LeetCode 3585 - Find Weighted Median Node in Tree

We are given a weighted tree with n nodes. A tree is an undirected connected graph with exactly n - 1 edges, which guarantees that there is exactly one simple path between any two nodes. Each edge has a positive weight.

LeetCode Problem 3585

Difficulty: 🔴 Hard
Topics: Array, Binary Search, Dynamic Programming, Bit Manipulation, Tree, Depth-First Search

Solution

Problem Understanding

We are given a weighted tree with n nodes. A tree is an undirected connected graph with exactly n - 1 edges, which guarantees that there is exactly one simple path between any two nodes.

Each edge has a positive weight. For every query [u, v], we look at the unique path from node u to node v. Let the total weight of that path be:

$$D = \text{dist}(u, v)$$

The problem asks us to find the first node encountered while moving from u toward v whose accumulated path weight from u is at least half of the total path weight.

More formally, we want the first node x on the path such that:

$$\text{dist}(u, x) \ge \frac{D}{2}$$

and every node before x on the path has distance strictly smaller than D / 2.

The output is an array where each query contributes exactly one weighted median node.

The constraints are large:

  • n <= 100000
  • queries.length <= 100000
  • Edge weights can be as large as 10^9

These constraints immediately rule out any approach that traverses entire paths for every query. A solution that takes O(n) per query would become O(nq), which is far too slow.

A few important observations:

  • Edge weights are positive, so cumulative path distance strictly increases as we move along a path.
  • The tree is static, meaning we can afford heavy preprocessing.
  • Every query asks about a path between two nodes, which strongly suggests using Lowest Common Ancestor (LCA) preprocessing.
  • Distances may reach 10^14, so 64-bit integers are required.

Important edge cases include:

  • A path consisting of a single edge.
  • The weighted median being exactly the LCA.
  • The weighted median lying on the upward segment from u to the LCA.
  • The weighted median lying on the downward segment from the LCA to v.
  • Odd total path weights, where the threshold becomes a non-integer value such as 5.5.
  • Very large edge weights that require 64-bit arithmetic. The problem asks us to find a weighted median node along paths in a weighted, undirected tree. The tree has n nodes numbered from 0 to n-1, and edges have positive integer weights. Each query specifies a path from node u to node v. The weighted median of this path is defined as the first node x along the path from u to v where the cumulative sum of edge weights from u reaches at least half of the total path weight.

The input consists of:

  • An integer n, the number of nodes.
  • A list of edges [ui, vi, wi] defining an undirected tree with weighted edges.
  • A list of queries [uj, vj] asking for the weighted median along the path from uj to vj.

The output is an array where each element is the node index of the weighted median for the corresponding query.

Constraints:

  • The tree has up to 10^5 nodes.
  • Edge weights can be as large as 10^9.
  • There can be up to 10^5 queries.

These constraints imply that naive path enumeration for each query is too slow, because a naive approach could take O(n) per query, which would be O(n * q) = 10^10 in the worst case.

Important edge cases:

  • Queries where u and v are the same node.
  • Queries along a single edge.
  • Trees with skewed or unbalanced structures where path lengths vary widely.

Approaches

Brute Force

A straightforward solution is to process every query independently.

For a query [u, v], we could first reconstruct the entire path between the two nodes. Then we compute the total path weight and walk along the path while accumulating edge weights until we reach at least half of the total.

This works because the weighted median is defined directly in terms of cumulative path weight.

Unfortunately, reconstructing a path can take O(n) time in the worst case. With up to 100000 queries, the total complexity becomes O(nq), which is much too large.

Key Insight

The tree never changes, so we should preprocess information that allows us to answer path queries efficiently.

The key observations are:

  1. Every path can be decomposed using the Lowest Common Ancestor.
  2. Distances between nodes can be computed from root distances.
  3. Binary lifting allows us to jump upward in powers of two.
  4. Since edge weights are positive, cumulative distance along a path is monotonic, which means we can use binary lifting to locate the first node that crosses the halfway threshold.

