LeetCode 2916 - Subarrays Distinct Element Sum of Squares II

The problem asks us to consider every possible non-empty subarray of the input array nums. For each subarray, we compute how many distinct values appear inside it. We then square that distinct count, and finally sum the squared values across all subarrays.

LeetCode Problem 2916

Difficulty: 🔴 Hard
Topics: Array, Dynamic Programming, Binary Indexed Tree, Segment Tree

Solution

Problem Understanding

The problem asks us to consider every possible non-empty subarray of the input array nums. For each subarray, we compute how many distinct values appear inside it. We then square that distinct count, and finally sum the squared values across all subarrays.

Formally, if a subarray contains k distinct elements, then its contribution to the answer is:

$k^2$

The final answer is the sum of these values for every subarray, modulo:

$10^9 + 7$

For example, if nums = [1,2,1], then the subarray [1,2] contains two distinct values, so it contributes:

$2^2 = 4$

The constraints are extremely important:

  • nums.length can be as large as 10^5
  • Values in the array can also be as large as 10^5

These limits immediately rule out any solution that explicitly enumerates all subarrays and recomputes distinct counts from scratch. There are:

$\frac{n(n+1)}{2}$

subarrays in total, which is already about 5 * 10^9 when n = 10^5.

The key challenge is therefore not generating subarrays, but efficiently tracking how distinct counts change as we extend subarrays.

Several edge cases are important:

  • Arrays where all values are identical, because distinct counts never grow beyond 1
  • Arrays where all values are unique, because distinct counts grow as fast as possible
  • Repeated patterns like [1,2,1,2,1,2], because the same value repeatedly changes many subarrays simultaneously
  • Large arrays near the upper constraint limit, where only near linear or O(n log n) solutions are feasible

Approaches

Brute Force Approach

The most direct solution is to generate every subarray and compute its number of distinct elements.

For every starting index i, we expand the subarray one element at a time toward the right. We maintain a hash set containing the distinct elements currently inside the subarray. Every time we extend the subarray to index j, we insert nums[j] into the set, compute the set size, square it, and add it to the answer.

This approach is correct because every subarray is examined exactly once, and the set accurately tracks the distinct values inside that subarray.

However, the time complexity is far too large. There are O(n^2) subarrays, and although each insertion into the hash set is efficient, we still process every subarray individually. With n = 10^5, this becomes completely infeasible.

Key Insight for the Optimal Solution

The crucial observation is that when we append a new element nums[i], we do not need to recompute distinct counts for all subarrays from scratch.

Suppose we know the distinct counts of all subarrays ending at position i - 1. When we extend those subarrays with nums[i], only some of them gain one additional distinct element.

Specifically:

  • Any subarray whose previous occurrence of nums[i] lies outside the subarray gains a new distinct element
  • Any subarray already containing nums[i] keeps the same distinct count

This means we can think of the update as a range increment problem.

If the previous occurrence of nums[i] was at index prev, then every subarray starting in the range:

$[prev+1,\ i]$

gets its distinct count increased by 1.

We therefore need a data structure that supports:

  • Range increment updates
  • Efficient querying of sums
  • Efficient maintenance of squared contributions

A segment tree with lazy propagation is ideal for this task.

Approach Comparison

Approach Time Complexity Space Complexity Notes
Brute Force O(n²) O(n) Enumerates every subarray and tracks distinct values with a set
Optimal O(n log n) O(n) Uses segment tree with lazy propagation and range updates

Algorithm Walkthrough

Step 1: Define the DP Interpretation

Let:

$dp[l]$

represent the distinct count of the subarray:

$nums[l..i]$

for the current ending index i.

As we move from left to right, we continuously update these distinct counts.

Step 2: Track Previous Occurrences

We maintain a hash map:

last_seen[value] = most recent index

When processing nums[i], we look up its previous occurrence.

If the previous occurrence is prev, then all subarrays starting after prev gain one new distinct element when extended to i.

Step 3: Convert the Problem Into Range Updates

For every starting index:

l in [prev + 1, i]

the distinct count increases by 1.

So we perform a range increment on that interval.

Step 4: Maintain Both Sum and Sum of Squares

The answer requires squared distinct counts.

Suppose a value changes from:

$x$

to:

$x+1$

Then its square changes by:

$(x+1)^2 - x^2 = 2x + 1$

Therefore, the segment tree stores:

  • The sum of distinct counts
  • The sum of squared distinct counts

When applying a range increment, both quantities can be updated efficiently using lazy propagation.

Step 5: Add Contributions to the Final Answer

After processing index i, every subarray ending at i has an updated distinct count.

The segment tree root stores the sum of squared distinct counts for all subarrays ending at i.

We add that value to the global answer.

Step 6: Continue Until the End

We repeat the process for every index in the array.

At the end, the accumulated total is the required answer.

Why it works

The invariant is that after processing position i, the segment tree stores the distinct counts for every subarray ending at i.

