LeetCode 2916 - Subarrays Distinct Element Sum of Squares II
The problem asks us to consider every possible non-empty subarray of the input array nums. For each subarray, we compute how many distinct values appear inside it. We then square that distinct count, and finally sum the squared values across all subarrays.
Difficulty: 🔴 Hard
Topics: Array, Dynamic Programming, Binary Indexed Tree, Segment Tree
Solution
Problem Understanding
The problem asks us to consider every possible non-empty subarray of the input array nums. For each subarray, we compute how many distinct values appear inside it. We then square that distinct count, and finally sum the squared values across all subarrays.
Formally, if a subarray contains k distinct elements, then its contribution to the answer is:
$k^2$
The final answer is the sum of these values for every subarray, modulo:
$10^9 + 7$
For example, if nums = [1,2,1], then the subarray [1,2] contains two distinct values, so it contributes:
$2^2 = 4$
The constraints are extremely important:
nums.lengthcan be as large as10^5- Values in the array can also be as large as
10^5
These limits immediately rule out any solution that explicitly enumerates all subarrays and recomputes distinct counts from scratch. There are:
$\frac{n(n+1)}{2}$
subarrays in total, which is already about 5 * 10^9 when n = 10^5.
The key challenge is therefore not generating subarrays, but efficiently tracking how distinct counts change as we extend subarrays.
Several edge cases are important:
- Arrays where all values are identical, because distinct counts never grow beyond
1 - Arrays where all values are unique, because distinct counts grow as fast as possible
- Repeated patterns like
[1,2,1,2,1,2], because the same value repeatedly changes many subarrays simultaneously - Large arrays near the upper constraint limit, where only near linear or
O(n log n)solutions are feasible
Approaches
Brute Force Approach
The most direct solution is to generate every subarray and compute its number of distinct elements.
For every starting index i, we expand the subarray one element at a time toward the right. We maintain a hash set containing the distinct elements currently inside the subarray. Every time we extend the subarray to index j, we insert nums[j] into the set, compute the set size, square it, and add it to the answer.
This approach is correct because every subarray is examined exactly once, and the set accurately tracks the distinct values inside that subarray.
However, the time complexity is far too large. There are O(n^2) subarrays, and although each insertion into the hash set is efficient, we still process every subarray individually. With n = 10^5, this becomes completely infeasible.
Key Insight for the Optimal Solution
The crucial observation is that when we append a new element nums[i], we do not need to recompute distinct counts for all subarrays from scratch.
Suppose we know the distinct counts of all subarrays ending at position i - 1. When we extend those subarrays with nums[i], only some of them gain one additional distinct element.
Specifically:
- Any subarray whose previous occurrence of
nums[i]lies outside the subarray gains a new distinct element - Any subarray already containing
nums[i]keeps the same distinct count
This means we can think of the update as a range increment problem.
If the previous occurrence of nums[i] was at index prev, then every subarray starting in the range:
$[prev+1,\ i]$
gets its distinct count increased by 1.
We therefore need a data structure that supports:
- Range increment updates
- Efficient querying of sums
- Efficient maintenance of squared contributions
A segment tree with lazy propagation is ideal for this task.
Approach Comparison
| Approach | Time Complexity | Space Complexity | Notes |
|---|---|---|---|
| Brute Force | O(n²) | O(n) | Enumerates every subarray and tracks distinct values with a set |
| Optimal | O(n log n) | O(n) | Uses segment tree with lazy propagation and range updates |
Algorithm Walkthrough
Step 1: Define the DP Interpretation
Let:
$dp[l]$
represent the distinct count of the subarray:
$nums[l..i]$
for the current ending index i.
As we move from left to right, we continuously update these distinct counts.
Step 2: Track Previous Occurrences
We maintain a hash map:
last_seen[value] = most recent index
When processing nums[i], we look up its previous occurrence.
If the previous occurrence is prev, then all subarrays starting after prev gain one new distinct element when extended to i.
Step 3: Convert the Problem Into Range Updates
For every starting index:
l in [prev + 1, i]
the distinct count increases by 1.
So we perform a range increment on that interval.
Step 4: Maintain Both Sum and Sum of Squares
The answer requires squared distinct counts.
Suppose a value changes from:
$x$
to:
$x+1$
Then its square changes by:
$(x+1)^2 - x^2 = 2x + 1$
Therefore, the segment tree stores:
- The sum of distinct counts
- The sum of squared distinct counts
When applying a range increment, both quantities can be updated efficiently using lazy propagation.
Step 5: Add Contributions to the Final Answer
After processing index i, every subarray ending at i has an updated distinct count.
The segment tree root stores the sum of squared distinct counts for all subarrays ending at i.
We add that value to the global answer.
Step 6: Continue Until the End
We repeat the process for every index in the array.
At the end, the accumulated total is the required answer.
Why it works
The invariant is that after processing position i, the segment tree stores the distinct counts for every subarray ending at i.
When a new value appears at position i, only subarrays that do not already contain that value gain one additional distinct element. The interval [prev + 1, i] precisely captures those subarrays.
Because the segment tree correctly applies these range increments and maintains both sums and squared sums, every subarray contribution is counted exactly once.
Python Solution
from typing import List
MOD = 10**9 + 7
class SegmentTree:
def __init__(self, n: int):
self.n = n
size = 4 * n
self.sum_vals = [0] * size
self.sum_sq = [0] * size
self.lazy = [0] * size
def apply(self, node: int, left: int, right: int, val: int) -> None:
length = right - left + 1
self.sum_sq[node] = (
self.sum_sq[node]
+ 2 * val * self.sum_vals[node]
+ length * val * val
) % MOD
self.sum_vals[node] = (
self.sum_vals[node]
+ length * val
) % MOD
self.lazy[node] += val
def push(self, node: int, left: int, right: int) -> None:
if self.lazy[node] == 0:
return
mid = (left + right) // 2
self.apply(node * 2, left, mid, self.lazy[node])
self.apply(node * 2 + 1, mid + 1, right, self.lazy[node])
self.lazy[node] = 0
def update(
self,
node: int,
left: int,
right: int,
ql: int,
qr: int,
val: int
) -> None:
if ql <= left and right <= qr:
self.apply(node, left, right, val)
return
self.push(node, left, right)
mid = (left + right) // 2
if ql <= mid:
self.update(node * 2, left, mid, ql, qr, val)
if qr > mid:
self.update(node * 2 + 1, mid + 1, right, ql, qr, val)
self.sum_vals[node] = (
self.sum_vals[node * 2]
+ self.sum_vals[node * 2 + 1]
) % MOD
self.sum_sq[node] = (
self.sum_sq[node * 2]
+ self.sum_sq[node * 2 + 1]
) % MOD
class Solution:
def sumCounts(self, nums: List[int]) -> int:
n = len(nums)
seg = SegmentTree(n)
last_seen = {}
answer = 0
for i, value in enumerate(nums):
prev = last_seen.get(value, -1)
seg.update(
1,
0,
n - 1,
prev + 1,
i,
1
)
answer = (answer + seg.sum_sq[1]) % MOD
last_seen[value] = i
return answer
The implementation follows the algorithm directly.
The SegmentTree class maintains three arrays:
sum_vals, storing the sum of distinct countssum_sq, storing the sum of squared distinct countslazy, storing pending range increments
The apply function updates an entire segment efficiently without descending into children immediately. The update formula comes from the algebraic identity:
$(x+v)^2 = x^2 + 2vx + v^2$
The push method propagates lazy updates downward only when necessary.
Inside the main solution loop, we locate the previous occurrence of the current value. Every starting position after that occurrence gains one additional distinct element, so we increment the range [prev + 1, i].
After the update, the segment tree root stores the total squared contribution of all subarrays ending at i, which we add to the answer.
Go Solution
package main
const MOD int64 = 1_000_000_007
type SegmentTree struct {
sumVals []int64
sumSq []int64
lazy []int64
}
func NewSegmentTree(n int) *SegmentTree {
size := 4 * n
return &SegmentTree{
sumVals: make([]int64, size),
sumSq: make([]int64, size),
lazy: make([]int64, size),
}
}
func (st *SegmentTree) apply(node, left, right int, val int64) {
length := int64(right - left + 1)
st.sumSq[node] = (
st.sumSq[node] +
2*val*st.sumVals[node] +
length*val*val,
) % MOD
st.sumVals[node] = (
st.sumVals[node] +
length*val,
) % MOD
st.lazy[node] += val
}
func (st *SegmentTree) push(node, left, right int) {
if st.lazy[node] == 0 {
return
}
mid := (left + right) / 2
st.apply(node*2, left, mid, st.lazy[node])
st.apply(node*2+1, mid+1, right, st.lazy[node])
st.lazy[node] = 0
}
func (st *SegmentTree) update(
node, left, right,
ql, qr int,
val int64,
) {
if ql <= left && right <= qr {
st.apply(node, left, right, val)
return
}
st.push(node, left, right)
mid := (left + right) / 2
if ql <= mid {
st.update(node*2, left, mid, ql, qr, val)
}
if qr > mid {
st.update(node*2+1, mid+1, right, ql, qr, val)
}
st.sumVals[node] = (
st.sumVals[node*2] +
st.sumVals[node*2+1],
) % MOD
st.sumSq[node] = (
st.sumSq[node*2] +
st.sumSq[node*2+1],
) % MOD
}
func sumCounts(nums []int) int {
n := len(nums)
seg := NewSegmentTree(n)
lastSeen := map[int]int{}
var answer int64 = 0
for i, value := range nums {
prev := -1
if idx, exists := lastSeen[value]; exists {
prev = idx
}
seg.update(
1,
0,
n-1,
prev+1,
i,
1,
)
answer = (answer + seg.sumSq[1]) % MOD
lastSeen[value] = i
}
return int(answer)
}
The Go implementation mirrors the Python version closely.
The main difference is integer handling. Because intermediate values can become very large, the implementation uses int64 throughout the segment tree. The final answer is converted back to int before returning.
Go slices are used instead of Python lists, and maps replace Python dictionaries.
Worked Examples
Example 1
nums = [1,2,1]
Initial state:
| Index | Distinct Count |
|---|---|
| 0 | 0 |
| 1 | 0 |
| 2 | 0 |
Processing index 0, value = 1
Previous occurrence:
prev = -1
Update range:
[0, 0]
Distinct counts become:
| Start Index | Subarray | Distinct Count |
|---|---|---|
| 0 | [1] | 1 |
Squared contribution:
$1^2 = 1$
Running answer:
1
Processing index 1, value = 2
Previous occurrence:
prev = -1
Update range:
[0, 1]
Distinct counts:
| Start Index | Subarray | Distinct Count |
|---|---|---|
| 0 | [1,2] | 2 |
| 1 | [2] | 1 |
Squared contributions:
$2^2 + 1^2 = 5$
Running answer:
1 + 5 = 6
Processing index 2, value = 1
Previous occurrence:
prev = 0
Update range:
[1, 2]
Distinct counts:
| Start Index | Subarray | Distinct Count |
|---|---|---|
| 0 | [1,2,1] | 2 |
| 1 | [2,1] | 2 |
| 2 | [1] | 1 |
Squared contributions:
$2^2 + 2^2 + 1^2 = 9$
Final answer:
6 + 9 = 15
Example 2
nums = [2,2]
Processing index 0
Distinct counts:
| Subarray | Count | Square |
|---|---|---|
| [2] | 1 | 1 |
Running answer:
1
Processing index 1
Only subarrays starting after the previous occurrence gain a new distinct element.
Updated subarrays:
| Subarray | Count | Square |
|---|---|---|
| [2,2] | 1 | 1 |
| [2] | 1 | 1 |
Contribution:
2
Final answer:
3
Complexity Analysis
| Measure | Complexity | Explanation |
|---|---|---|
| Time | O(n log n) | Each index performs one segment tree range update |
| Space | O(n) | Segment tree arrays and hash map storage |
The segment tree supports each range update in O(log n) time because lazy propagation avoids visiting every element individually. Since we process n elements total, the overall complexity becomes O(n log n).
The memory usage is linear because the segment tree requires O(n) storage and the last_seen map stores at most one entry per distinct value.
Test Cases
sol = Solution()
assert sol.sumCounts([1, 2, 1]) == 15
# Basic mixed duplicates example
assert sol.sumCounts([2, 2]) == 3
# All elements identical
assert sol.sumCounts([1]) == 1
# Single element array
assert sol.sumCounts([1, 2, 3]) == 20
# All elements distinct
assert sol.sumCounts([1, 1, 1, 1]) == 10
# Every subarray has distinct count 1
assert sol.sumCounts([1, 2, 1, 2]) == 28
# Alternating repeated pattern
assert sol.sumCounts([5, 4, 3, 2, 1]) == 105
# Strictly decreasing distinct values
assert sol.sumCounts([1, 2, 3, 1]) == 38
# Repeat after long gap
assert sol.sumCounts([1, 2, 2, 1]) == 18
# Multiple overlapping duplicate regions
assert sol.sumCounts([100000]) == 1
# Maximum value constraint with minimal length
| Test | Why |
|---|---|
[1,2,1] |
Validates the main example |
[2,2] |
Ensures duplicate handling works |
[1] |
Smallest possible input |
[1,2,3] |
All elements distinct |
[1,1,1,1] |
Distinct count never increases |
[1,2,1,2] |
Repeated alternating pattern |
[5,4,3,2,1] |
Maximum distinct growth |
[1,2,3,1] |
Tests previous occurrence logic |
[1,2,2,1] |
Tests overlapping duplicate effects |
[100000] |
Validates upper bound element values |
Edge Cases
One important edge case is when all elements are identical, such as [7,7,7,7]. In this scenario, every subarray always has exactly one distinct element. A buggy implementation might accidentally increment distinct counts repeatedly for the same value. The last_seen logic prevents this by only incrementing subarrays starting after the previous occurrence.
Another important case is when all elements are distinct, such as [1,2,3,4,5]. Here, every new element increases the distinct count for every active subarray. This stresses the range update mechanism because the updated interval grows continuously. The segment tree handles this efficiently with lazy propagation.
A third tricky case involves overlapping duplicates, such as [1,2,1,2,1]. Multiple values repeatedly reappear, and different subsets of subarrays must be updated each time. The interval [prev + 1, i] precisely captures which subarrays gain a new distinct element, ensuring no subarray is updated too many times or too few times.
Finally, arrays near the maximum constraint size require careful performance considerations. Any algorithm that explicitly processes all subarrays will time out. The segment tree solution avoids this by reducing each iteration to logarithmic work.