LeetCode 834 - Sum of Distances in Tree

The problem gives us an undirected tree with n nodes labeled from 0 to n - 1. A tree is a connected graph with exactly n - 1 edges and no cycles. Each edge connects two nodes bidirectionally.

LeetCode Problem 834

Difficulty: 🔴 Hard
Topics: Dynamic Programming, Tree, Depth-First Search, Graph Theory

Solution

Problem Understanding

The problem gives us an undirected tree with n nodes labeled from 0 to n - 1. A tree is a connected graph with exactly n - 1 edges and no cycles. Each edge connects two nodes bidirectionally.

For every node i, we must compute the total distance from i to every other node in the tree. The result should be returned as an array where:

  • answer[i] equals the sum of shortest path distances from node i to all other nodes.

The distance between two nodes in a tree is simply the number of edges along the unique path connecting them.

For example, consider this tree:

    0
   / \
  1   2
     /|\
    3 4 5

The distances from node 0 are:

  • to 1 = 1
  • to 2 = 1
  • to 3 = 2
  • to 4 = 2
  • to 5 = 2

Their sum is 8, so answer[0] = 8.

The constraints are important:

  • n can be as large as 30,000
  • The graph is guaranteed to be a valid tree
  • There are exactly n - 1 edges

These constraints immediately rule out expensive all-pairs shortest path algorithms. An O(n^2) solution would be too slow because:

30,000^2 = 900,000,000

which is far beyond acceptable runtime limits.

The important edge cases include:

  • A single-node tree, where the answer is [0]
  • A two-node tree, where both answers are 1
  • Highly unbalanced trees, such as a straight chain
  • Star-shaped trees where one node connects to every other node

The tree guarantee is extremely useful because:

  • There is exactly one path between any two nodes
  • DFS traversals can compute subtree information efficiently
  • Dynamic programming on trees becomes possible

Approaches

Brute Force Approach

The most direct solution is to compute the distance from every node to every other node independently.

For each node:

  1. Run BFS or DFS
  2. Compute distances to all other nodes
  3. Sum the distances
  4. Store the result

Since a BFS or DFS traversal takes O(n) time in a tree, and we must do this for every node, the total complexity becomes O(n^2).

This approach is correct because BFS computes shortest path distances in an unweighted graph. However, it is far too slow for n = 30,000.

Optimal Approach

The key insight is that the answers for neighboring nodes are closely related.

Suppose we already know the total distance sum for some node u. If we move the root from u to one of its children v, distances change in a predictable way:

  • Nodes inside v's subtree become 1 step closer
  • Nodes outside v's subtree become 1 step farther

If:

  • subtree_size[v] is the number of nodes in v's subtree
  • n is the total number of nodes

then:

answer[v] = answer[u] - subtree_size[v] + (n - subtree_size[v])

This observation allows us to compute all answers in linear time using two DFS traversals:

  1. First DFS:
  • Compute subtree sizes
  • Compute the answer for an arbitrary root, usually node 0
  1. Second DFS:
  • Reuse the root transition formula
  • Compute answers for all nodes

This transforms the problem from repeated graph traversals into dynamic programming on trees.

Approach Comparison

Approach Time Complexity Space Complexity Notes
Brute Force O(n²) O(n) Run BFS/DFS from every node independently
Optimal O(n) O(n) Two DFS traversals with tree DP

Algorithm Walkthrough

Step 1: Build the Adjacency List

We represent the tree using an adjacency list.

Since the graph is undirected:

  • if there is an edge [u, v]
  • then u is added to v's neighbor list
  • and v is added to u's neighbor list

This structure allows efficient traversal of the tree.

Step 2: Initialize Arrays

We maintain two important arrays:

  • count[node]

  • stores how many nodes exist in the subtree rooted at node

  • initially every node counts itself, so initialize to 1

  • answer[node]

  • stores the sum of distances from node to all nodes

Step 3: First DFS, Compute Subtree Sizes and Root Answer

We root the tree at node 0.

During DFS:

  • recursively process all children
  • accumulate subtree sizes
  • accumulate distance sums

For a child child of node:

count[node] += count[child]
answer[node] += answer[child] + count[child]

Why does this work?

  • answer[child] already contains distances inside the child's subtree
  • Every node in that subtree becomes 1 edge farther from node
  • There are count[child] such nodes

After this DFS:

  • count[node] is known for every node
  • answer[0] becomes the total distance sum for root 0

Step 4: Second DFS, Reroot the Tree

Now we propagate answers from parent to child.