When a new value appears at position i, only subarrays that do not already contain that value gain one additional distinct element. The interval [prev + 1, i] precisely captures those subarrays.

Because the segment tree correctly applies these range increments and maintains both sums and squared sums, every subarray contribution is counted exactly once.

Python Solution

from typing import List

MOD = 10**9 + 7

class SegmentTree:
    def __init__(self, n: int):
        self.n = n
        size = 4 * n

        self.sum_vals = [0] * size
        self.sum_sq = [0] * size
        self.lazy = [0] * size

    def apply(self, node: int, left: int, right: int, val: int) -> None:
        length = right - left + 1

        self.sum_sq[node] = (
            self.sum_sq[node]
            + 2 * val * self.sum_vals[node]
            + length * val * val
        ) % MOD

        self.sum_vals[node] = (
            self.sum_vals[node]
            + length * val
        ) % MOD

        self.lazy[node] += val

    def push(self, node: int, left: int, right: int) -> None:
        if self.lazy[node] == 0:
            return

        mid = (left + right) // 2

        self.apply(node * 2, left, mid, self.lazy[node])
        self.apply(node * 2 + 1, mid + 1, right, self.lazy[node])

        self.lazy[node] = 0

    def update(
        self,
        node: int,
        left: int,
        right: int,
        ql: int,
        qr: int,
        val: int
    ) -> None:
        if ql <= left and right <= qr:
            self.apply(node, left, right, val)
            return

        self.push(node, left, right)

        mid = (left + right) // 2

        if ql <= mid:
            self.update(node * 2, left, mid, ql, qr, val)

        if qr > mid:
            self.update(node * 2 + 1, mid + 1, right, ql, qr, val)

        self.sum_vals[node] = (
            self.sum_vals[node * 2]
            + self.sum_vals[node * 2 + 1]
        ) % MOD

        self.sum_sq[node] = (
            self.sum_sq[node * 2]
            + self.sum_sq[node * 2 + 1]
        ) % MOD

class Solution:
    def sumCounts(self, nums: List[int]) -> int:
        n = len(nums)

        seg = SegmentTree(n)
        last_seen = {}

        answer = 0

        for i, value in enumerate(nums):
            prev = last_seen.get(value, -1)

            seg.update(
                1,
                0,
                n - 1,
                prev + 1,
                i,
                1
            )

            answer = (answer + seg.sum_sq[1]) % MOD

            last_seen[value] = i

        return answer

The implementation follows the algorithm directly.

The SegmentTree class maintains three arrays:

  • sum_vals, storing the sum of distinct counts
  • sum_sq, storing the sum of squared distinct counts
  • lazy, storing pending range increments

The apply function updates an entire segment efficiently without descending into children immediately. The update formula comes from the algebraic identity:

$(x+v)^2 = x^2 + 2vx + v^2$

The push method propagates lazy updates downward only when necessary.

Inside the main solution loop, we locate the previous occurrence of the current value. Every starting position after that occurrence gains one additional distinct element, so we increment the range [prev + 1, i].

After the update, the segment tree root stores the total squared contribution of all subarrays ending at i, which we add to the answer.

Go Solution

package main

const MOD int64 = 1_000_000_007

type SegmentTree struct {
	sumVals []int64
	sumSq   []int64
	lazy    []int64
}

func NewSegmentTree(n int) *SegmentTree {
	size := 4 * n

	return &SegmentTree{
		sumVals: make([]int64, size),
		sumSq:   make([]int64, size),
		lazy:    make([]int64, size),
	}
}

func (st *SegmentTree) apply(node, left, right int, val int64) {
	length := int64(right - left + 1)

	st.sumSq[node] = (
		st.sumSq[node] +
			2*val*st.sumVals[node] +
			length*val*val,
	) % MOD

	st.sumVals[node] = (
		st.sumVals[node] +
			length*val,
	) % MOD

	st.lazy[node] += val
}

func (st *SegmentTree) push(node, left, right int) {
	if st.lazy[node] == 0 {
		return
	}

	mid := (left + right) / 2

	st.apply(node*2, left, mid, st.lazy[node])
	st.apply(node*2+1, mid+1, right, st.lazy[node])

	st.lazy[node] = 0
}

func (st *SegmentTree) update(
	node, left, right,
	ql, qr int,
	val int64,
) {
	if ql <= left && right <= qr {
		st.apply(node, left, right, val)
		return
	}

	st.push(node, left, right)

	mid := (left + right) / 2

	if ql <= mid {
		st.update(node*2, left, mid, ql, qr, val)
	}

	if qr > mid {
		st.update(node*2+1, mid+1, right, ql, qr, val)
	}

	st.sumVals[node] = (
		st.sumVals[node*2] +
			st.sumVals[node*2+1],
	) % MOD

	st.sumSq[node] = (
		st.sumSq[node*2] +
			st.sumSq[node*2+1],
	) % MOD
}

