LeetCode 3553 - Minimum Weighted Subgraph With the Required Paths II

We are given a weighted, undirected tree with n nodes. Because the graph is a tree, there is exactly one simple path between any two nodes.

LeetCode Problem 3553

Difficulty: 🔴 Hard
Topics: Array, Dynamic Programming, Bit Manipulation, Tree, Depth-First Search

Solution

Problem Understanding

We are given a weighted, undirected tree with n nodes. Because the graph is a tree, there is exactly one simple path between any two nodes.

For each query [src1, src2, dest], we must find the minimum total weight of a connected subtree such that both src1 and src2 can reach dest using only edges contained in that subtree.

Since the original graph is already a tree, any connected subtree is simply a subset of edges that remains connected and acyclic.

The key observation is that the required subtree must contain:

  • The unique path from src1 to dest
  • The unique path from src2 to dest

Any valid subtree must include all edges from both of those paths. Conversely, the union of those two paths is itself a connected subtree, so it is the smallest possible valid subtree.

Therefore, every query asks for the total weight of:

Path(src1, dest) ∪ Path(src2, dest)

The constraints are very large:

  • n ≤ 100,000
  • queries.length ≤ 100,000

A per-query traversal of the tree would be far too slow. We need a solution that preprocesses the tree once and answers each query in logarithmic time.

Important edge cases include situations where the two paths overlap heavily, where one node lies on the path between the other two, and where the tree degenerates into a long chain. The problem guarantees that the input graph is a valid tree and that the three query nodes are distinct. The problem presents an undirected weighted tree with n nodes labeled from 0 to n-1. Each edge has a positive weight. A tree is a connected acyclic graph, so for n nodes there are exactly n-1 edges. The tree is represented as a list of edges, where each edge is [u, v, w], connecting nodes u and v with weight w.

We are also given queries, each of the form [src1, src2, dest]. For each query, the goal is to find a subtree (any connected subset of nodes and edges of the original tree) such that there exist paths from src1 to dest and from src2 to dest using edges inside the subtree. Among all such subtrees, we must return the minimum total weight (sum of edge weights).

Because the original graph is a tree, there is exactly one simple path between any two nodes. Therefore, the subtree needed for a query is the union of the paths src1 → dest and src2 → dest. Since the paths in a tree intersect along their common ancestors, the minimal subtree is uniquely determined by taking all edges along both paths, with no edge counted twice.

Constraints inform us about scale: n can be up to 10^5 and queries can also be up to 10^5, so any solution iterating over all paths in a naive O(n) way per query is too slow. Edge weights are moderate (1 <= w <= 10^4). The input guarantees a valid tree and that all query nodes are distinct. This avoids concerns about disconnected graphs or self-loops.

Important edge cases include queries where src1 or src2 is the dest, queries where src1 and src2 are neighbors, and queries along long chains (to test cumulative weight correctness).

Approaches

Brute Force

A direct approach is to process each query independently.

For a query (src1, src2, dest), we could explicitly find:

  • The path from src1 to dest
  • The path from src2 to dest

Then compute the union of their edges and sum the weights.

Because the graph is a tree, each path can be recovered using DFS or parent reconstruction. However, a path may contain O(n) edges. With up to 100,000 queries, the worst case becomes O(nq).

This is far too slow.

Key Insight

The union of the paths

Path(src1, dest) ∪ Path(src2, dest)

is exactly the minimum Steiner tree connecting the three terminals:

src1, src2, dest

For any three nodes a, b, and c in a tree, the total weight of their Steiner tree is:

$\frac{dist(a,b)+dist(b,c)+dist(c,a)}{2}$

This is a classical tree property.

Every edge of the Steiner tree appears in exactly two of the three pairwise paths. Therefore, when we sum the three pairwise distances, every Steiner-tree edge is counted twice. Dividing by two gives the desired weight.

So each query reduces to computing three pairwise distances:

  • dist(src1, src2)
  • dist(src1, dest)
  • dist(src2, dest)

If we can answer distance queries quickly using LCA preprocessing, each query becomes very efficient.

