LeetCode 1740 - Find Distance in a Binary Tree

This problem asks us to determine the distance between two nodes in a binary tree, given their values p and q. The distance is defined as the number of edges in the shortest path connecting the two nodes.

LeetCode Problem 1740

Difficulty: 🟔 Medium
Topics: Hash Table, Tree, Depth-First Search, Breadth-First Search, Binary Tree

Solution

Problem Understanding

This problem asks us to determine the distance between two nodes in a binary tree, given their values p and q. The distance is defined as the number of edges in the shortest path connecting the two nodes. The input is the root of a binary tree and two integers representing values of nodes in that tree. The output is a single integer representing the distance between these two nodes.

The constraints provide important information. The tree contains between 1 and 10,000 nodes, each node has a unique value between 0 and 10^9, and the values p and q are guaranteed to exist in the tree. These constraints suggest that we can rely on the uniqueness of values to identify nodes and do not need to handle duplicates. Because the number of nodes can be up to 10,000, an O(n²) solution may be inefficient, and we should aim for O(n) time complexity.

Important edge cases include when p and q are the same (distance is zero), when one node is the parent of the other (distance is 1), and when both nodes are at the maximum depth in a large tree. Because all values are unique, we do not have to consider multiple matches for a value.

Approaches

The brute-force approach would first find the path from the root to p and q individually, store both paths, and then compare them to find the last common node (the lowest common ancestor, LCA). The distance would then be calculated as the sum of the steps from each node to the LCA. While correct, this requires storing two paths, which uses extra space and involves traversing the tree twice, making it less efficient for large trees.

The optimal approach leverages the concept of the Lowest Common Ancestor (LCA). The key insight is that the distance between two nodes p and q can be computed as the sum of their distances from their LCA. To implement this, we perform a single DFS to find the LCA and compute distances along the way. This reduces the need for extra path storage and ensures linear traversal of the tree.

Approach Time Complexity Space Complexity Notes
Brute Force O(n) O(n) Find paths to p and q separately, compare paths to find LCA, sum distances
Optimal O(n) O(h) Single DFS to find LCA and distances, h = height of tree

Algorithm Walkthrough

  1. Define a helper function lca(node, p, q) that recursively finds the Lowest Common Ancestor of nodes with values p and q.
  2. In the same recursion, if the current node is either p or q, return the node itself.
  3. Recursively traverse the left and right subtrees.
  4. If both left and right recursive calls return non-null nodes, the current node is the LCA.
  5. Otherwise, propagate whichever child is non-null.
  6. Once the LCA is found, compute the distance from the LCA to p and from the LCA to q separately using a helper distance_from(node, target).
  7. Return the sum of the two distances as the final result.

Why it works: This works because the LCA is the deepest node that lies on the path from the root to both p and q. By summing the distances from the LCA to each node, we count each edge exactly once in the shortest path connecting p and q.

Python Solution

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
        def lca(node: Optional[TreeNode], p: int, q: int) -> Optional[TreeNode]:
            if not node or node.val == p or node.val == q:
                return node
            left = lca(node.left, p, q)
            right = lca(node.right, p, q)
            if left and right:
                return node
            return left or right
        
        def distance_from(node: Optional[TreeNode], target: int) -> int:
            if node.val == target:
                return 0
            if node.left:
                left_distance = distance_from(node.left, target)
                if left_distance != -1:
                    return left_distance + 1
            if node.right:
                right_distance = distance_from(node.right, target)
                if right_distance != -1:
                    return right_distance + 1
            return -1
        
        ancestor = lca(root, p, q)
        return distance_from(ancestor, p) + distance_from(ancestor, q)

The Python implementation defines two helper functions. lca identifies the lowest common ancestor of p and q using a standard recursive approach. distance_from calculates the number of edges from a given node to the target node. The sum of distances from the LCA to each node gives the total distance.

Go Solution

/**
 * Definition for a binary tree node.
 * type TreeNode struct {
 *     Val int
 *     Left *TreeNode
 *     Right *TreeNode
 * }
 */
func findDistance(root *TreeNode, p int, q int) int {
    var lca func(node *TreeNode, p, q int) *TreeNode
    lca = func(node *TreeNode, p, q int) *TreeNode {
        if node == nil || node.Val == p || node.Val == q {
            return node
        }
        left := lca(node.Left, p, q)
        right := lca(node.Right, p, q)
        if left != nil && right != nil {
            return node
        }
        if left != nil {
            return left
        }
        return right
    }

    var distanceFrom func(node *TreeNode, target int) int
    distanceFrom = func(node *TreeNode, target int) int {
        if node.Val == target {
            return 0
        }
        if node.Left != nil {
            leftDist := distanceFrom(node.Left, target)
            if leftDist != -1 {
                return leftDist + 1
            }
        }
        if node.Right != nil {
            rightDist := distanceFrom(node.Right, target)
            if rightDist != -1 {
                return rightDist + 1
            }
        }
        return -1
    }

    ancestor := lca(root, p, q)
    return distanceFrom(ancestor, p) + distanceFrom(ancestor, q)
}

The Go implementation mirrors the Python logic. Nil checks replace Python's None. Function closures allow recursion with local state. Return values use -1 to indicate absence, as in Python.

Worked Examples

Example 1: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 0

  1. Find LCA of 5 and 0 → node 3.
  2. Distance from 3 to 5 → 1 edge to 5's parent 3, then 1 edge to 5 → total 1 (actually 3-5 = 1) wait, step carefully: 3 → 5 = 1 edge. Correct.
  3. Distance from 3 to 0 → 3 → 1 → 0 = 2 edges.
  4. Total distance = 1 + 2 = 3.

Example 2: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 7

  1. LCA of 5 and 7 → node 5.
  2. Distance from 5 to 5 → 0 edges.
  3. Distance from 5 to 7 → 5 → 2 → 7 = 2 edges.
  4. Total distance = 0 + 2 = 2.

Example 3: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 5

  1. LCA of 5 and 5 → node 5.
  2. Distance from 5 to 5 → 0 edges.
  3. Distance from 5 to 5 → 0 edges.
  4. Total distance = 0 + 0 = 0.

Complexity Analysis

Measure Complexity Explanation
Time O(n) Single DFS to find LCA and distances touches each node at most twice.
Space O(h) Recursive stack space proportional to tree height h. In worst case, h = n (skewed tree).

The algorithm is linear in the number of nodes because each node is visited at most twice, once for finding the LCA and once for computing distances.

Test Cases

# Provided examples
assert Solution().findDistance(TreeNode(3, TreeNode(5, TreeNode(6), TreeNode(2, TreeNode(7), TreeNode(4))), TreeNode(1, TreeNode(0), TreeNode(8))), 5, 0) == 3  # 5-3-1-0
assert Solution().findDistance(TreeNode(3, TreeNode(5