Suppose we move root from node to child.

Two groups of nodes change:

  1. Nodes inside child's subtree
  • become 1 step closer
  • decrease total by count[child]
  1. Nodes outside child's subtree
  • become 1 step farther
  • increase total by n - count[child]

So:

answer[child] = answer[node] - count[child] + (n - count[child])

We recursively apply this formula throughout the tree.

Step 5: Return the Final Answer

After the second DFS, every node has its correct distance sum.

Return the answer array.

Why it works

The algorithm relies on a rerooting dynamic programming invariant.

The first DFS computes complete subtree information assuming node 0 is the root. The second DFS exploits the fact that moving the root across one edge changes distances uniformly:

  • one subtree gets closer
  • the rest of the tree gets farther

Because every edge is processed only a constant number of times, the entire computation remains linear.

Python Solution

from typing import List
from collections import defaultdict

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        graph = defaultdict(list)

        for u, v in edges:
            graph[u].append(v)
            graph[v].append(u)

        count = [1] * n
        answer = [0] * n

        def dfs1(node: int, parent: int) -> None:
            for neighbor in graph[node]:
                if neighbor == parent:
                    continue

                dfs1(neighbor, node)

                count[node] += count[neighbor]
                answer[node] += answer[neighbor] + count[neighbor]

        def dfs2(node: int, parent: int) -> None:
            for neighbor in graph[node]:
                if neighbor == parent:
                    continue

                answer[neighbor] = (
                    answer[node]
                    - count[neighbor]
                    + (n - count[neighbor])
                )

                dfs2(neighbor, node)

        dfs1(0, -1)
        dfs2(0, -1)

        return answer

The implementation starts by constructing an adjacency list using defaultdict(list). Since the tree is undirected, every edge is inserted in both directions.

The count array is initialized with 1 because every node is part of its own subtree. The answer array begins with zeros and is gradually filled during DFS.

The first DFS computes two things simultaneously:

  • subtree sizes
  • the total distance sum for root 0

Each child contributes both its subtree size and its already-computed internal distances.

The second DFS performs rerooting dynamic programming. Instead of recomputing distances from scratch, it transforms the parent answer into the child answer using the derived formula.

Because each DFS visits every edge once, the solution remains linear.

Go Solution

package main

func sumOfDistancesInTree(n int, edges [][]int) []int {
	graph := make([][]int, n)

	for _, edge := range edges {
		u := edge[0]
		v := edge[1]

		graph[u] = append(graph[u], v)
		graph[v] = append(graph[v], u)
	}

	count := make([]int, n)
	answer := make([]int, n)

	for i := 0; i < n; i++ {
		count[i] = 1
	}

	var dfs1 func(int, int)
	dfs1 = func(node int, parent int) {
		for _, neighbor := range graph[node] {
			if neighbor == parent {
				continue
			}

			dfs1(neighbor, node)

			count[node] += count[neighbor]
			answer[node] += answer[neighbor] + count[neighbor]
		}
	}

	var dfs2 func(int, int)
	dfs2 = func(node int, parent int) {
		for _, neighbor := range graph[node] {
			if neighbor == parent {
				continue
			}

			answer[neighbor] =
				answer[node] -
					count[neighbor] +
					(n - count[neighbor])

			dfs2(neighbor, node)
		}
	}

	dfs1(0, -1)
	dfs2(0, -1)

	return answer
}

The Go implementation follows the same logic as the Python version, but uses slices instead of Python lists and dictionaries.

The adjacency list is represented as [][]int, where each index stores neighboring nodes.

Recursive DFS functions are declared using function variables because Go requires named recursive closures to be assigned before use.

Integer overflow is not a concern because the maximum possible distance sum fits comfortably within Go's standard int range under the problem constraints.

Worked Examples

Example 1

n = 6
edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]

Tree:

    0
   / \
  1   2
     /|\
    3 4 5

First DFS

Initial state:

Node count answer
0 1 0
1 1 0
2 1 0
3 1 0
4 1 0
5 1 0

Process leaves first.

Node 3

No children.

Node count answer
3 1 0

Node 4

Node count answer
4 1 0

Node 5

Node count answer
5 1 0

Process node 2

count[2] = 1 + 1 + 1 + 1 = 4
answer[2] = 0 + 1 + 1 + 1 = 3
Node count answer
2 4 3

Process node 1

Node count answer
1 1 0

Process node 0

count[0] = 1 + 1 + 4 = 6
answer[0] = (0 + 1) + (3 + 4) = 8
Node count answer
0 6 8