func sumCounts(nums []int) int {
	n := len(nums)

	seg := NewSegmentTree(n)

	lastSeen := map[int]int{}

	var answer int64 = 0

	for i, value := range nums {
		prev := -1

		if idx, exists := lastSeen[value]; exists {
			prev = idx
		}

		seg.update(
			1,
			0,
			n-1,
			prev+1,
			i,
			1,
		)

		answer = (answer + seg.sumSq[1]) % MOD

		lastSeen[value] = i
	}

	return int(answer)
}

The Go implementation mirrors the Python version closely.

The main difference is integer handling. Because intermediate values can become very large, the implementation uses int64 throughout the segment tree. The final answer is converted back to int before returning.

Go slices are used instead of Python lists, and maps replace Python dictionaries.

Worked Examples

Example 1

nums = [1,2,1]

Initial state:

Index Distinct Count
0 0
1 0
2 0

Processing index 0, value = 1

Previous occurrence:

prev = -1

Update range:

[0, 0]

Distinct counts become:

Start Index Subarray Distinct Count
0 [1] 1

Squared contribution:

$1^2 = 1$

Running answer:

1

Processing index 1, value = 2

Previous occurrence:

prev = -1

Update range:

[0, 1]

Distinct counts:

Start Index Subarray Distinct Count
0 [1,2] 2
1 [2] 1

Squared contributions:

$2^2 + 1^2 = 5$

Running answer:

1 + 5 = 6

Processing index 2, value = 1

Previous occurrence:

prev = 0

Update range:

[1, 2]

Distinct counts:

Start Index Subarray Distinct Count
0 [1,2,1] 2
1 [2,1] 2
2 [1] 1

Squared contributions:

$2^2 + 2^2 + 1^2 = 9$

Final answer:

6 + 9 = 15

Example 2

nums = [2,2]

Processing index 0

Distinct counts:

Subarray Count Square
[2] 1 1

Running answer:

1

Processing index 1

Only subarrays starting after the previous occurrence gain a new distinct element.

Updated subarrays:

Subarray Count Square
[2,2] 1 1
[2] 1 1

Contribution:

2

Final answer:

3

Complexity Analysis

Measure Complexity Explanation
Time O(n log n) Each index performs one segment tree range update
Space O(n) Segment tree arrays and hash map storage

The segment tree supports each range update in O(log n) time because lazy propagation avoids visiting every element individually. Since we process n elements total, the overall complexity becomes O(n log n).

The memory usage is linear because the segment tree requires O(n) storage and the last_seen map stores at most one entry per distinct value.

Test Cases

sol = Solution()

assert sol.sumCounts([1, 2, 1]) == 15
# Basic mixed duplicates example

assert sol.sumCounts([2, 2]) == 3
# All elements identical

assert sol.sumCounts([1]) == 1
# Single element array

assert sol.sumCounts([1, 2, 3]) == 20
# All elements distinct

assert sol.sumCounts([1, 1, 1, 1]) == 10
# Every subarray has distinct count 1

assert sol.sumCounts([1, 2, 1, 2]) == 28
# Alternating repeated pattern

assert sol.sumCounts([5, 4, 3, 2, 1]) == 105
# Strictly decreasing distinct values

assert sol.sumCounts([1, 2, 3, 1]) == 38
# Repeat after long gap

assert sol.sumCounts([1, 2, 2, 1]) == 18
# Multiple overlapping duplicate regions

assert sol.sumCounts([100000]) == 1
# Maximum value constraint with minimal length
Test Why
[1,2,1] Validates the main example
[2,2] Ensures duplicate handling works
[1] Smallest possible input
[1,2,3] All elements distinct
[1,1,1,1] Distinct count never increases
[1,2,1,2] Repeated alternating pattern
[5,4,3,2,1] Maximum distinct growth
[1,2,3,1] Tests previous occurrence logic
[1,2,2,1] Tests overlapping duplicate effects
[100000] Validates upper bound element values

Edge Cases

One important edge case is when all elements are identical, such as [7,7,7,7]. In this scenario, every subarray always has exactly one distinct element. A buggy implementation might accidentally increment distinct counts repeatedly for the same value. The last_seen logic prevents this by only incrementing subarrays starting after the previous occurrence.

Another important case is when all elements are distinct, such as [1,2,3,4,5]. Here, every new element increases the distinct count for every active subarray. This stresses the range update mechanism because the updated interval grows continuously. The segment tree handles this efficiently with lazy propagation.

A third tricky case involves overlapping duplicates, such as [1,2,1,2,1]. Multiple values repeatedly reappear, and different subsets of subarrays must be updated each time. The interval [prev + 1, i] precisely captures which subarrays gain a new distinct element, ensuring no subarray is updated too many times or too few times.

Finally, arrays near the maximum constraint size require careful performance considerations. Any algorithm that explicitly processes all subarrays will time out. The segment tree solution avoids this by reducing each iteration to logarithmic work.