Instead of walking along a path node by node, we use binary lifting to skip large portions of the path while maintaining distance information.

This reduces each query to O(log n) time.

Approach Time Complexity Space Complexity Notes
Brute Force O(nq) O(n) Reconstructs and scans paths directly
Optimal O((n + q) log n) O(n log n) Binary lifting, LCA, and weighted ancestor search

Algorithm Walkthrough

Preprocessing

  1. Build an adjacency list for the tree.
  2. Root the tree at node 0.
  3. Run DFS to compute:
  • depth[node]
  • parent[0][node]
  • dist[node], the weighted distance from the root
  1. Build the binary lifting table:
  • parent[k][node] stores the 2^k-th ancestor of node.

LCA Query

  1. To find the LCA of two nodes:
  • Lift the deeper node until both nodes have the same depth.
  • Lift both nodes simultaneously from highest power to lowest power.
  • The first common ancestor is the LCA.

Weighted Median Query

For a query (u, v):

  1. Compute:

$$l = \text{LCA}(u,v)$$

  1. Compute path distances:

$$du = \text{dist}(u,l)$$

$$dv = \text{dist}(v,l)$$

$$D = du + dv$$

  1. The median threshold is:

$$need = \left\lceil \frac{D}{2} \right\rceil$$

Instead of using floating-point arithmetic, we compute:

$$need = \frac{D + 1}{2}$$

using integer division.

Case 1: Median Lies on the Upward Segment

  1. If:

$$need \le du$$

then the median is somewhere on the path from u to l.

  1. Find the deepest ancestor whose distance from u is still smaller than need.
  2. The parent of that node is the first node whose distance reaches or exceeds the threshold.

Case 2: Median Lies on the Downward Segment

  1. Otherwise the median lies on the path from l to v.
  2. Let:

$$rem = need - du$$

This is how far we still need to travel after reaching the LCA.

  1. From the perspective of v, the answer is the highest ancestor whose distance from v is at most:

$$dv - rem$$

  1. Binary lifting finds that ancestor directly.
  2. Return the resulting node.

Why it works

The path distance increases monotonically because every edge weight is positive. Therefore, there is exactly one first node whose accumulated distance reaches at least half of the total path weight.

The LCA splits the path into two monotonic segments. Binary lifting allows us to jump through ancestors while preserving distance constraints. Since every jump maintains the invariant that we have not crossed the target threshold, the final node found is exactly the first node where the threshold is reached. The brute-force method traverses the path from u to v for each query, summing edge weights incrementally until reaching at least half the total path weight. This requires reconstructing the path in O(n) time for each query. Although correct, this approach is too slow for n and queries up to 10^5.

Optimal Approach

The key insight is that we can preprocess the tree to allow constant-time queries of path lengths and paths using Lowest Common Ancestor (LCA) and prefix sums of edge weights along paths from the root. Specifically:

  1. Root the tree at node 0.
  2. Compute depth and dist arrays via DFS, where dist[u] is the cumulative weight from the root to node u.
  3. Use binary lifting or a standard LCA algorithm to find the lowest common ancestor lca(u, v) of any query nodes in O(log n).
  4. The total path weight from u to v is dist[u] + dist[v] - 2 * dist[lca(u, v)].
  5. To find the weighted median, perform a binary search up the path from u toward v, using precomputed parent pointers, to locate the first node reaching at least half the path weight.

This reduces per-query complexity from O(n) to O(log n), with O(n log n) preprocessing.

Approach Time Complexity Space Complexity Notes
Brute Force O(n * q) O(n) Enumerate path per query
Optimal O((n + q) log n) O(n log n) Preprocess LCA + distances, binary search path

