LeetCode 437: Path Sum III
Count downward paths in a binary tree whose values sum to targetSum using DFS and prefix sums.
Problem Restatement
We are given the root of a binary tree and an integer targetSum.
We need to count how many paths have node values that add up to targetSum.
A valid path:
| Rule | Meaning |
|---|---|
| Can start anywhere | It does not have to start at the root |
| Can end anywhere | It does not have to end at a leaf |
| Must go downward | It can only move from parent to child |
So this path is valid:
parent -> child -> grandchild
But this path is not valid:
left child -> parent -> right child
because it moves upward.
The official problem asks for the number of downward paths whose sum equals targetSum. The tree can be empty, and node values may be large or negative.
Input and Output
| Item | Meaning |
|---|---|
| Input | root of a binary tree and integer targetSum |
| Output | Number of valid downward paths |
| Path start | Any node |
| Path end | Any node below the start node, including itself |
| Empty tree | Return 0 |
The node class is usually:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
Examples
Example 1:
root = [10, 5, -3, 3, 2, None, 11, 3, -2, None, 1]
targetSum = 8
The answer is:
3
The three paths are:
5 -> 3
5 -> 2 -> 1
-3 -> 11
Example 2:
root = [5, 4, 8, 11, None, 13, 4, 7, 2, None, None, 5, 1]
targetSum = 22
The answer is:
3
Example 3:
root = None
targetSum = 0
The answer is:
0
There are no nodes, so there are no paths.
First Thought: Start DFS From Every Node
A direct approach is:
- Pick every node as a possible path start.
- From that node, try every downward path.
- Count paths whose sum equals
targetSum.
This works, but it can revisit the same subtree many times.
For a skewed tree, the time complexity can become:
O(n^2)
We need to count all possible downward paths while visiting each node only once.
Key Insight
This problem is a tree version of subarray sum with prefix sums.
Suppose we are walking from the root to the current node and the running sum is:
current_sum
We want to know whether there is an earlier point on the same root-to-current path where the prefix sum was:
current_sum - targetSum
Why?
If:
current_sum - old_prefix_sum = targetSum
then the path after that old prefix and ending at the current node has sum targetSum.
Rearrange:
old_prefix_sum = current_sum - targetSum
So during DFS, we keep a hash map:
prefix_count[prefix_sum] = frequency
This map stores prefix sums only along the current root-to-node path.
Algorithm
Initialize:
prefix_count = {0: 1}
The prefix sum 0 handles paths that start at the root.
Run DFS with:
dfs(node, current_sum)
For each node:
- Add the node value to
current_sum. - Count how many earlier prefixes equal
current_sum - targetSum. - Add the current prefix sum to the hash map.
- Recurse into left and right children.
- Remove the current prefix sum from the hash map before returning.
The last step is important. Prefix sums from one branch must not affect another branch.
Correctness
At any point during DFS, prefix_count contains exactly the prefix sums on the current path from the root to the parent of the current node.
When we visit a node, we update current_sum to include that node.
A downward path ending at the current node has sum targetSum exactly when there exists an earlier prefix sum p such that:
current_sum - p == targetSum
This is equivalent to:
p == current_sum - targetSum
So prefix_count[current_sum - targetSum] gives exactly the number of valid paths ending at the current node.
After counting paths ending at the current node, we add current_sum to prefix_count before processing children. This allows child paths to start at the current node or above it.
After processing both children, we decrement the current prefix sum. This restores the hash map to its previous state before returning to the parent, so sibling branches do not share invalid prefix sums.
Because the DFS visits every node and counts exactly the valid paths ending at that node, the final total is the number of all valid downward paths.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) |
Each node is visited once |
| Space | O(h) |
Hash map and recursion stack follow the current root-to-node path |
Here, n is the number of nodes.
h is the height of the tree.
In the worst case, h = n.
Implementation
from collections import defaultdict
class Solution:
def pathSum(self, root: 'Optional[TreeNode]', targetSum: int) -> int:
prefix_count = defaultdict(int)
prefix_count[0] = 1
def dfs(node: 'Optional[TreeNode]', current_sum: int) -> int:
if not node:
return 0
current_sum += node.val
total = prefix_count[current_sum - targetSum]
prefix_count[current_sum] += 1
total += dfs(node.left, current_sum)
total += dfs(node.right, current_sum)
prefix_count[current_sum] -= 1
return total
return dfs(root, 0)
Code Explanation
We use a hash map to count prefix sums:
prefix_count = defaultdict(int)
prefix_count[0] = 1
The initial 0 means there is one empty prefix before the root.
The DFS receives the current running sum:
def dfs(node, current_sum):
If the node is missing, it contributes no paths:
if not node:
return 0
We include the current node:
current_sum += node.val
Now we count paths ending exactly at this node:
total = prefix_count[current_sum - targetSum]
Then we add the current prefix before going downward:
prefix_count[current_sum] += 1
Now the left and right children can use this prefix:
total += dfs(node.left, current_sum)
total += dfs(node.right, current_sum)
After both children are processed, we backtrack:
prefix_count[current_sum] -= 1
This removes the current node’s prefix from the active path.
Finally, the DFS returns how many valid paths were found in this subtree.
Testing
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def run_tests():
s = Solution()
root = TreeNode(
10,
TreeNode(
5,
TreeNode(
3,
TreeNode(3),
TreeNode(-2),
),
TreeNode(
2,
None,
TreeNode(1),
),
),
TreeNode(
-3,
None,
TreeNode(11),
),
)
assert s.pathSum(root, 8) == 3
root = TreeNode(
5,
TreeNode(
4,
TreeNode(
11,
TreeNode(7),
TreeNode(2),
),
),
TreeNode(
8,
TreeNode(13),
TreeNode(
4,
TreeNode(5),
TreeNode(1),
),
),
)
assert s.pathSum(root, 22) == 3
assert s.pathSum(None, 0) == 0
root = TreeNode(1, TreeNode(-1), TreeNode(1))
assert s.pathSum(root, 0) == 1
root = TreeNode(0, TreeNode(0), TreeNode(0))
assert s.pathSum(root, 0) == 5
print("all tests passed")
run_tests()
Test meaning:
| Test | Why |
|---|---|
| Standard example | Checks paths starting below root |
Target 22 example |
Checks longer downward paths |
| Empty tree | Checks no-node case |
| Negative value | Checks prefix sums with subtraction |
| Zero values | Checks multiple overlapping valid paths |