Approach Comparison

Approach Time Complexity Space Complexity Notes
Brute Force O(nq) O(n) Reconstruct paths for every query
Optimal O(n log n + q log n) O(n log n) LCA preprocessing, O(log n) per query

Algorithm Walkthrough

1. Build the tree

Create an adjacency list containing:

  • Neighbor node
  • Edge weight

This allows efficient DFS traversal.

2. Root the tree

Choose node 0 as the root.

During DFS, compute:

  • depth[u]
  • distRoot[u], distance from root to u
  • parent[0][u], immediate parent

These values are needed for binary lifting and distance calculations.

3. Build the binary lifting table

Let:

parent[k][u]

represent the 2^k-th ancestor of node u.

Using the first parent layer from DFS:

parent[k][u] = parent[k-1][ parent[k-1][u] ]

This preprocessing takes O(n log n).

4. Implement LCA queries

To find the Lowest Common Ancestor of two nodes:

  1. Lift the deeper node until both depths match.
  2. Lift both nodes simultaneously from the largest power of two downward.
  3. When their ancestors diverge, keep lifting.
  4. Their common parent is the LCA.

Each LCA query takes O(log n).

5. Compute distances

For any nodes u and v:

dist(u,v)
=
distRoot[u]
+
distRoot[v]
-
2 * distRoot[lca(u,v)]

This follows directly from tree path properties.

6. Answer each query

For:

a = src1
b = src2
c = dest

Compute:

d1 = dist(a,b)
d2 = dist(a,c)
d3 = dist(b,c)

The answer is:

(d1 + d2 + d3) // 2

Why it works

The union of the two required paths is the unique minimum connected subtree containing src1, src2, and dest. In a tree, every edge of this subtree belongs to exactly two of the three pairwise terminal paths. Therefore the sum of the three pairwise distances counts every subtree edge exactly twice. Dividing by two yields precisely the total weight of the required subtree. A brute-force method would compute paths from src1 to dest and src2 to dest individually for each query using DFS or BFS, collect all edges in the union of both paths, and sum their weights. While this is correct, it requires O(n) traversal per query, yielding O(n × q) overall time. With n and q up to 10^5, this reaches 10^10 operations, which is infeasible.

Optimal Approach

Key observation: in a tree, the union of two paths src1 → dest and src2 → dest forms a path from src1 to src2 that passes through their lowest common ancestor (LCA) with respect to dest.

Thus, if we precompute:

  1. Distance from every node to every other node - impractical.
  2. Distance from a fixed root to all nodes, and ability to compute LCA - feasible.

We can root the tree arbitrarily (e.g., node 0) and precompute prefix distances along the tree to allow O(log n) queries for distances between any two nodes. Using binary lifting for LCA allows O(log n) per query. Then, for a query [src1, src2, dest], the minimal subtree weight is:

weight(src1 → dest) + weight(src2 → dest) - weight(lca(src1, src2, dest) → dest)

Where the "triple LCA" is determined as the intersection point of paths. Concretely, the minimal subtree weight can be computed as the sum of distances from src1 and src2 to dest minus the distance from their common path overlap.

We can achieve O(n log n + q log n) overall time with O(n log n) space for binary lifting tables.

Approach Time Complexity Space Complexity Notes
Brute Force O(n × q) O(n) DFS for each query to collect union of paths
Optimal O((n + q) log n) O(n log n) Precompute LCA with binary lifting, compute distances in O(log n)

Algorithm Walkthrough

  1. Build adjacency list: Convert edges into a standard tree adjacency list storing (neighbor, weight) pairs.
  2. Root the tree: Arbitrarily select node 0 as the root.
  3. Precompute depths and binary lifting table: For each node u, store up[u][i] as the 2^i-th ancestor of u. Also store dist[u], the total distance from root to u.
  4. Implement LCA query: Using binary lifting, for any nodes u and v, find their lowest common ancestor in O(log n) time.
  5. Compute distance between any two nodes: For nodes u and v, distance is dist[u] + dist[v] - 2 * dist[lca(u, v)].
  6. Process queries: For each query [src1, src2, dest], compute distances:
