LeetCode 3359 - Find Sorted Submatrices With Maximum Element at Most K

We are given a matrix grid with m rows and n columns, along with an integer k. The task is to count how many rectangular submatrices satisfy two conditions simultaneously: 1. Every value inside the submatrix must be at most k. 2.

LeetCode Problem 3359

Difficulty: 🔴 Hard
Topics: Array, Stack, Matrix, Monotonic Stack

Solution

Problem Understanding

We are given a matrix grid with m rows and n columns, along with an integer k. The task is to count how many rectangular submatrices satisfy two conditions simultaneously:

  1. Every value inside the submatrix must be at most k.
  2. Every row inside the submatrix must be sorted in non-increasing order from left to right.

A row is non-increasing if every element is greater than or equal to the next element. For example:

  • [5, 4, 4, 1] is valid
  • [5, 3, 6] is not valid because 3 < 6

A submatrix is defined by choosing a continuous range of rows and a continuous range of columns.

The important observation is that the sorting condition applies independently to each row of the chosen submatrix. The rows do not need to relate to one another vertically. We only care that inside each selected row segment, values never increase when moving to the right.

The constraints are large:

  • m, n <= 1000
  • m * n can be as large as 10^6

This immediately rules out brute-force enumeration of all submatrices, because the number of submatrices in a matrix is already O(m^2 * n^2).

The problem guarantees that all grid values are positive integers, and k is also positive. We only need to count valid submatrices, not construct them.

Several edge cases are important:

  • Cells larger than k cannot belong to any valid submatrix.
  • Single-cell submatrices are valid if their value is at most k.
  • Rows with increasing transitions break the non-increasing property.
  • Entire matrices may already satisfy the conditions, producing the maximum possible answer.
  • Sparse valid regions separated by invalid cells can easily break naive interval logic.

Approaches

Brute Force Approach

The brute-force solution considers every possible submatrix.

A submatrix is identified by:

  • top row
  • bottom row
  • left column
  • right column

For each submatrix, we would:

  1. Scan all elements to verify every value is <= k
  2. Check every row segment to ensure it is non-increasing

This approach is straightforward and obviously correct because it explicitly validates every candidate submatrix.

However, the complexity is enormous.

There are:

  • O(m^2) row ranges
  • O(n^2) column ranges

So there are O(m^2 * n^2) total submatrices.

Checking each one takes up to O(m * n) time in the worst case.

The total complexity becomes:

O(m^3 * n^3)

This is completely infeasible for matrices up to 1000 x 1000.

Key Insight

Instead of validating each submatrix independently, we can preprocess how far each row can extend to the right while maintaining validity.

For every cell (i, j), define:

width[i][j] = maximum valid width starting at (i, j)

A width is valid if:

  • all values are <= k
  • the row segment remains non-increasing

For example:

[5,4,3,2]

At position 0, width is 4.

At position 1, width is 3.

Once we know these widths, the problem becomes similar to counting submatrices in a histogram.

For every column start j, we move downward row by row while maintaining the minimum width seen so far. That minimum width determines how many valid submatrices end at the current row.

This transforms the problem into a much more efficient dynamic counting process.

Approach Comparison

Approach Time Complexity Space Complexity Notes
Brute Force O(m^3 * n^3) O(1) Enumerates and validates every submatrix
Optimal O(m * n) O(m * n) Precomputes valid widths and counts efficiently

Algorithm Walkthrough

Step 1: Compute valid row widths

Create a matrix width of size m x n.

For every row, process columns from right to left.

At each position (i, j):

  • If grid[i][j] > k, then width[i][j] = 0

  • Otherwise:

  • The current cell is valid

  • Check whether extending right preserves non-increasing order

If:

grid[i][j] >= grid[i][j+1]

then we can extend:

width[i][j] = width[i][j+1] + 1

Otherwise:

width[i][j] = 1

This preprocessing tells us the maximum valid horizontal span beginning at every cell.

Step 2: Count submatrices column by column

Now treat every cell as the top-left corner of possible submatrices.

For each column j:

  • Iterate downward through rows
  • Maintain the minimum width among consecutive rows

Suppose we are currently at row bottom.

We move upward:

min_width = minimum(width[top..bottom][j])

Every time we extend upward, the minimum width determines how many valid submatrices exist with:

  • left boundary at j
  • bottom row at bottom
  • top row at top

If min_width = w, then there are exactly w valid choices for the right boundary.

Add this to the answer.

Step 3: Use monotonic stack optimization

A naive upward scan would still be O(m^2 * n).

We optimize using a monotonic stack.

