LeetCode 3544 - Subtree Inversion Sum

This problem asks us to select a subset of nodes in a rooted tree such that each selected node triggers a subtree inversion, and inversions interact through ancestor-descendant relationships with a distance constraint.

LeetCode Problem 3544

Difficulty: 🔴 Hard
Topics: Array, Dynamic Programming, Tree, Depth-First Search

Solution

Problem Understanding

This problem asks us to select a subset of nodes in a rooted tree such that each selected node triggers a subtree inversion, and inversions interact through ancestor-descendant relationships with a distance constraint. Each inversion multiplies all values in the corresponding subtree by −1, and multiple inversions stack multiplicatively, meaning each node’s final sign depends on how many selected ancestors it has.

The input consists of an undirected tree with root fixed at node 0, an array nums representing node values, and an integer k which enforces spacing constraints between selected inversion nodes along any ancestor-descendant chain. Specifically, if two selected nodes lie on the same root-to-node path and one is an ancestor of the other, then their distance in edges must be at least k.

The output is the maximum possible sum of all node values after applying an optimal selection of inversion nodes.

The key constraints are n ≤ 5 * 10^4 and k ≤ 50, which strongly suggests a tree dynamic programming solution with an additional bounded state dimension. The small value of k is crucial, since it allows us to track distance-related state efficiently.

Edge cases include trees that are skewed (degenerating into a chain), trees where all values are negative or zero, and cases where k is large relative to tree height, effectively preventing multiple selections on a single root-to-leaf path.

Approaches

The brute-force approach considers every subset of nodes as potential inversion sets, checks whether the distance constraint is satisfied, computes the resulting sign propagation for every node, and evaluates the total sum. This is exponential in nature because there are 2^n subsets, and for each subset we must recompute subtree effects, making it completely infeasible.

The optimal solution relies on tree dynamic programming. The central insight is that inversion decisions propagate downward, and each node only needs to know two pieces of information from its ancestors: the parity of inversions applied so far (even or odd flips), and the distance since the last selected inversion node. Because k ≤ 50, we can explicitly track this distance as a bounded DP state.

We define a DP state per node that captures both ancestor parity and distance constraints, and we decide at each node whether to select it for inversion or not, propagating updated states to children.

Approach Time Complexity Space Complexity Notes
Brute Force O(2^n · n) O(n) Enumerates all inversion subsets and recomputes subtree effects
Optimal Tree DP O(n · k) O(n · k) DP on tree with parity and bounded distance state

Algorithm Walkthrough

We solve the problem using a bottom-up DFS dynamic programming approach over the rooted tree.

Step 1: Build the rooted tree

We first convert the undirected tree into a rooted adjacency structure starting from node 0. This ensures a clear parent-child direction for DP propagation.

Step 2: Define DP state

For each node v, we define a DP table:

DP[v][p][d]

Where p ∈ {0,1} represents parity of inversions applied by ancestors (0 means even flips, 1 means odd flips), and d ∈ [1..k] represents the distance since the last selected inversion ancestor, capped at k (where k means “safe”, i.e., no active restriction).

This state fully captures what a node needs to decide whether it can be selected and how it contributes to its subtree.

Step 3: Base contribution at a node

At node v, its raw contribution depends on parity p. If p = 0, contribution is nums[v], otherwise it is -nums[v].

From here, we consider two choices:

We do not invert node v, in which case parity remains unchanged for children, and distance increases by 1.

We invert node v, in which case parity flips, and distance resets to 1 for children. This choice is only allowed if d ≥ k.

Step 4: Combine children contributions

Each node aggregates contributions from all children independently. For a fixed DP state at node v, each child u is evaluated using the corresponding propagated state:

If v is not selected, child receives (p, min(d+1, k)).

If v is selected, child receives (p ^ 1, 1).

Because children are independent given the parent state, we sum their optimal contributions.

Step 5: Tree DP execution order

We run DFS postorder so that all children DP tables are computed before processing the parent node.

Step 6: Final answer

At root node 0, we evaluate all states (p, d) and take the maximum possible value.

Why it works

The DP works because the effect of an inversion node is fully captured by parity propagation and distance constraints along ancestor chains. The subtree structure ensures independence between siblings, and the bounded distance constraint allows us to encode all valid configurations compactly without enumerating subsets. We are given an undirected tree rooted at node 0, where each node i has an integer value nums[i]. We may perform a sequence of operations called inversions. An inversion at a node u multiplies every value in the subtree rooted at u by -1. Since multiple inversions may overlap, a node’s final sign depends on how many inversions affect it along the root-to-node path.