d1 = distance(src1, dest)
d2 = distance(src2, dest)
lca_src1_src2 = LCA(src1, src2)
d_overlap = distance(lca_src1_src2, dest)
answer = d1 + d2 - d_overlap

This correctly removes the overlapping edges counted twice in d1 + d2. 7. Return results: Collect answers for all queries in an array.

Why it works: The property exploited is that in a tree, paths are unique and intersect along ancestors. Subtracting the overlapping path ensures edges are counted once, yielding minimal total weight.

Python Solution

from typing import List
import sys

class Solution:
    def minimumWeight(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        n = len(edges) + 1

        graph = [[] for _ in range(n)]
        for u, v, w in edges:
            graph[u].append((v, w))
            graph[v].append((u, w))

        LOG = (n).bit_length()

        parent = [[-1] * n for _ in range(LOG)]
        depth = [0] * n
        dist_root = [0] * n

        sys.setrecursionlimit(300000)

        def dfs(u: int, p: int) -> None:
            parent[0][u] = p
            for v, w in graph[u]:
                if v == p:
                    continue
                depth[v] = depth[u] + 1
                dist_root[v] = dist_root[u] + w
                dfs(v, u)

        dfs(0, -1)

        for k in range(1, LOG):
            prev = parent[k - 1]
            curr = parent[k]
            for node in range(n):
                mid = prev[node]
                if mid != -1:
                    curr[node] = prev[mid]

        def lca(u: int, v: int) -> int:
            if depth[u] < depth[v]:
                u, v = v, u

            diff = depth[u] - depth[v]

            bit = 0
            while diff:
                if diff & 1:
                    u = parent[bit][u]
                diff >>= 1
                bit += 1

            if u == v:
                return u

            for k in range(LOG - 1, -1, -1):
                if parent[k][u] != parent[k][v]:
                    u = parent[k][u]
                    v = parent[k][v]

            return parent[0][u]

        def distance(u: int, v: int) -> int:
            ancestor = lca(u, v)
            return dist_root[u] + dist_root[v] - 2 * dist_root[ancestor]

        answer = []

        for src1, src2, dest in queries:
            d1 = distance(src1, src2)
            d2 = distance(src1, dest)
            d3 = distance(src2, dest)
            answer.append((d1 + d2 + d3) // 2)

        return answer

The implementation begins by constructing the adjacency list representation of the tree. A DFS rooted at node 0 computes depths, distances from the root, and immediate parents.

Next, binary lifting preprocessing builds ancestor tables for powers of two. These tables allow efficient LCA queries.

The lca() function first equalizes depths and then lifts both nodes upward until their lowest common ancestor is found.

The distance() helper uses the standard tree distance formula based on root distances and LCA.

For each query, the algorithm computes the three pairwise distances among the terminal nodes and applies the Steiner-tree formula. Since all arithmetic involves integer path weights, the result is always an integer.

class Solution: def minimumWeight(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]: import sys sys.setrecursionlimit(1 << 20)

    n = len(edges) + 1
    adj = [[] for _ in range(n)]
    for u, v, w in edges:
        adj[u].append((v, w))
        adj[v].append((u, w))
    
    LOG = 17  # since 2^17 > 1e5
    up = [[-1] * LOG for _ in range(n)]
    depth = [0] * n
    dist = [0] * n
    
    def dfs(u, p):
        up[u][0] = p
        for i in range(1, LOG):
            if up[u][i-1] != -1:
                up[u][i] = up[up[u][i-1]][i-1]
        for v, w in adj[u]:
            if v != p:
                depth[v] = depth[u] + 1
                dist[v] = dist[u] + w
                dfs(v, u)
    
    dfs(0, -1)
    
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for i in reversed(range(LOG)):
            if up[u][i] != -1 and depth[up[u][i]] >= depth[v]:
                u = up[u][i]
        if u == v:
            return u
        for i in reversed(range(LOG)):
            if up[u][i] != -1 and up[u][i] != up[v][i]:
                u = up[u][i]
                v = up[v][i]
        return up[u][0]
    
    def distance(u, v):
        return dist[u] + dist[v] - 2 * dist[lca(u, v)]
    
    answer = []
    for src1, src2, dest in queries:
        lca_pair = lca(src1, src2)
        ans = distance(src1, dest) + distance(src2, dest) - distance(lca_pair, dest)
        answer.append(ans)
    return answer

The implementation first constructs the tree, then performs DFS to populate depths, distances, and binary lifting tables for LCA queries. Distances between any pair of nodes are computed using the standard formula `dist[u] + dist[v] - 2 * dist[lca(u,v)]`. For each query, the minimal subtree weight is computed as described.

## Go Solution

```go
func minimumWeight(edges [][]int, queries [][]int) []int {
	n := len(edges) + 1

	type Edge struct {
		to int
		w  int
	}

	graph := make([][]Edge, n)

	for _, e := range edges {
		u, v, w := e[0], e[1], e[2]
		graph[u] = append(graph[u], Edge{v, w})
		graph[v] = append(graph[v], Edge{u, w})
	}

	LOG := 0
	for (1 << LOG) <= n {
		LOG++
	}

	parent := make([][]int, LOG)
	for i := range parent {
		parent[i] = make([]int, n)
		for j := range parent[i] {
			parent[i][j] = -1
		}
	}

	depth := make([]int, n)
	distRoot := make([]int64, n)

	var dfs func(int, int)
	dfs = func(u, p int) {
		parent[0][u] = p

		for _, e := range graph[u] {
			v := e.to
			if v == p {
				continue
			}

			depth[v] = depth[u] + 1
			distRoot[v] = distRoot[u] + int64(e.w)
			dfs(v, u)
		}
	}

	dfs(0, -1)

	for k := 1; k < LOG; k++ {
		for node := 0; node < n; node++ {
			mid := parent[k-1][node]
			if mid != -1 {
				parent[k][node] = parent[k-1][mid]
			}
		}
	}

	lca := func(u, v int) int {
		if depth[u] < depth[v] {
			u, v = v, u
		}

		diff := depth[u] - depth[v]

		for k := 0; diff > 0; k++ {
			if diff&1 == 1 {
				u = parent[k][u]
			}
			diff >>= 1
		}

		if u == v {
			return u
		}

		for k := LOG - 1; k >= 0; k-- {
			if parent[k][u] != parent[k][v] {
				u = parent[k][u]
				v = parent[k][v]
			}
		}

		return parent[0][u]
	}

	distance := func(u, v int) int64 {
		a := lca(u, v)
		return distRoot[u] + distRoot[v] - 2*distRoot[a]
	}

	ans := make([]int, len(queries))

	for i, q := range queries {
		s1, s2, d := q[0], q[1], q[2]

		d1 := distance(s1, s2)
		d2 := distance(s1, d)
		d3 := distance(s2, d)

		ans[i] = int((d1 + d2 + d3) / 2)
	}

	return ans
}

The Go implementation mirrors the Python solution. The main difference is that distances are stored as int64 because the total path weight can exceed the range of a 32-bit integer. Binary lifting tables are represented as slices of slices, and recursion is used for DFS exactly as in the Python version.

Worked Examples

Example 1

Input:

edges =
[[0,1,2],
 [1,2,3],
 [1,3,5],
 [1,4,4],
 [2,5,6]]

query = [2,3,4]

First compute pairwise distances.

Pair Distance
2 ↔ 3 3 + 5 = 8
2 ↔ 4 3 + 4 = 7
3 ↔ 4 5 + 4 = 9

Apply the formula:

Value Result
d(2,3) 8
d(2,4) 7
d(3,4) 9
Sum 24
Answer 24 / 2 = 12

Output:

12

Example 1, Second Query

Query:

[0,2,5]

Distances:

Pair Distance
0 ↔ 2 5
0 ↔ 5 11
2 ↔ 5 6

Formula:

Value Result
d(0,2) 5
d(0,5) 11
d(2,5) 6
Sum 22
Answer 11

Output:

11

Example 2

Input:

edges = [[1,0,8],[0,2,7]]
query = [0,1,2]

Distances:

Pair Distance
0 ↔ 1 8
0 ↔ 2 7
1 ↔ 2 15

Formula:

Value Result
Sum 30
Answer 15

Output:

15

Complexity Analysis

Measure Complexity Explanation
Time O(n log n + q log n) LCA preprocessing plus O(log n) per query
Space O(n log n) Binary lifting ancestor table

The DFS traversal is linear in the number of nodes. Building the ancestor table requires O(n log n) time and memory. Each query performs three distance computations, and each distance computation requires one LCA query, resulting in O(log n) time per query.

Test Cases

sol = Solution()

# Example 1
assert sol.minimumWeight(
    [[0,1,2],[1,2,3],[1,3,5],[1,4,4],[2,5,6]],
    [[2,3,4],[0,2,5]]
) == [12, 11]

# Example 2
assert sol.minimumWeight(
    [[1,0,8],[0,2,7]],
    [[0,1,2]]
) == [15]

# Small chain
assert sol.minimumWeight(
    [[0,1,1],[1,2,2]],
    [[0,1,2]]
) == [3]  # entire chain

# Star tree
assert sol.minimumWeight(
    [[0,1,1],[0,2,2],[0,3,3],[0,4,4]],
    [[1,2,3]]
) == [6]  # 1-0-2 and 3 share center

# One node lies on path between others
assert sol.minimumWeight(
    [[0,1,5],[1,2,7],[2,3,9]],
    [[0,2,3]]
) == [21]

# Heavy overlap
assert sol.minimumWeight(
    [[0,1,1],[1,2,1],[2,3,1],[3,4,1]],
    [[0,1,4]]
) == [4]

# Multiple queries
assert sol.minimumWeight(
    [[0,1,2],[1,2,3],[2,3,4],[3,4,5]],
    [[0,2,4],[1,3,4],[0,1,2]]
) == [14, 12, 5]

# Symmetry check
assert sol.minimumWeight(
    [[0,1,10],[1,2,20],[1,3,30]],
    [[2,3,0]]
) == [60]
Test Why
Example 1 Validates official sample
Example 2 Validates smallest nontrivial tree
Small chain Tests linear topology
Star tree Tests branching through common center
Node on path Tests maximal path overlap
Heavy overlap Ensures duplicate edges are not double counted
Multiple queries Verifies reuse of preprocessing
Symmetry check Confirms Steiner-tree formula correctness

Edge Cases

One terminal lies on the path between the other two

Suppose the query nodes are arranged along a chain such that one node already lies on the path connecting the other two. In this case the required subtree is simply a single path. A naive union computation can accidentally double count overlapping segments. The Steiner-tree formula automatically handles this because overlapping edges contribute correctly to the pairwise distance sum.

Highly unbalanced trees

The tree may degenerate into a chain of length 100,000. Recursive or per-query path reconstruction approaches become inefficient. The binary lifting structure still supports O(log n) LCA queries regardless of tree shape, so performance remains acceptable.

Large edge weights

Each edge weight can be as large as 10,000, and a path may contain up to 100,000 edges. Total distances can therefore exceed 32-bit integer limits. The Go solution uses int64 for distance accumulation, ensuring correctness even in the largest cases. The Python implementation naturally supports arbitrary-size integers. package main

func minimumWeight(edges [][]int, queries [][]int) []int { n := len(edges) + 1 adj := make([][][2]int, n) for _, e := range edges { u, v, w := e[0], e[1], e[2] adj[u] = append(adj[u], [2]int{v, w}) adj[v] = append(adj[v], [2]int{u, w}) }

const LOG = 17
up := make([][]int, n)
for i := range up {
    up[i] = make([]int, LOG)
    for j := 0; j < LOG; j++ {
        up[i][j] = -1
    }
}
depth := make([]int, n)
dist := make([]int, n)

var dfs func(u, p int)
dfs = func(u, p int) {
    up[u][0] = p
    for i := 1; i < LOG; i++ {
        if up[u][i