LeetCode 366: Find Leaves of Binary Tree

A clear explanation of grouping binary tree nodes by the round in which they become leaves using postorder DFS.

Problem Restatement

We are given the root of a binary tree.

We need to collect the tree's nodes as if we repeatedly do this:

  1. Collect all current leaf nodes.
  2. Remove those leaf nodes.
  3. Repeat until the tree is empty.

Return a list of lists.

Each inner list contains the node values removed in the same round.

The order inside each round does not matter.

The official example is:

root = [1, 2, 3, 4, 5]

Output:

[[4, 5, 3], [2], [1]]

Other orders within the first group, such as [[3, 4, 5], [2], [1]], are also accepted. The constraints say the tree has between 1 and 100 nodes, and node values are between -100 and 100.

Input and Output

Item Meaning
Input Root of a binary tree
Output List of leaf-removal rounds
Leaf A node with no children
Order inside one round Does not matter
Main goal Group nodes by when they become leaves

Example function shape:

def findLeaves(root: Optional[TreeNode]) -> list[list[int]]:
    ...

Examples

Example 1:

      1
     / \
    2   3
   / \
  4   5

First round:

[4, 5, 3]

These are the original leaves.

After removing them, the tree becomes:

  1
 /
2

Second round:

[2]

After removing 2, the tree becomes:

1

Third round:

[1]

So the answer is:

[[4, 5, 3], [2], [1]]

Example 2:

root = [1]

The root is already a leaf.

Answer:

[[1]]

First Thought: Simulate Removal

A direct approach is to repeatedly scan the tree.

In each round:

  1. Find all current leaves.
  2. Add their values to the answer.
  3. Remove them from their parent.
  4. Repeat until the root is removed.

This matches the problem statement closely.

But it is inefficient because we may scan the same nodes many times.

For a skewed tree with n nodes, each round removes only one node. Repeated scanning can cost:

O(n^2)

We can do better with one DFS traversal.

Key Insight

A node is removed based on its height from the bottom.

Define height like this:

Node type Height
None child -1
Leaf node 0
Parent of a leaf 1
Root above height-1 child 2

So:

height(node) = 1 + max(height(left), height(right))

Leaves have height 0 because both children have height -1.

Nodes with the same height are removed in the same round.

For the tree:

      1
     / \
    2   3
   / \
  4   5

The heights are:

Node Height Removal round
4 0 first
5 0 first
3 0 first
2 1 second
1 2 third

So we only need to compute each node's height and append its value to answer[height].

Algorithm

Use postorder DFS.

For each node:

  1. Recursively compute the height of the left subtree.
  2. Recursively compute the height of the right subtree.
  3. Compute this node's height:
height = 1 + max(left_height, right_height)
  1. Ensure answer has a list for this height.
  2. Append node.val to answer[height].
  3. Return height.

Postorder traversal is necessary because we need children's heights before computing the current node's height.

Correctness

A leaf has no children. The DFS returns -1 for missing children, so a leaf receives height:

1 + max(-1, -1) = 0

Therefore all original leaves are placed into answer[0], which is the first removal round.

For any non-leaf node, it can become a leaf only after all of its children have been removed. If its tallest child has height h, that child is removed in round h. The current node becomes removable one round later, so its removal round is h + 1.

The DFS computes exactly:

1 + max(left_height, right_height)

which matches that removal round.

By induction from leaves upward, every node is placed into the list corresponding to the round in which it becomes a leaf.

Therefore, the algorithm returns exactly the required groups.

Complexity

Let n be the number of nodes.

Metric Value Why
Time O(n) Each node is visited once
Space O(n) Output stores all node values
Recursion stack O(h) h is the tree height

For a balanced tree, the recursion stack is O(log n).

For a skewed tree, it can be O(n).

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def findLeaves(self, root: Optional[TreeNode]) -> list[list[int]]:
        answer = []

        def dfs(node: Optional[TreeNode]) -> int:
            if node is None:
                return -1

            left_height = dfs(node.left)
            right_height = dfs(node.right)

            height = 1 + max(left_height, right_height)

            if height == len(answer):
                answer.append([])

            answer[height].append(node.val)

            return height

        dfs(root)
        return answer

Code Explanation

The answer list stores groups by height:

answer = []

The DFS returns the height from the bottom.

For a missing child, we return -1:

if node is None:
    return -1

This makes a leaf height 0.

We compute child heights first:

left_height = dfs(node.left)
right_height = dfs(node.right)

Then compute the current height:

height = 1 + max(left_height, right_height)

If this is the first node at that height, create a new group:

if height == len(answer):
    answer.append([])

Then append the node value:

answer[height].append(node.val)

Finally, return the height to the parent:

return height

Testing

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def run_tests():
    s = Solution()

    root = TreeNode(
        1,
        TreeNode(2, TreeNode(4), TreeNode(5)),
        TreeNode(3),
    )
    assert s.findLeaves(root) == [[4, 5, 3], [2], [1]]

    root = TreeNode(1)
    assert s.findLeaves(root) == [[1]]

    root = TreeNode(1, TreeNode(2, TreeNode(3)))
    assert s.findLeaves(root) == [[3], [2], [1]]

    root = TreeNode(
        1,
        TreeNode(2),
        TreeNode(3, None, TreeNode(4)),
    )
    assert s.findLeaves(root) == [[2, 4], [3], [1]]

    print("all tests passed")

run_tests()

Test meaning:

Test Why
Standard tree Checks multiple leaves in first round
Single node Root is already a leaf
Left-skewed tree One node removed per round
Mixed shape Checks different subtree heights