LeetCode 669: Trim a Binary Search Tree

A clear explanation of trimming a BST so that all remaining node values lie inside a given inclusive range.

Problem Restatement

We are given the root of a binary search tree and two boundaries:

Variable Meaning
low Lowest allowed node value
high Highest allowed node value

We need to trim the tree so that every remaining node value lies inside:

[low, high]

The relative structure of the remaining nodes must stay the same. If one remaining node was a descendant of another remaining node before trimming, it should still be a descendant after trimming.

Return the root of the trimmed tree. The root may change if the original root is outside the allowed range. The official problem states that the input is a valid BST, node values are unique, and the answer is unique.

Input and Output

Item Meaning
Input Root of a BST, integer low, integer high
Output Root of the trimmed BST
Keep rule Keep nodes with low <= node.val <= high
Remove rule Remove nodes outside the range
Structure rule Preserve relative structure among remaining nodes

Example function shape:

def trimBST(root: Optional[TreeNode], low: int, high: int) -> Optional[TreeNode]:
    ...

Examples

Consider:

root = [1, 0, 2]
low = 1
high = 2

The tree is:

  1
 / \
0   2

The value 0 is smaller than low, so it must be removed.

The trimmed tree is:

1
 \
  2

So the output is:

[1, None, 2]

Another example:

root = [3, 0, 4, None, 2, None, None, 1]
low = 1
high = 3

The tree is:

      3
     / \
    0   4
     \
      2
     /
    1

Values outside [1, 3] are removed.

The trimmed tree is:

    3
   /
  2
 /
1

So the output is:

[3, 2, None, 1]

First Thought: Traverse Every Node

A direct solution is to visit every node and decide whether it should stay.

If a node value is inside the range, keep it.

If a node value is outside the range, remove it and reconnect any valid descendants.

This works, but the reconnecting logic can feel messy if we ignore the BST property.

The BST property gives a cleaner rule.

Key Insight

In a binary search tree:

Relation Meaning
Left subtree All values are smaller than the current node
Right subtree All values are larger than the current node

So if:

root.val < low

then the root is too small.

Also, every node in its left subtree is even smaller, so the entire left subtree can be discarded.

Only the right subtree might contain valid nodes.

Similarly, if:

root.val > high

then the root is too large.

Also, every node in its right subtree is even larger, so the entire right subtree can be discarded.

Only the left subtree might contain valid nodes.

This pruning rule is the core of the solution.

Algorithm

Use recursion.

For a node root:

  1. If root is None, return None.
  2. If root.val < low, return the trimmed version of root.right.
  3. If root.val > high, return the trimmed version of root.left.
  4. Otherwise, the current node is valid:
    • Trim its left subtree.
    • Trim its right subtree.
    • Return the current node.

Correctness

The algorithm considers each node according to its value and the BST property.

If root.val < low, then the current node cannot remain. Since all nodes in its left subtree are less than root.val, they are also less than low. Therefore, none of them can remain. Any valid remaining node must come from the right subtree, so returning the trimmed right subtree is correct.

If root.val > high, then the current node cannot remain. Since all nodes in its right subtree are greater than root.val, they are also greater than high. Therefore, none of them can remain. Any valid remaining node must come from the left subtree, so returning the trimmed left subtree is correct.

If low <= root.val <= high, the current node must remain. Its left and right children may still contain invalid nodes, so the algorithm recursively trims both subtrees and attaches the results back to the current node.

The recursion applies the same logic to every reachable subtree. Therefore, all remaining nodes lie inside [low, high], all removed nodes lie outside the range, and the relative structure of remaining nodes is preserved.

Complexity

Metric Value Why
Time O(n) Each node is visited at most once
Space O(h) Recursion stack depends on tree height

Here, n is the number of nodes and h is the height of the tree.

In the worst case, a skewed tree has h = n.

Implementation

from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def trimBST(
        self,
        root: Optional[TreeNode],
        low: int,
        high: int,
    ) -> Optional[TreeNode]:
        if root is None:
            return None

        if root.val < low:
            return self.trimBST(root.right, low, high)

        if root.val > high:
            return self.trimBST(root.left, low, high)

        root.left = self.trimBST(root.left, low, high)
        root.right = self.trimBST(root.right, low, high)

        return root

Code Explanation

The base case handles an empty subtree:

if root is None:
    return None

If the current value is too small:

if root.val < low:
    return self.trimBST(root.right, low, high)

we discard the current node and its left subtree. Only the right subtree can contain values inside the range.

If the current value is too large:

if root.val > high:
    return self.trimBST(root.left, low, high)

we discard the current node and its right subtree. Only the left subtree can contain values inside the range.

If the current node is valid, we keep it and trim both children:

root.left = self.trimBST(root.left, low, high)
root.right = self.trimBST(root.right, low, high)

Finally, we return the current node:

return root

This returned node becomes the root of the trimmed subtree.

Testing

def serialize(root):
    if root is None:
        return None

    return [
        root.val,
        serialize(root.left),
        serialize(root.right),
    ]

def run_tests():
    s = Solution()

    # Tree:
    #   1
    #  / \
    # 0   2
    root = TreeNode(1, TreeNode(0), TreeNode(2))

    trimmed = s.trimBST(root, 1, 2)

    assert serialize(trimmed) == [
        1,
        None,
        [2, None, None],
    ]

    # Tree:
    #       3
    #      / \
    #     0   4
    #      \
    #       2
    #      /
    #     1
    root = TreeNode(3)
    root.left = TreeNode(0, None, TreeNode(2, TreeNode(1), None))
    root.right = TreeNode(4)

    trimmed = s.trimBST(root, 1, 3)

    assert serialize(trimmed) == [
        3,
        [
            2,
            [1, None, None],
            None,
        ],
        None,
    ]

    # Root is too small, so root changes.
    root = TreeNode(1, None, TreeNode(2))

    trimmed = s.trimBST(root, 2, 2)

    assert serialize(trimmed) == [
        2,
        None,
        None,
    ]

    # Root is too large, so root changes.
    root = TreeNode(3, TreeNode(2), None)

    trimmed = s.trimBST(root, 2, 2)

    assert serialize(trimmed) == [
        2,
        None,
        None,
    ]

    # Everything remains.
    root = TreeNode(2, TreeNode(1), TreeNode(3))

    trimmed = s.trimBST(root, 1, 3)

    assert serialize(trimmed) == [
        2,
        [1, None, None],
        [3, None, None],
    ]

    print("all tests passed")

run_tests()

Test meaning:

Test Why
[1,0,2], range [1,2] Removes a too-small left child
Larger sample Checks reconnecting valid descendants
Root too small Confirms returned root can change
Root too large Confirms trimming can return the left subtree
Full range Confirms valid nodes are preserved