Algorithm Walkthrough

  1. Tree Construction: Build an adjacency list from the edges input. Each entry stores connected nodes and edge weights.
  2. DFS Preprocessing: Starting from root node 0, perform a depth-first search to compute:
  • dist[u] = cumulative distance from root to u.
  • parent[u][k] = the 2^k-th ancestor of node u for binary lifting.
  1. LCA Preprocessing: Fill parent[u][k] tables for all u and k ≤ log2(n) to support efficient ancestor queries.
  2. Distance Query Function: To compute the total path weight from u to v, use lca(u, v) and the precomputed dist values:
total_weight = dist[u] + dist[v] - 2 * dist[lca(u, v)]
half_weight = total_weight / 2
  1. Weighted Median Search: Starting from u, move up toward v:
  • Use binary lifting to jump as far as possible while the cumulative distance remains below half_weight.
  • Stop at the first node where the distance exceeds half_weight; this node is the weighted median.
  1. Repeat for Each Query: Apply steps 4-5 for each query in queries.

Why it works: By computing dist from root and using LCA, we can calculate cumulative weights along any path without explicit path traversal. Binary lifting ensures we find the first node exceeding half the path weight in logarithmic steps, preserving correctness.

Python Solution

from typing import List
import sys

class Solution:
    def findMedian(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        sys.setrecursionlimit(300000)

        LOG = (n).bit_length()

        graph = [[] for _ in range(n)]
        for u, v, w in edges:
            graph[u].append((v, w))
            graph[v].append((u, w))

        parent = [[-1] * n for _ in range(LOG)]
        depth = [0] * n
        dist = [0] * n

        def dfs(u: int, p: int) -> None:
            parent[0][u] = p

            for v, w in graph[u]:
                if v == p:
                    continue

                depth[v] = depth[u] + 1
                dist[v] = dist[u] + w
                dfs(v, u)

        dfs(0, -1)

        for k in range(1, LOG):
            for node in range(n):
                mid = parent[k - 1][node]
                if mid != -1:
                    parent[k][node] = parent[k - 1][mid]

        def lca(a: int, b: int) -> int:
            if depth[a] < depth[b]:
                a, b = b, a

            diff = depth[a] - depth[b]

            for k in range(LOG):
                if diff & (1 << k):
                    a = parent[k][a]

            if a == b:
                return a

            for k in range(LOG - 1, -1, -1):
                if parent[k][a] != parent[k][b]:
                    a = parent[k][a]
                    b = parent[k][b]

            return parent[0][a]

        def first_ancestor_reaching(start: int, need: int) -> int:
            cur = start
            travelled = 0

            for k in range(LOG - 1, -1, -1):
                nxt = parent[k][cur]

                if nxt == -1:
                    continue

                jump_dist = dist[start] - dist[nxt]

                if jump_dist < need:
                    cur = nxt
                    travelled = jump_dist

            return parent[0][cur]

        def highest_ancestor_within(start: int, limit: int) -> int:
            cur = start

            for k in range(LOG - 1, -1, -1):
                nxt = parent[k][cur]

                if nxt == -1:
                    continue

                jump_dist = dist[start] - dist[nxt]

                if jump_dist <= limit:
                    cur = nxt

            return cur

        answer = []

        for u, v in queries:
            p = lca(u, v)

            du = dist[u] - dist[p]
            dv = dist[v] - dist[p]

            total = du + dv
            need = (total + 1) // 2

            if need <= du:
                answer.append(first_ancestor_reaching(u, need))
            else:
                rem = need - du
                limit = dv - rem
                answer.append(highest_ancestor_within(v, limit))

        return answer

The preprocessing stage builds the adjacency list, computes depths and root distances, and fills the binary lifting table.

The lca function computes the Lowest Common Ancestor in O(log n) time.

The helper first_ancestor_reaching handles the case where the weighted median lies on the upward path from u toward the LCA. It finds the last ancestor whose accumulated distance is still below the threshold, then returns its parent.

The helper highest_ancestor_within handles the downward segment. Instead of traversing from the LCA toward v, it works from v upward and finds the highest ancestor that still satisfies the required distance constraint.

Each query performs a constant number of binary lifting operations, giving O(log n) query time.

class Solution: def findMedian(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]: import sys sys.setrecursionlimit(1 << 25)

    LOG = 17  # since 2^17 > 1e5
    adj = [[] for _ in range(n)]
    for u, v, w in edges:
        adj[u].append((v, w))
        adj[v].append((u, w))
    
    parent = [[-1]*LOG for _ in range(n)]
    depth = [0]*n
    dist = [0]*n
    
    def dfs(u, p):
        for v, w in adj[u]:
            if v == p:
                continue
            parent[v][0] = u
            depth[v] = depth[u] + 1
            dist[v] = dist[u] + w
            dfs(v, u)
    
    dfs(0, -1)
    
    # Binary lifting
    for k in range(1, LOG):
        for v in range(n):
            if parent[v][k-1] != -1:
                parent[v][k] = parent[parent[v][k-1]][k-1]
    
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in reversed(range(LOG)):
            if parent[u][k] != -1 and depth[parent[u][k]] >= depth[v]:
                u = parent[u][k]
        if u == v:
            return u
        for k in reversed(range(LOG)):
            if parent[u][k] != -1 and parent[u][k] != parent[v][k]:
                u = parent[u][k]
                v = parent[v][k]
        return parent[u][0]
    
    def get_kth_ancestor(u, k):
        for i in range(LOG):
            if k & (1 << i):
                u = parent[u][i]
        return u
    
    ans = []
    for u, v in queries:
        w = lca(u, v)
        total = dist[u] + dist[v] - 2*dist[w]
        half = total / 2
        
        # binary search along path from u to v
        path_len = depth[u] - depth[w]
        lo, hi = 0, path_len
        curr = u
        curr_dist = 0
        while lo <= hi:
            mid = (lo + hi) // 2
            ancestor = get_kth_ancestor(u, mid)
            d = dist[u] - dist[ancestor]
            if d < half:
                lo = mid + 1
                curr = ancestor
                curr_dist = d
            else:
                hi = mid - 1
        # check if moving toward v overshoots half
        node = curr
        remaining = half - curr_dist
        if remaining > 0:
            # move one step toward v if possible
            for neigh, wgt in adj[node]:
                if depth[neigh] > depth[node]:
                    if dist[neigh] - dist[u] + dist[u] - dist[w] >= half:
                        node = neigh
                        break
        ans.append(node)
    return ans

**Explanation:**

- Build adjacency list and DFS to compute parent, depth, and cumulative distance arrays.
- Precompute binary lifting tables to query ancestors quickly.
- For each query, calculate `lca` and total path weight.
- Use binary lifting to locate the first node along the path from `u` whose cumulative weight exceeds half of total.
- Append the result to `ans`.

## Go Solution

```go
package main

func findMedian(n int, edges [][]int, queries [][]int) []int {
	LOG := 0
	for (1 << LOG) <= n {
		LOG++
	}

	type Edge struct {
		to int
		w  int64
	}

	graph := make([][]Edge, n)

	for _, e := range edges {
		u, v, w := e[0], e[1], int64(e[2])

		graph[u] = append(graph[u], Edge{v, w})
		graph[v] = append(graph[v], Edge{u, w})
	}

	parent := make([][]int, LOG)
	for i := range parent {
		parent[i] = make([]int, n)
	const LOG = 17
	adj := make([][][2]int, n)
	for _, e := range edges {
		u, v, w := e[0], e[1], e[2]
		adj[u] = append(adj[u], [2]int{v, w})
		adj[v] = append(adj[v], [2]int{u, w})
	}
	parent := make([][]int, n)
	for i := 0; i < n; i++ {
		parent[i] = make([]int, LOG)
		for j := range parent[i] {
			parent[i][j] = -1
		}
	}

	depth := make([]int, n)
	dist := make([]int64, n)

	var dfs func(int, int)

	dfs = func(u, p int) {
		parent[0][u] = p

		for _, e := range graph[u] {
			v := e.to

			if v == p {
				continue
			}

			depth[v] = depth[u] + 1
			dist[v] = dist[u] + e.w

			dfs(v, u)
		}
	}

	dfs(0, -1)

	for k := 1; k < LOG; k++ {
		for node := 0; node < n; node++ {
			mid := parent[k-1][node]
			if mid != -1 {
				parent[k][node] = parent[k-1][mid]
			}
		}
	}

	lca := func(a, b int) int {
		if depth[a] < depth[b] {
			a, b = b, a
		}

		diff := depth[a] - depth[b]

		for k := 0; k < LOG; k++ {
			if (diff & (1 << k)) != 0 {
				a = parent[k][a]
			}
		}

		if a == b {
			return a
		}

		for k := LOG - 1; k >= 0; k-- {
			if parent[k][a] != parent[k][b] {
				a = parent[k][a]
				b = parent[k][b]
			}
		}

		return parent[0][a]
	}

	firstAncestorReaching := func(start int, need int64) int {
		cur := start

		for k := LOG - 1; k >= 0; k-- {
			nxt := parent[k][cur]

			if nxt == -1 {
				continue
			}

			jumpDist := dist[start] - dist[nxt]

			if jumpDist < need {
				cur = nxt
			}
		}

		return parent[0][cur]
	}

	highestAncestorWithin := func(start int, limit int64) int {
		cur := start

		for k := LOG - 1; k >= 0; k-- {
			nxt := parent[k][cur]

			if nxt == -1 {
				continue
			}

			jumpDist := dist[start] - dist[nxt]

			if jumpDist <= limit {
				cur = nxt
			}
		}

		return cur
	}

	ans := make([]int, 0, len(queries))

	for _, q := range queries {
		u, v := q[0], q[1]

		p := lca(u, v)

		du := dist[u] - dist[p]
		dv := dist[v] - dist[p]

		total := du + dv
		need := (total + 1) / 2

		if need <= du {
			ans = append(ans, firstAncestorReaching(u, need))
		} else {
			rem := need - du
			limit := dv - rem
			ans = append(ans, highestAncestorWithin(v, limit))
		}
	}

	return ans
}

The Go implementation mirrors the Python solution closely. The main difference is the explicit use of int64 for distances because path weights can reach approximately 10^14. All weighted distance calculations therefore use 64-bit integers to avoid overflow.

Worked Examples

Example 1

Input:

n = 2
edges = [[0,1,7]]
queries = [[1,0],[0,1]]

Tree:

0 --7-- 1

Query [1,0]

Value Result
LCA 0
du 7
dv 0
Total 7
need 4

Since need <= du, search upward from node 1.

Node Distance from 1
1 0
0 7

The first node reaching at least 4 is 0.

Answer: 0

Query [0,1]

Value Result
LCA 0
du 0
dv 7
Total 7
need 4

Now the median lies on the downward segment.

The first node from 0 reaching distance at least 4 is node 1.

Answer: 1

Final output:

[0, 1]

Example 2

Input:

edges = [[0,1,2],[2,0,4]]
query = [1,2]

Path:

1 --2-- 0 --4-- 2
Value Result
LCA 0
du 2
dv 4
Total 6
need 3

Since need > du, the median is on the LCA-to-v segment.

rem = 3 - 2 = 1
limit = 4 - 1 = 3

From node 2:

Ancestor Distance from 2
2 0
0 4

The highest ancestor within distance 3 is node 2.

Answer:

2

Example 3

Input:

edges =
[[0,1,2],
 [0,2,5],
 [1,3,1],
 [2,4,3]]

query = [3,4]

Path:

3 --1-- 1 --2-- 0 --5-- 2 --3-- 4
Value Result
LCA 0
du 3
dv 8
Total 11
need 6

Since need > du:

rem = 6 - 3 = 3
limit = 8 - 3 = 5

From node 4:

Ancestor Distance from 4
4 0
2 3
0 8

The highest ancestor within distance 5 is node 2.

Answer:

2

Complexity Analysis

Measure Complexity Explanation
Time O((n + q) log n) Preprocessing and each query use binary lifting
Space O(n log n) Ancestor sparse table dominates memory

The DFS preprocessing takes O(n). Building the binary lifting table takes O(n log n). Every query performs an LCA computation and one additional binary lifting search, both of which require O(log n) time. Therefore the total complexity is:

$$O(n \log n + q \log n)$$

which is efficient for 100000 nodes and queries.

Test Cases

sol = Solution()

# Example 1
assert sol.findMedian(
    2,
    [[0, 1, 7]],
    [[1, 0], [0, 1]]
) == [0, 1]  # single edge

# Example 2
assert sol.findMedian(
    3,
    [[0, 1, 2], [2, 0, 4]],
    [[0, 1], [2, 0], [1, 2]]
) == [1, 0, 2]  # median on different segments

# Example 3
assert sol.findMedian(
    5,
    [[0, 1, 2], [0, 2, 5], [1, 3, 1], [2, 4, 3]],
    [[3, 4], [1, 2]]
) == [2, 2]  # crosses the LCA

# Chain tree
assert sol.findMedian(
    4,
    [[0, 1, 1], [1, 2, 1], [2, 3, 1]],
    [[0, 3]]
) == [2]  # exact midpoint

# Heavy edge dominates
assert sol.findMedian(
    3,
    [[0, 1, 1], [1, 2, 100]],
    [[0, 2]]
) == [2]  # crossing occurs on large edge

# Median equals LCA
assert sol.findMedian(
    3,
    [[0, 1, 5], [0, 2, 5]],
    [[1, 2]]
) == [0]  # exact half reached at LCA

# Root involved
assert sol.findMedian(
    3,
    [[0, 1, 4], [1, 2, 6]],
    [[0, 2]]
) == [2]  # downward segment only

# Multiple queries
assert sol.findMedian(
    5,
    [[0, 1, 2], [1, 2, 2], [2, 3, 2], [3, 4, 2]],
    [[0, 4], [1, 4], [2, 4]]
) == [2, 3, 3]  # repeated path checks
Test Why
Single edge tree Smallest valid tree
Example 2 Median may lie on either side of LCA
Example 3 Long weighted path crossing the LCA
Uniform chain Exact midpoint behavior
Heavy edge Large weight causes immediate crossing
Median equals LCA Important equality case
Root involved Path starts at root
Multiple queries Reuses preprocessing heavily

Edge Cases

One important edge case occurs when the weighted median is exactly the LCA. This happens when the cumulative distance reaches half of the path weight precisely at the ancestor where the two path segments meet. Implementations that use strict versus non-strict inequalities incorrectly can easily return a child of the LCA instead. The binary lifting conditions are carefully chosen so that the first node satisfying the threshold is returned.

Another important case is when a single edge weight is much larger than all others. In such situations, the weighted median may jump directly across that edge. A naive approach that reasons in terms of node counts rather than weighted distance would fail. Because the algorithm uses actual path weights, it correctly identifies the crossing point regardless of edge distribution.

A third edge case involves odd total path weights. For example, if the total distance is 11, the threshold is 5.5. Using floating-point arithmetic introduces unnecessary precision concerns. The implementation avoids this entirely by using:

$$\left\lceil \frac{D}{2} \right\rceil$$

computed as:

(D + 1) // 2

This preserves correctness while keeping all calculations in integer arithmetic.

The solution relies on binary lifting, LCA preprocessing, and weighted distance calculations to answer every query in logarithmic time, making it suitable for the full constraint range. depth := make([]int, n) dist := make([]int64, n)