LeetCode 124 - Binary Tree Maximum Path Sum

The problem is asking us to find the maximum path sum in a binary tree. A path is defined as any sequence of nodes connected by edges, where each node is included at most once. The path does not need to start at the root or end at a leaf.

LeetCode Problem 124

Difficulty: 🔴 Hard
Topics: Dynamic Programming, Tree, Depth-First Search, Binary Tree

Solution

Problem Understanding

The problem is asking us to find the maximum path sum in a binary tree. A path is defined as any sequence of nodes connected by edges, where each node is included at most once. The path does not need to start at the root or end at a leaf. The sum of a path is simply the sum of the node values along that path.

The input is the root of a binary tree, which may contain negative values, and the output is a single integer representing the largest possible path sum. The constraints tell us that the tree can be moderately large, up to 30,000 nodes, and node values can be negative or positive. This implies that a naive approach that considers every possible path explicitly will be too slow. Edge cases include trees with all negative values, trees with a single node, and skewed trees (linked-list-like structures).

Approaches

The brute-force approach would attempt to enumerate all possible paths in the tree and compute their sums. This would involve exploring all combinations of nodes starting from every node, which quickly becomes infeasible because the number of paths in a binary tree grows exponentially with the number of nodes. While it guarantees correctness, it is far too slow for trees with tens of thousands of nodes.

The optimal approach leverages a recursive depth-first search (DFS) with a key observation: for any given node, the maximum path sum through that node is either just the node itself, the node plus the maximum path sum of its left child, the node plus the maximum path sum of its right child, or the node plus both children. However, when returning a value to the parent, we can only include one branch (left or right) to avoid reusing nodes in a path. This insight allows us to compute the maximum path sum efficiently using a post-order DFS traversal.

Approach Time Complexity Space Complexity Notes
Brute Force O(2^n) O(n) Explore all possible paths starting from each node; too slow for large trees
Optimal O(n) O(h) DFS with post-order traversal, propagate maximum path sum through one child at a time; h is the tree height

Algorithm Walkthrough

  1. Initialize a global variable max_sum to negative infinity to track the overall maximum path sum.
  2. Define a recursive function dfs(node) that returns the maximum path sum starting from node and extending downward to one branch only.
  3. For each node, recursively compute the maximum path sum of the left child and right child using dfs(node.left) and dfs(node.right).
  4. If a child contributes a negative sum, discard it by taking max(left_sum, 0) and max(right_sum, 0).
  5. Compute the maximum path sum through the current node as node.val + left_sum + right_sum.
  6. Update the global max_sum if this path sum is larger than the current max_sum.
  7. Return to the parent the maximum sum including only one branch: node.val + max(left_sum, right_sum).
  8. Call dfs(root) to traverse the entire tree and finally return max_sum.

Why it works: The algorithm maintains an invariant that at each node, we know the maximum sum of a path starting at this node and extending downward. By considering both children for the global update, we account for paths that pass through the node, but by returning only one branch, we avoid double-counting nodes when extending to the parent. This ensures correctness while maintaining linear time complexity.

Python Solution

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        self.max_sum = float('-inf')
        
        def dfs(node: Optional[TreeNode]) -> int:
            if not node:
                return 0
            left = max(dfs(node.left), 0)
            right = max(dfs(node.right), 0)
            current_sum = node.val + left + right
            self.max_sum = max(self.max_sum, current_sum)
            return node.val + max(left, right)
        
        dfs(root)
        return self.max_sum

This Python solution initializes a global variable max_sum and uses a recursive DFS to compute the maximum path sum. At each node, it evaluates contributions from left and right children and updates max_sum for paths passing through the current node. When returning, it only propagates the larger child path to the parent, ensuring nodes are not reused.

Go Solution

/**
 * Definition for a binary tree node.
 * type TreeNode struct {
 *     Val int
 *     Left *TreeNode
 *     Right *TreeNode
 * }
 */
func maxPathSum(root *TreeNode) int {
    maxSum := -1 << 31 // initialize to minimum int value

    var dfs func(node *TreeNode) int
    dfs = func(node *TreeNode) int {
        if node == nil {
            return 0
        }
        left := max(dfs(node.Left), 0)
        right := max(dfs(node.Right), 0)
        currentSum := node.Val + left + right
        if currentSum > maxSum {
            maxSum = currentSum
        }
        return node.Val + max(left, right)
    }

    dfs(root)
    return maxSum
}

func max(a, b int) int {
    if a > b {
        return a
    }
    return b
}

The Go solution is similar, using a recursive dfs function. Go requires explicit handling of nil pointers and has no built-in max function for integers, so a helper function max is defined.

Worked Examples

Example 1: [1,2,3]

Node Left Max Right Max Path Through Node max_sum
2 0 0 2 2
3 0 0 3 3
1 2 3 6 6

Output: 6

Example 2: [-10,9,20,null,null,15,7]

Node Left Max Right Max Path Through Node max_sum
15 0 0 15 15
7 0 0 7 15
20 15 7 42 42
9 0 0 9 42
-10 9 20 19 42

Output: 42

Complexity Analysis

Measure Complexity Explanation
Time O(n) Each node is visited once in DFS
Space O(h) Recursion stack depth equals tree height

The algorithm is linear in time because every node is processed exactly once. Space is proportional to the height of the tree due to recursion, which can be O(log n) for balanced trees or O(n) for skewed trees.

Test Cases

# test cases
assert Solution().maxPathSum(TreeNode(1, TreeNode(2), TreeNode(3))) == 6  # Example 1
assert Solution().maxPathSum(TreeNode(-10, TreeNode(9), TreeNode(20, TreeNode(15), TreeNode(7)))) == 42  # Example 2
assert Solution().maxPathSum(TreeNode(-3)) == -3  # Single negative node
assert Solution().maxPathSum(TreeNode(2, TreeNode(-1))) == 2  # Path ignores negative child
assert Solution().maxPathSum(TreeNode(-2, TreeNode(-1))) == -1  # Path chooses less negative node
assert Solution().maxPathSum(TreeNode(1)) == 1  # Single node tree
Test Why
[1,2,3] Validates simple balanced tree
[-10,9,20,15,7] Validates tree with negative root and positive path through children
[-3] Single node, negative value
[2,-1] Ignores negative child for max sum
[-2,-1] Chooses least negative path
[1] Single node, positive

Edge Cases

All negative nodes: When all nodes are negative, the algorithm must correctly choose the least negative single node as the maximum path sum. Using max(left,0) ensures negative child paths are discarded when combining with the parent.

Single node tree: The tree contains only one node. The algorithm correctly handles this by initializing max_sum to negative infinity and updating it with the node value.

Skewed tree (linked list): If the tree is heavily skewed to the left or right, recursion depth could reach O(n). The algorithm still works because the DFS correctly propagates the maximum sum down the single path and