Second DFS

Move root from 0 to 1

answer[1] = 8 - 1 + (6 - 1)
          = 12

Move root from 0 to 2

answer[2] = 8 - 4 + (6 - 4)
          = 6

Move root from 2 to 3

answer[3] = 6 - 1 + 5
          = 10

Move root from 2 to 4

answer[4] = 6 - 1 + 5
          = 10

Move root from 2 to 5

answer[5] = 6 - 1 + 5
          = 10

Final result:

[8,12,6,10,10,10]

Example 2

n = 1
edges = []

There is only one node.

No traversal changes occur.

Result:

[0]

Example 3

n = 2
edges = [[1,0]]

Tree:

0 -- 1

First DFS:

Node count answer
1 1 0
0 2 1

Second DFS:

answer[1] = 1 - 1 + 1 = 1

Final result:

[1,1]

Complexity Analysis

Measure Complexity Explanation
Time O(n) Each DFS visits every edge once
Space O(n) Adjacency list, recursion stack, and arrays

The adjacency list stores 2 * (n - 1) edge references because the graph is undirected. Both DFS traversals process each edge a constant number of times, giving linear runtime overall.

The recursion depth can reach O(n) in a skewed tree, which contributes to the space complexity alongside the graph and auxiliary arrays.

Test Cases

from typing import List

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        from collections import defaultdict

        graph = defaultdict(list)

        for u, v in edges:
            graph[u].append(v)
            graph[v].append(u)

        count = [1] * n
        answer = [0] * n

        def dfs1(node: int, parent: int):
            for neighbor in graph[node]:
                if neighbor == parent:
                    continue

                dfs1(neighbor, node)

                count[node] += count[neighbor]
                answer[node] += answer[neighbor] + count[neighbor]

        def dfs2(node: int, parent: int):
            for neighbor in graph[node]:
                if neighbor == parent:
                    continue

                answer[neighbor] = (
                    answer[node]
                    - count[neighbor]
                    + (n - count[neighbor])
                )

                dfs2(neighbor, node)

        dfs1(0, -1)
        dfs2(0, -1)

        return answer

solution = Solution()

assert solution.sumOfDistancesInTree(
    6,
    [[0,1],[0,2],[2,3],[2,4],[2,5]]
) == [8,12,6,10,10,10]  # example 1

assert solution.sumOfDistancesInTree(
    1,
    []
) == [0]  # single node tree

assert solution.sumOfDistancesInTree(
    2,
    [[1,0]]
) == [1,1]  # two node tree

assert solution.sumOfDistancesInTree(
    4,
    [[0,1],[1,2],[2,3]]
) == [6,4,4,6]  # linear chain

assert solution.sumOfDistancesInTree(
    5,
    [[0,1],[0,2],[0,3],[0,4]]
) == [4,7,7,7,7]  # star topology

assert solution.sumOfDistancesInTree(
    3,
    [[0,1],[1,2]]
) == [3,2,3]  # small chain

assert solution.sumOfDistancesInTree(
    7,
    [[0,1],[0,2],[1,3],[1,4],[2,5],[2,6]]
) == [10,11,11,16,16,16,16]  # balanced binary tree
Test Why
n = 6 example tree Validates the main rerooting logic
Single node Verifies minimum boundary case
Two nodes Verifies simplest non-trivial tree
Linear chain Tests maximum depth behavior
Star topology Tests highly centralized structure
Small chain of 3 nodes Verifies symmetry
Balanced binary tree Tests multiple subtree interactions

Edge Cases

Single Node Tree

A tree with only one node has no edges and no neighbors. This case is easy to mishandle because DFS logic often assumes children exist.

The implementation handles this naturally because:

  • the adjacency list for node 0 is empty
  • both DFS traversals terminate immediately
  • answer[0] remains 0

No special-case code is required.

Highly Skewed Tree

A chain-like tree such as:

0 - 1 - 2 - 3 - 4

creates maximum recursion depth and asymmetric subtree sizes.

Naive implementations may repeatedly recompute distances and degrade to quadratic time.

This solution still runs in linear time because:

  • each edge is processed only twice
  • rerooting updates answers in constant time per edge

Star-Shaped Tree

A tree where one center node connects to every other node creates a large imbalance in subtree sizes.

For example:

    1
    |
2 - 0 - 3
    |
    4

The center node has very small distances, while leaf nodes have much larger sums.

The rerooting formula correctly handles this because subtree size directly determines how distances change when moving the root.