For each column:

  • Treat row widths like histogram heights
  • Maintain increasing widths in a stack
  • Efficiently compute the sum of minimums for all row intervals ending at the current row

This is the same technique used in several submatrix-counting problems.

The stack stores:

  • width value
  • count of consecutive rows represented

We also maintain a running contribution sum.

When a smaller width appears:

  • Pop larger widths
  • Remove their contributions
  • Merge counts

Then add the new contribution.

This allows every row to be pushed and popped at most once.

Why it works

For every fixed left boundary, each row contributes the maximum possible horizontal extension. A valid multi-row submatrix can only extend as far right as the smallest row width involved. Therefore, the number of valid submatrices ending at a row equals the sum of minimum widths across all vertical intervals. The monotonic stack computes these minimums efficiently and exactly once per interval.

Python Solution

from typing import List

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        m, n = len(grid), len(grid[0])

        width = [[0] * n for _ in range(m)]

        # Precompute maximum valid widths
        for i in range(m):
            for j in range(n - 1, -1, -1):
                if grid[i][j] > k:
                    width[i][j] = 0
                elif j == n - 1:
                    width[i][j] = 1
                elif grid[i][j] >= grid[i][j + 1]:
                    width[i][j] = width[i][j + 1] + 1
                else:
                    width[i][j] = 1

        answer = 0

        # Process each column independently
        for j in range(n):
            stack = []
            current_sum = 0

            for i in range(m):
                w = width[i][j]
                count = 1

                while stack and stack[-1][0] >= w:
                    prev_width, prev_count = stack.pop()
                    current_sum -= prev_width * prev_count
                    count += prev_count

                current_sum += w * count
                stack.append((w, count))

                answer += current_sum

        return answer

The implementation has two major phases.

The first phase computes the width matrix. Processing rows from right to left allows us to determine how far each position can extend while preserving both validity conditions. If a cell exceeds k, it cannot participate in any valid submatrix. Otherwise, we check whether the next value to the right preserves non-increasing order.

The second phase processes each column independently using a monotonic stack. The stack maintains increasing widths. For every row, we efficiently compute the sum of minimum widths across all vertical intervals ending at that row. Each minimum width corresponds to the number of valid right boundaries.

The running variable current_sum stores the total contribution of all intervals ending at the current row. Adding it directly to the answer accumulates all valid submatrices.

Go Solution

func countSubmatrices(grid [][]int, k int) int64 {
    m := len(grid)
    n := len(grid[0])

    width := make([][]int, m)
    for i := range width {
        width[i] = make([]int, n)
    }

    // Precompute widths
    for i := 0; i < m; i++ {
        for j := n - 1; j >= 0; j-- {
            if grid[i][j] > k {
                width[i][j] = 0
            } else if j == n-1 {
                width[i][j] = 1
            } else if grid[i][j] >= grid[i][j+1] {
                width[i][j] = width[i][j+1] + 1
            } else {
                width[i][j] = 1
            }
        }
    }

    type Pair struct {
        width int
        count int
    }

    var answer int64 = 0

    for j := 0; j < n; j++ {
        stack := []Pair{}
        currentSum := 0

        for i := 0; i < m; i++ {
            w := width[i][j]
            count := 1

            for len(stack) > 0 && stack[len(stack)-1].width >= w {
                top := stack[len(stack)-1]
                stack = stack[:len(stack)-1]

                currentSum -= top.width * top.count
                count += top.count
            }

            currentSum += w * count
            stack = append(stack, Pair{w, count})

            answer += int64(currentSum)
        }
    }

    return answer
}

The Go implementation follows exactly the same algorithmic structure as the Python version.

The main Go-specific detail is integer overflow handling. Since the number of submatrices can be very large, the final answer uses int64.

Slices are used as stacks, with manual push and pop operations. The Pair struct stores both the width value and the number of merged intervals represented by that width.

Worked Examples

Example 1

Input:

grid =
[
  [4,3,2,1],
  [8,7,6,1]
]
k = 3

Step 1: Build width matrix

Cells greater than 3 become 0.

Row 0: [0,3,2,1]
Row 1: [0,0,0,1]

Explanation:

  • 3 >= 2 >= 1, so widths expand
  • 4, 8, 7, 6 exceed k

Final width matrix:

Row Values
0 [0, 3, 2, 1]
1 [0, 0, 0, 1]

Step 2: Process columns

Column 0

Widths: [0, 0]

No valid submatrices added.

Column 1

Widths: [3, 0]

Processing:

Row Width Current Sum Added
0 3 3 3
1 0 0 0

Contribution: 3

Column 2

Widths: [2, 0]