However, inversions are constrained. If two nodes a and b are both chosen for inversion and one is an ancestor of the other, then the distance between them in the tree must be at least k. Equivalently, along any root-to-leaf path, any two chosen inversion nodes must be separated by at least k edges.

The goal is to choose a subset of nodes satisfying this constraint such that, after applying all subtree inversions, the total sum of all node values is maximized.

The key difficulty is that each inversion flips an entire subtree, so decisions at an ancestor affect all descendants, and inversions interact via parity of flips.

The input constraints indicate a large tree size up to 5 * 10^4, but a small parameter k ≤ 50, suggesting a dynamic programming solution that exploits bounded state in k.

Edge cases include a single path (degenerate tree), alternating positive and negative values, and scenarios where no inversion or all valid inversions are optimal.

Approaches

Brute Force Approach

A naive solution considers every subset of nodes as potential inversion points, checks whether the subset satisfies the distance constraint along ancestor-descendant relationships, and computes the resulting sign of each node by counting how many chosen inversions affect it.

This is correct because it exhaustively evaluates all valid configurations. However, there are 2^n subsets, and each evaluation requires computing subtree effects, leading to infeasible exponential time complexity.

Key Insight and Optimal Approach

The critical observation is that the effect of inversions on any node depends only on the parity of selected ancestors on its root path. Thus, we do not need to track the entire subset globally, only the state along a root-to-current path.

We define a dynamic programming state at each node that tracks two pieces of information:

  1. The distance since the last inversion on the current root-to-node path.
  2. The current parity of flips affecting the node.

At each node, we decide whether to invert it or not. If we invert it, we reset the distance to zero and flip parity; if we do not, we increase the distance.

This reduces the problem to a tree DP with O(k) states per node.

Complexity Comparison

Approach Time Complexity Space Complexity Notes
Brute Force O(2^n · n) O(n) Enumerate all inversion subsets
Optimal O(n · k) O(n · k) Tree DP with bounded distance state

Algorithm Walkthrough

We perform a postorder traversal of the tree and compute a DP table for each node.

Let dp[u][d][p] denote the maximum achievable sum in the subtree rooted at u under the following conditions:

  • d is the distance since the last inversion on the path from root to u, capped at k.
  • p ∈ {0,1} is the parity of inversions applied on the path above u, determining whether nums[u] is currently negated (p=1) or not (p=0).

We proceed as follows:

  1. Root the tree at node 0 and build an adjacency list.

This is necessary to define subtree structure and enable DFS-based DP. 2. Initialize DP arrays for each node.

We use a 3D DP table per node because each state depends on subtree configuration conditioned on (d, p). 3. Perform a DFS from the root in postorder.

We must compute children first because the DP at a node aggregates results from all its children. 4. At each node u, consider two decisions:

We first consider not inverting u.

In this case, u contributes (+1 if p=0 else -1) * nums[u].

For every child v, we propagate state (min(d+1, k), p).

Next, we consider inverting u, but only if d ≥ k.

In this case, u contributes (−1 if p=0 else +1) * nums[u] because parity flips.

For every child v, we propagate (0, 1-p). 5. For each state, we combine children contributions additively.

Since subtrees are independent once the state is fixed, we sum optimal contributions from children. 6. At the root, we start with state (k, 0) meaning no prior inversion constraints and no flips.

Why it works

The DP is correct because the only information required to determine future validity and contribution is the distance to the last inversion and the parity of flips. These two values form a complete state description: all future constraints depend only on whether we are allowed to invert and whether current values are flipped. Subtree independence holds once the state is fixed, ensuring optimal substructure.

Python Solution

from typing import List
import sys
sys.setrecursionlimit(10**7)

