LeetCode 304: Range Sum Query 2D - Immutable
A clear explanation of Range Sum Query 2D - Immutable using a 2D prefix sum matrix for constant-time rectangle queries.
Problem Restatement
We are given a 2D integer matrix.
We need to implement a class called NumMatrix that supports rectangle sum queries.
The class supports two operations:
| Operation | Meaning |
|---|---|
NumMatrix(matrix) |
Initialize the object |
sumRegion(row1, col1, row2, col2) |
Return the sum of all elements inside the rectangle |
The rectangle is inclusive, so all cells from:
(row1, col1)
through:
(row2, col2)
must be included.
The matrix never changes after construction. The official problem expects many rectangle queries on the same matrix.
Input and Output
| Item | Meaning |
|---|---|
| Input | A 2D integer matrix |
| Query input | row1, col1, row2, col2 |
| Output | Sum of all values inside the rectangle |
| Property | Matrix is immutable |
| Goal | Efficient repeated queries |
Class shape:
class NumMatrix:
def __init__(self, matrix: list[list[int]]):
...
def sumRegion(
self,
row1: int,
col1: int,
row2: int,
col2: int,
) -> int:
...
Examples
Example matrix:
matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5],
]
Create the object:
obj = NumMatrix(matrix)
Query:
obj.sumRegion(2, 1, 4, 3)
This rectangle contains:
2 0 1
1 0 1
0 3 0
Sum:
2 + 0 + 1 + 1 + 0 + 1 + 0 + 3 + 0 = 8
Output:
8
Another query:
obj.sumRegion(1, 1, 2, 2)
Rectangle:
6 3
2 0
Sum:
11
Another query:
obj.sumRegion(1, 2, 2, 4)
Rectangle:
3 2 1
0 1 5
Sum:
12
First Thought: Sum Every Rectangle Directly
The simplest solution loops through every cell inside the requested rectangle.
total = 0
for r in range(row1, row2 + 1):
for c in range(col1, col2 + 1):
total += matrix[r][c]
This works.
But one query may scan a large part of the matrix.
If the matrix has m rows and n columns, one query can cost:
O(mn)
The problem expects many queries, so repeated scanning becomes expensive.
Key Insight
This problem is the 2D version of prefix sums.
Instead of storing prefix sums for a 1D array, we store prefix sums for rectangles.
Define:
prefix[r][c]
to mean:
Sum of all cells inside the rectangle from
(0, 0)to(r - 1, c - 1).
So prefix has one extra row and one extra column filled with zeros.
This extra border removes edge-case handling.
Building the 2D Prefix Sum
For every cell:
matrix[r][c]
we compute:
prefix[r + 1][c + 1]
using:
$$ P(r,c)=P(r-1,c)+P(r,c-1)-P(r-1,c-1)+A(r,c) $$
In code form:
prefix[r + 1][c + 1] = (
prefix[r][c + 1]
+ prefix[r + 1][c]
- prefix[r][c]
+ matrix[r][c]
)
Why subtraction?
Both:
prefix[r][c + 1]
and:
prefix[r + 1][c]
contain the overlapping top-left rectangle:
prefix[r][c]
So that overlap gets counted twice.
We subtract it once to correct the total.
Rectangle Query Formula
Suppose we want the rectangle:
(row1, col1)
through:
(row2, col2)
Start with the large rectangle from (0, 0) to (row2, col2).
Then remove:
- The rows above
row1 - The columns left of
col1
But the top-left overlap gets removed twice, so we add it back once.
The formula becomes:
$$ S=P(r_2+1,c_2+1)-P(r_1,c_2+1)-P(r_2+1,c_1)+P(r_1,c_1) $$
Code:
return (
prefix[row2 + 1][col2 + 1]
- prefix[row1][col2 + 1]
- prefix[row2 + 1][col1]
+ prefix[row1][col1]
)
Algorithm
During initialization:
- Create a prefix matrix of size
(m + 1) x (n + 1). - Fill it using the 2D prefix formula.
During sumRegion:
- Read four prefix values.
- Use inclusion-exclusion to compute the rectangle sum.
Correctness
Let:
prefix[r][c]
represent the sum of all matrix cells in the rectangle:
(0, 0) through (r - 1, c - 1)
When building the prefix matrix, we compute each value by combining:
- The rectangle above
- The rectangle to the left
- Removing the overlap counted twice
- Adding the current matrix cell
So every prefix[r][c] stores the correct rectangle sum.
Now consider a query rectangle from:
(row1, col1)
through:
(row2, col2)
prefix[row2 + 1][col2 + 1] contains the entire area from (0, 0) to (row2, col2).
This includes extra cells:
- Rows above
row1 - Columns left of
col1
Subtracting:
prefix[row1][col2 + 1]
removes the rows above.
Subtracting:
prefix[row2 + 1][col1]
removes the left columns.
But the top-left overlap gets removed twice, so we add back:
prefix[row1][col1]
The remaining value is exactly the requested rectangle sum.
Therefore, the algorithm always returns the correct answer.
Complexity
| Operation | Time | Space | Why |
|---|---|---|---|
| Constructor | O(mn) |
O(mn) |
Build the full prefix matrix |
sumRegion |
O(1) |
O(1) |
Four reads and arithmetic operations |
The preprocessing cost happens only once.
Implementation
class NumMatrix:
def __init__(self, matrix: list[list[int]]):
if not matrix or not matrix[0]:
self.prefix = [[0]]
return
m = len(matrix)
n = len(matrix[0])
self.prefix = [[0] * (n + 1) for _ in range(m + 1)]
for r in range(m):
for c in range(n):
self.prefix[r + 1][c + 1] = (
self.prefix[r][c + 1]
+ self.prefix[r + 1][c]
- self.prefix[r][c]
+ matrix[r][c]
)
def sumRegion(
self,
row1: int,
col1: int,
row2: int,
col2: int,
) -> int:
return (
self.prefix[row2 + 1][col2 + 1]
- self.prefix[row1][col2 + 1]
- self.prefix[row2 + 1][col1]
+ self.prefix[row1][col1]
)
Code Explanation
We first handle the empty matrix case.
if not matrix or not matrix[0]:
self.prefix = [[0]]
return
Then create the prefix matrix:
self.prefix = [[0] * (n + 1) for _ in range(m + 1)]
The extra row and column simplify boundaries.
Now fill the prefix matrix.
for r in range(m):
for c in range(n):
Each prefix value combines:
- Top rectangle
- Left rectangle
- Remove overlap
- Add current value
self.prefix[r + 1][c + 1] = (
self.prefix[r][c + 1]
+ self.prefix[r + 1][c]
- self.prefix[r][c]
+ matrix[r][c]
)
For queries:
return (
self.prefix[row2 + 1][col2 + 1]
- self.prefix[row1][col2 + 1]
- self.prefix[row2 + 1][col1]
+ self.prefix[row1][col1]
)
This uses inclusion-exclusion to isolate exactly the requested rectangle.
Testing
def run_tests():
matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5],
]
obj = NumMatrix(matrix)
assert obj.sumRegion(2, 1, 4, 3) == 8
assert obj.sumRegion(1, 1, 2, 2) == 11
assert obj.sumRegion(1, 2, 2, 4) == 12
single = NumMatrix([[5]])
assert single.sumRegion(0, 0, 0, 0) == 5
row_matrix = NumMatrix([[1, 2, 3]])
assert row_matrix.sumRegion(0, 0, 0, 2) == 6
col_matrix = NumMatrix([
[1],
[2],
[3],
])
assert col_matrix.sumRegion(0, 0, 2, 0) == 6
negatives = NumMatrix([
[-1, -2],
[-3, -4],
])
assert negatives.sumRegion(0, 0, 1, 1) == -10
print("all tests passed")
run_tests()
Test meaning:
| Test | Why |
|---|---|
| Official example | Standard rectangle queries |
| Single cell | Smallest matrix |
| One-row matrix | Horizontal range handling |
| One-column matrix | Vertical range handling |
| Negative numbers | Confirms arithmetic correctness |
| Full matrix query | Checks large rectangle sums |