Contribution: 2

Column 3

Widths: [1, 1]

Processing:

Row Width Current Sum Added
0 1 1 1
1 1 2 2

Contribution: 3

Total:

3 + 2 + 3 = 8

Example 2

Input:

grid =
[
 [1,1,1],
 [1,1,1],
 [1,1,1]
]
k = 1

Every row is non-increasing and every value is valid.

Width matrix:

Row Values
0 [3,2,1]
1 [3,2,1]
2 [3,2,1]

All submatrices are valid.

A 3 x 3 matrix contains:

(3 * 4 / 2)^2 = 36

submatrices.

Answer:

36

Example 3

Input:

grid = [[1]]
k = 1

Width matrix:

[[1]]

Only one valid submatrix exists.

Answer:

1

Complexity Analysis

Measure Complexity Explanation
Time O(m * n) Each cell is processed a constant number of times
Space O(m * n) The width matrix stores one integer per cell

The preprocessing step scans the matrix once. During the monotonic stack phase, every row index is pushed and popped at most once per column. Therefore, the total stack work remains linear.

Test Cases

from typing import List

class Solution:
    def countSubmatrices(self, grid: List[List[int]], k: int) -> int:
        m, n = len(grid), len(grid[0])

        width = [[0] * n for _ in range(m)]

        for i in range(m):
            for j in range(n - 1, -1, -1):
                if grid[i][j] > k:
                    width[i][j] = 0
                elif j == n - 1:
                    width[i][j] = 1
                elif grid[i][j] >= grid[i][j + 1]:
                    width[i][j] = width[i][j + 1] + 1
                else:
                    width[i][j] = 1

        answer = 0

        for j in range(n):
            stack = []
            current_sum = 0

            for i in range(m):
                w = width[i][j]
                count = 1

                while stack and stack[-1][0] >= w:
                    prev_width, prev_count = stack.pop()
                    current_sum -= prev_width * prev_count
                    count += prev_count

                current_sum += w * count
                stack.append((w, count))

                answer += current_sum

        return answer

sol = Solution()

assert sol.countSubmatrices([[4,3,2,1],[8,7,6,1]], 3) == 8  # provided example
assert sol.countSubmatrices([[1,1,1],[1,1,1],[1,1,1]], 1) == 36  # all valid
assert sol.countSubmatrices([[1]], 1) == 1  # single cell valid
assert sol.countSubmatrices([[5]], 1) == 0  # single cell invalid
assert sol.countSubmatrices([[3,2,1]], 3) == 6  # single decreasing row
assert sol.countSubmatrices([[1,2,3]], 3) == 6  # every segment still non-increasing fails except singles
assert sol.countSubmatrices([[3,2],[2,1]], 3) == 9  # fully valid 2x2 matrix
assert sol.countSubmatrices([[10,9],[8,7]], 5) == 0  # all values exceed k
assert sol.countSubmatrices([[2,2,2]], 2) == 6  # equal elements allowed
assert sol.countSubmatrices([[5,4,3],[3,2,1]], 5) == 18  # all submatrices valid

Test Summary

Test Why
[[4,3,2,1],[8,7,6,1]], k=3 Validates mixed valid and invalid regions
3x3 all ones Validates maximum valid coverage
[[1]] Smallest valid input
[[5]], k=1 Smallest invalid input
Single decreasing row Tests horizontal extension
Single increasing row Tests row-order constraint handling
Fully valid 2x2 matrix Tests rectangular counting
All values exceed k Ensures invalid cells contribute zero
Equal adjacent values Confirms non-increasing allows equality
Entire matrix valid Stress test for full coverage

Edge Cases

One important edge case occurs when every value exceeds k. In this scenario, every width becomes zero during preprocessing. A buggy implementation might still count empty intervals or fail to clear stack contributions properly. The current solution handles this naturally because zero widths contribute nothing to the running sum.

Another tricky case is rows containing equal adjacent values. Since the condition is non-increasing rather than strictly decreasing, equal values are valid. The implementation uses >= when extending widths, ensuring segments like [5,5,5] are counted correctly.

A third important edge case is alternating valid and invalid regions inside a row. For example:

[5,4,10,3,2]

with k = 5.

The value 10 splits the row into two independent valid regions. The preprocessing phase correctly resets width to zero at invalid cells, preventing submatrices from incorrectly crossing through invalid positions.

A final subtle case involves very tall matrices where widths shrink rapidly between rows. The monotonic stack logic is essential here. Without careful removal of larger widths when a smaller width appears, the algorithm would overcount submatrices whose rows cannot all extend equally far to the right.