class Solution:
    def subtreeInversionSum(self, edges: List[List[int]], nums: List[int], k: int) -> int:
        n = len(nums)
        g = [[] for _ in range(n)]
        for u, v in edges:
            g[u].append(v)
            g[v].append(u)

        # rooted tree
        parent = [-1] * n
        children = [[] for _ in range(n)]
        stack = [0]
        parent[0] = -2

        order = []
        while stack:
            node = stack.pop()
            order.append(node)
            for nei in g[node]:
                if parent[nei] == -1:
                    parent[nei] = node
                    children[node].append(nei)
                    stack.append(nei)

        # dp[v][parity][dist]
        # dist in [1..k], k means safe (>=k)
        dp = [None] * n

        def dfs(v: int):
            # initialize dp table
            dp_v = [[[float("-inf")] * (k + 1) for _ in range(2)] for _ in range(n)]
            # we will compress later; actually we only keep dp[v]
            # dp[v][p][d]
            dp_v = [[[-10**18] * (k + 1) for _ in range(2)] for _ in range(1)]
            dp_v = [[[0] * (k + 1) for _ in range(2)]]

            # temporary: we compute child contributions first
            child_dp = []

            for c in children[v]:
                dfs(c)
                child_dp.append(dp[c])

            # dp[v] now
            dp[v] = [[float("-inf")] * (k + 1) for _ in range(2)]

            for p in range(2):
                for d in range(1, k + 1):

                    base = nums[v] if p == 0 else -nums[v]

                    # case 1: do not invert v
                    total_no = base
                    p2 = p
                    d2 = min(d + 1, k)

                    for dp_c in child_dp:
                        best_c = max(dp_c[p2][d2] for _ in [0])  # single state
                        total_no += dp_c[p2][d2]

                    best = total_no

                    # case 2: invert v
                    if d >= k:
                        total_yes = -base  # flip current node again
                        p2 = p ^ 1
                        d2 = 1
                        for dp_c in child_dp:
                            total_yes += dp_c[p2][d2]
                        best = max(best, total_yes)

                    dp[v][p][d] = best

        dfs(0)

        ans = float("-inf")
        for p in range(2):
            for d in range(1, k + 1):
                ans = max(ans, dp[0][p][d])

        return ans

Code Explanation

The solution first builds a rooted tree from the undirected input. Then it performs a DFS to compute DP tables bottom-up. For each node, we compute a 3D DP table indexed by ancestor parity and distance since last inversion.

For each state, we evaluate two possibilities: not selecting the node or selecting it as an inversion point. We aggregate contributions from all children based on how the state changes when passing through the current node.

Finally, we compute the maximum over all valid states at the root. K = k

    # dp[u] will be a dict: (parent, node) handled via recursion closure
    # dp[u][d][p]
    # we use list of dicts for speed: dp[u] = [[0]*2 for _ in range(K+1)]
    dp = [None] * n

    visited = [False] * n

    def dfs(u: int, parent: int):
        visited[u] = True
        children = []
        for v in g[u]:
            if v != parent:
                dfs(v, u)
                children.append(v)

        # dp[u][d][p]
        dp_u = [[[0, 0] for _ in range(K + 1)] for _ in range(2)]

        # We will compute bottom-up over states
        # For each state, combine children independently
        for d in range(K + 1):
            for p in range(2):

                # Option 1: do NOT invert u
                sign = 1 if p == 0 else -1
                base = sign * nums[u]

                total_no_inv = base
                nd = min(d + 1, K)
                np = p

                for v in children:
                    total_no_inv += dp[v][nd][np]

                best = total_no_inv

                # Option 2: invert u if allowed
                if d >= K:
                    sign2 = -sign
                    base2 = sign2 * nums[u]
                    total_inv = base2

                    nd2 = 0
                    np2 = 1 - p

                    for v in children:
                        total_inv += dp[v][nd2][np2]

                    best = max(best, total_inv)

                dp_u[p][d] = best

        dp[u] = dp_u

    # adjust dp structure usage: we need dp[v][p][d]
    dfs(0, -1)

    return dp[0][0][K]

### Explanation of implementation

We build an adjacency list for the tree, then run a DFS to compute DP bottom-up. Each node stores a DP table indexed by parity and distance state. For each state, we compute two transitions: skipping inversion or applying inversion (if allowed). Children contributions are aggregated using already computed DP tables.

The root is evaluated with no active inversion parity and maximal allowed distance.

## Go Solution

```go
package main

func subtreeInversionSum(edges [][]int, nums []int, k int) int64 {
	n := len(nums)
	g := make([][]int, n)
	for _, e := range edges {
		u, v := e[0], e[1]
		g[u] = append(g[u], v)
		g[v] = append(g[v], u)
	}

	parent := make([]int, n)
	for i := range parent {
		parent[i] = -1
	}
	parent[0] = -2

	children := make([][]int, n)

	stack := []int{0}
	for len(stack) > 0 {
		v := stack[len(stack)-1]
		stack = stack[:len(stack)-1]

		for _, nei := range g[v] {
			if parent[nei] == -1 {
				parent[nei] = v
				children[v] = append(children[v], nei)
				stack = append(stack, nei)
			}
		}
	}

	type DP = [][][]int64
	dp := make([]DP, n)

	var dfs func(v int)
	dfs = func(v int) {
		for _, c := range children[v] {
			dfs(c)
		}

		dp[v] = make([][][]int64, 2)
		for p := 0; p < 2; p++ {
			dp[v][p] = make([][]int64, k+1)
			for d := 0; d <= k; d++ {
				dp[v][p][d] = make([]int64, 1) // placeholder
				dp[v][p][d][0] = -1 << 60
			}
		}

		for p := 0; p < 2; p++ {
			for d := 1; d <= k; d++ {

				base := int64(nums[v])
				if p == 1 {
					base = -base
				}

				totalNo := base
				p2 := p
				d2 := d + 1
				if d2 > k {
					d2 = k
				}

				for _, c := range children[v] {
					totalNo += dp[c][p2][d2][0]
				}

				best := totalNo

				if d >= k {
					totalYes := -base
					p2 = p ^ 1
					d2 = 1
					for _, c := range children[v] {
						totalYes += dp[c][p2][d2][0]
					}
					if totalYes > best {
						best = totalYes
					}
				}

				dp[v][p][d][0] = best
			}
		}
	}

	dfs(0)

	var ans int64 = -1 << 60
	for p := 0; p < 2; p++ {
		for d := 1; d <= k; d++ {
			if dp[0][p][d][0] > ans {
				ans = dp[0][p][d][0]
			}
		}
	}
	return ans
	K := k

	type stateDP struct {
		dp [2][][]int64
	}

	dps := make([][][][2]int64, n)
	visited := make([]bool, n)

	var dfs func(u, parent int)
	dfs = func(u, parent int) {
		visited[u] = true

		children := make([]int, 0)
		for _, v := range g[u] {
			if v != parent {
				dfs(v, u)
				children = append(children, v)
			}
		}

		dp := make([][][2]int64, 2)
		for p := 0; p < 2; p++ {
			dp[p] = make([][2]int64, K+1)
		}

		for d := 0; d <= K; d++ {
			for p := 0; p < 2; p++ {

				sign := int64(1)
				if p == 1 {
					sign = -1
				}

				base := sign * int64(nums[u])
				best := base

				// no inversion
				total := base
				for _, v := range children {
					total += dps[v][p][min(d+1, K)]
				}
				if total > best {
					best = total
				}

				// inversion allowed
				if d >= K {
					sign2 := -sign
					base2 := sign2 * int64(nums[u])
					total2 := base2
					for _, v := range children {
						total2 += dps[v][1-p][0]
					}
					if total2 > best {
						best = total2
					}
				}

				dp[p][d] = best
			}
		}

		dps[u] = dp
	}

	min := func(a, b int) int {
		if a < b {
			return a
		}
		return b
	}

	dfs(0, -1)
	return dps[0][0][K]
}

Go-specific notes

The Go implementation uses explicit slices for DP tables since Go lacks native multi-dimensional arrays with dynamic bounds. Integer overflow safety is handled using int64, and sentinel values are initialized with a very negative constant. The recursive DFS ensures children are processed before parents. The Go implementation uses fixed-size arrays for the k ≤ 50 dimension and explicit recursion. Slices are used for adjacency lists, and integer arithmetic is cast to int64 to prevent overflow. The logic mirrors the Python DP but requires explicit helper functions for min.

Worked Examples

Example 1

We start at root 0. Initially, parity is 0, and distance is k.

At node 0, we evaluate both choices:

If we select node 0, parity flips and children receive distance 1.

If we do not select node 0, children receive distance k.

At node 3, 4, and 6, selection becomes beneficial due to local negative values becoming positive after inversion. The DP propagates these optimal choices upward, eventually selecting nodes {0, 3, 4, 6} as the best configuration.

The resulting transformed array becomes [-4, 8, 6, 3, 7, 2, 5], yielding 27.

Example 2

The chain structure means decisions propagate linearly.

At node 4, selecting inversion flips -5 into 5, producing a net gain. Due to k = 2, selecting adjacent nodes is restricted, so only node 4 is chosen.

DP propagates this upward, yielding final sum 9.

Example 3

Tree root 0 has two children. Selecting both 1 and 2 is valid since they are not in an ancestor-descendant relationship. Each inversion flips its subtree independently, and DP confirms that both selections maximize total sum.

Final sum becomes 3. We start at node 0 with state (d=k, p=0).

At each node, the algorithm evaluates inversion choices:

  • At node 0, inversion is allowed since d ≥ k.
  • Choosing inversion flips parity for entire subtree.
  • Subsequent nodes accumulate state transitions.

The optimal selection includes nodes {0, 3, 4, 6} because these satisfy spacing constraints and maximize beneficial sign flips, resulting in final sum 27.

Example 2

Tree is a chain. We propagate DP linearly:

At node 4, inversion is beneficial because it flips -5 → 5.

Since k=2, no other inversions are close enough to conflict.

Thus only node 4 is chosen.

Final sum becomes 9.

Example 3

Inversion at nodes 1 and 2 is valid since they are in different subtrees.

DP evaluates both independently and combines them, producing total 3.

Complexity Analysis

Measure Complexity Explanation
Time O(n · k) Each node processes up to k distance states and constant child transitions
Space O(n · k) DP table stores k states per parity per node

The complexity is linear in the number of nodes with an additional factor of k, which is acceptable given the constraint k ≤ 50. | Time | O(n · k) | Each node evaluates O(k) distance states and constant parity states, combining children once | | Space | O(n · k) | DP table stored per node for all states |

The bounded parameter k ≤ 50 ensures the DP remains linear in n, as the state space is small and fixed.

Test Cases

assert Solution().subtreeInversionSum([[0,1],[0,2]], [0,-1,-2], 3) == 3  # small branching tree
assert Solution().subtreeInversionSum([[0,1],[1,2],[2,3],[3,4]], [-1,3,-2,4,-5], 2) == 9  # chain
assert Solution().subtreeInversionSum([[0,1],[0,2],[1,3],[1,4],[2,5],[2,6]], [4,-8,-6,3,7,-2,5], 2) == 27  # full example
assert Solution().subtreeInversionSum([[0,1]], [5,-10], 1) == 15  # immediate inversion allowed
assert Solution().subtreeInversionSum([[0,1],[1,2]], [-5,-5,-5], 2) == 5  # all negative chain
assert Solution().subtreeInversionSum(
    [[0,1],[0,2],[1,3],[1,4],[2,5],[2,6]],
    [4,-8,-6,3,7,-2,5],
    2
) == 27  # provided example 1

assert Solution().subtreeInversionSum(
    [[0,1],[1,2],[2,3],[3,4]],
    [-1,3,-2,4,-5],
    2
) == 9  # provided example 2

assert Solution().subtreeInversionSum(
    [[0,1],[0,2]],
    [0,-1,-2],
    3
) == 3  # provided example 3

assert Solution().subtreeInversionSum(
    [[0,1]],
    [5, -10],
    1
) == 15  # small edge: invert root and leaf independently

assert Solution().subtreeInversionSum(
    [[0,1],[1,2],[2,3]],
    [1,1,1,1],
    50
) == 4  # all positive, no inversions needed
Test Why
small branching tree verifies independent sibling subtrees
chain graph stresses ancestor constraint propagation
full example 1 validates multi-inversion optimal structure
two-node tree minimal inversion case
all negative chain checks global sign optimization behavior

Edge Cases

One important edge case is a linear chain where k is large relative to the depth. In this situation, at most one inversion can be selected, and the DP must correctly evaluate which single node yields the maximum gain. This tests whether the distance constraint is enforced correctly.

Another edge case occurs when all node values are negative. The optimal strategy often becomes selecting as many non-conflicting inversions as possible to flip signs. This stresses whether parity propagation is handled correctly across multiple levels of the tree.

A final edge case is when k = 1, which effectively removes the spacing constraint entirely. In this case, multiple adjacent nodes can be selected, and the solution must still correctly handle overlapping subtree inversions without double counting or incorrect parity propagation. | example trees | validates correctness on official cases | | single chain | tests propagation along path constraints | | all positive | ensures algorithm avoids unnecessary inversions | | small tree | boundary correctness |

Edge Cases

One important edge case is a single long path where the tree degenerates into a linked list. In this scenario, the distance constraint becomes equivalent to selecting indices spaced at least k apart. The DP correctly handles this because distance state increases deterministically along the path.

Another edge case occurs when all node values are negative. The optimal solution tends to maximize inversions, but must respect spacing constraints. The DP ensures that inversions are only applied when allowed by the d ≥ k condition, preventing invalid greedy flips.

A final edge case is when k is larger than the tree height. In this case, at most one inversion can be applied in the entire tree. The algorithm correctly handles this because once a node is inverted, all descendants inherit d < k, blocking further inversions along ancestor chains.