LeetCode 3762 - Minimum Operations to Equalize Subarrays

We are given an array nums and a fixed integer k. A single operation allows us to choose any element and either increase it by exactly k or decrease it by exactly k. We may perform as many operations as needed. For each query [l, r], we only consider the subarray nums[l..r].

LeetCode Problem 3762

Difficulty: 🔴 Hard
Topics: Array, Math, Binary Search, Segment Tree

Solution

Problem Understanding

We are given an array nums and a fixed integer k.

A single operation allows us to choose any element and either increase it by exactly k or decrease it by exactly k. We may perform as many operations as needed.

For each query [l, r], we only consider the subarray nums[l..r]. The goal is to determine the minimum number of operations required to make every element in that subarray equal. If it is impossible to make them equal, we return -1 for that query.

The key observation is that adding or subtracting k never changes an element's remainder modulo k. Therefore, every value can only move within its own residue class.

For example, when k = 3:

  • 1 can become ..., -5, -2, 1, 4, 7, 10, ...
  • 2 can become ..., -4, -1, 2, 5, 8, 11, ...

A value with remainder 1 modulo 3 can never become a value with remainder 2 modulo 3.

The constraints are large:

  • n ≤ 4 × 10^4
  • queries.length ≤ 4 × 10^4

A solution that processes each query independently by sorting or scanning the entire subarray would be far too slow.

Some important edge cases are:

  • A query containing only one element always requires 0 operations.
  • If even two elements in the range have different remainders modulo k, the answer is immediately -1.
  • Large values up to 10^9 require 64-bit arithmetic for operation counts.
  • Long ranges require efficient median and distance computations.

Approaches

Brute Force

For each query:

  1. Check whether all elements have the same remainder modulo k.
  2. If not, return -1.
  3. Convert every value into its scaled form.
  4. Sort the subarray.
  5. Use the median to compute the minimum total number of operations.

This works because the minimum sum of absolute deviations is achieved at the median.

However, sorting every queried subarray is extremely expensive. In the worst case, a query may contain O(n) elements, giving a complexity near O(q · n log n), which is far too slow for n = q = 40000.

Key Insight

Suppose every value in a query has the same remainder r modulo k.

Then each element can be written as:

$$nums[i] = r + k \cdot b_i$$

where:

$$b_i = \left\lfloor \frac{nums[i]}{k} \right\rfloor$$

Making all numbers equal is equivalent to choosing an integer target t and minimizing:

$$\sum |b_i - t|$$

This is the classical minimum absolute deviation problem, whose optimum occurs at the median.

Therefore each query becomes:

  1. Verify all remainders are identical.
  2. Find the median of the transformed values b_i inside the range.
  3. Compute the sum of distances to that median.

The challenge is answering these range median and range distance queries efficiently.

A Merge Sort Tree solves this:

  • Range k-th smallest queries give the median.
  • Range count/sum queries allow computing total distance to the median.
  • Residue uniformity is checked using a simple prefix array.
Approach Time Complexity Space Complexity Notes
Brute Force O(q · n log n) O(n) Sort every queried subarray
Optimal O((n + q) log² n log V) O(n log n) Merge Sort Tree with range median queries

Here V is the number of distinct transformed values.

Algorithm Walkthrough

Step 1: Compute residue information

For every index:

$$rem[i] = nums[i] \bmod k$$

Create a prefix array recording residue changes:

change[i] = 1 if rem[i] != rem[i-1]

Then:

prefixChange[i]

stores the total number of residue transitions up to position i.

A query [l, r] has identical residues iff:

prefixChange[r] - prefixChange[l] == 0

Step 2: Build transformed values

Define:

$$b_i = nums[i] // k$$

When all residues are equal, every operation changes b_i by exactly 1.

Therefore the answer becomes the minimum sum of absolute differences to a common integer.

Step 3: Build a Merge Sort Tree

For every segment tree node store:

  • sorted values in that segment
  • prefix sums of the sorted values

This allows answering:

  • how many values in a range are ≤ x
  • the sum of all values in a range that are ≤ x

Both operations take O(log² n).

Step 4: Find the median of a query range

Let:

len = r - l + 1
kth = (len + 1) // 2

The median is the kth smallest element.

Binary search over the globally sorted distinct values.

For a candidate value x, query how many elements in [l, r] are ≤ x.

This determines whether the median lies to the left or right.

Step 5: Compute distance to the median

Let the median value be m.

Query:

  • count of values ≤ m
  • sum of values ≤ m

Let:

cntL = count(≤ m)
sumL = sum(≤ m)

cnt = range length
sum = total range sum

cntR = cnt - cntL
sumR = sum - sumL

The total operations are:

$$m \cdot cntL - sumL + sumR - m \cdot cntR$$

which equals:

$$\sum |b_i - m|$$

Why it works

If a range contains different remainders modulo k, no sequence of ±k operations can make all values equal because remainders never change.

When all remainders match, every value can be represented as r + k·b_i. Each operation changes b_i by exactly one. Therefore minimizing the number of operations is equivalent to minimizing the sum of absolute deviations of the b_i values. A fundamental property of absolute deviation states that the minimum is achieved at any median. The Merge Sort Tree efficiently finds that median and computes the corresponding distance sum, giving the optimal answer.

Python Solution

from typing import List
from bisect import bisect_right

class Solution:
    def minOperations(self, nums: List[int], k: int, queries: List[List[int]]) -> List[int]:
        n = len(nums)

        values = [x // k for x in nums]

        change_prefix = [0] * n
        for i in range(1, n):
            change_prefix[i] = change_prefix[i - 1]
            if nums[i] % k != nums[i - 1] % k:
                change_prefix[i] += 1

        class MergeSortTree:
            def __init__(self, arr):
                self.n = len(arr)
                self.tree = [[] for _ in range(self.n * 4)]
                self.pref = [[] for _ in range(self.n * 4)]
                self.build(1, 0, self.n - 1, arr)

            def build(self, node, left, right, arr):
                if left == right:
                    self.tree[node] = [arr[left]]
                    self.pref[node] = [0, arr[left]]
                    return

                mid = (left + right) // 2
                self.build(node * 2, left, mid, arr)
                self.build(node * 2 + 1, mid + 1, right, arr)

                merged = []
                a = self.tree[node * 2]
                b = self.tree[node * 2 + 1]

                i = j = 0
                while i < len(a) and j < len(b):
                    if a[i] <= b[j]:
                        merged.append(a[i])
                        i += 1
                    else:
                        merged.append(b[j])
                        j += 1

                merged.extend(a[i:])
                merged.extend(b[j:])

                self.tree[node] = merged

                pref = [0]
                s = 0
                for v in merged:
                    s += v
                    pref.append(s)
                self.pref[node] = pref

            def query_leq(self, node, left, right, ql, qr, x):
                if ql <= left and right <= qr:
                    pos = bisect_right(self.tree[node], x)
                    return pos, self.pref[node][pos]

                if right < ql or left > qr:
                    return 0, 0

                mid = (left + right) // 2

                c1, s1 = self.query_leq(
                    node * 2, left, mid, ql, qr, x
                )
                c2, s2 = self.query_leq(
                    node * 2 + 1, mid + 1, right, ql, qr, x
                )

                return c1 + c2, s1 + s2

        mst = MergeSortTree(values)

        sorted_values = sorted(set(values))

        def count_leq(l, r, x):
            return mst.query_leq(1, 0, n - 1, l, r, x)[0]

        def count_sum_leq(l, r, x):
            return mst.query_leq(1, 0, n - 1, l, r, x)

        answers = []

        for l, r in queries:
            if change_prefix[r] - change_prefix[l] != 0:
                answers.append(-1)
                continue

            length = r - l + 1
            kth = (length + 1) // 2

            lo = 0
            hi = len(sorted_values) - 1

            while lo < hi:
                mid = (lo + hi) // 2
                if count_leq(l, r, sorted_values[mid]) >= kth:
                    hi = mid
                else:
                    lo = mid + 1

            median = sorted_values[lo]

            cnt_le, sum_le = count_sum_leq(l, r, median)

            total_cnt = length

            _, total_sum = count_sum_leq(l, r, sorted_values[-1])

            cnt_gt = total_cnt - cnt_le
            sum_gt = total_sum - sum_le

            cost = (
                median * cnt_le - sum_le
                + sum_gt - median * cnt_gt
            )

            answers.append(cost)

        return answers

The implementation begins by transforming each value into its scaled representation nums[i] // k. A prefix array tracks where residues modulo k change, allowing constant time feasibility checks for every query.

The Merge Sort Tree stores a sorted list and prefix sums for every segment. This enables efficient range counting and range sum queries. Using these operations, the median is found through binary search over all distinct values. Once the median is known, the formula for absolute deviation is evaluated using counts and sums on both sides of the median.

Go Solution

package main

import "sort"

type Node struct {
	vals []int
	pref []int64
}

type MergeSortTree struct {
	tree []Node
	n    int
}

func NewMergeSortTree(arr []int) *MergeSortTree {
	n := len(arr)
	mst := &MergeSortTree{
		tree: make([]Node, 4*n),
		n:    n,
	}
	mst.build(1, 0, n-1, arr)
	return mst
}

func (mst *MergeSortTree) build(node, l, r int, arr []int) {
	if l == r {
		mst.tree[node].vals = []int{arr[l]}
		mst.tree[node].pref = []int64{0, int64(arr[l])}
		return
	}

	mid := (l + r) / 2

	mst.build(node*2, l, mid, arr)
	mst.build(node*2+1, mid+1, r, arr)

	a := mst.tree[node*2].vals
	b := mst.tree[node*2+1].vals

	merged := make([]int, 0, len(a)+len(b))

	i, j := 0, 0
	for i < len(a) && j < len(b) {
		if a[i] <= b[j] {
			merged = append(merged, a[i])
			i++
		} else {
			merged = append(merged, b[j])
			j++
		}
	}

	merged = append(merged, a[i:]...)
	merged = append(merged, b[j:]...)

	pref := make([]int64, len(merged)+1)
	for i, v := range merged {
		pref[i+1] = pref[i] + int64(v)
	}

	mst.tree[node].vals = merged
	mst.tree[node].pref = pref
}

func (mst *MergeSortTree) query(node, l, r, ql, qr, x int) (int, int64) {
	if ql <= l && r <= qr {
		pos := sort.Search(len(mst.tree[node].vals),
			func(i int) bool {
				return mst.tree[node].vals[i] > x
			})

		return pos, mst.tree[node].pref[pos]
	}

	if r < ql || l > qr {
		return 0, 0
	}

	mid := (l + r) / 2

	c1, s1 := mst.query(node*2, l, mid, ql, qr, x)
	c2, s2 := mst.query(node*2+1, mid+1, r, ql, qr, x)

	return c1 + c2, s1 + s2
}

func minOperations(nums []int, k int, queries [][]int) []int64 {
	n := len(nums)

	values := make([]int, n)
	for i := 0; i < n; i++ {
		values[i] = nums[i] / k
	}

	change := make([]int, n)
	for i := 1; i < n; i++ {
		change[i] = change[i-1]
		if nums[i]%k != nums[i-1]%k {
			change[i]++
		}
	}

	mst := NewMergeSortTree(values)

	distinct := append([]int(nil), values...)
	sort.Ints(distinct)

	uniq := make([]int, 0)
	for _, v := range distinct {
		if len(uniq) == 0 || uniq[len(uniq)-1] != v {
			uniq = append(uniq, v)
		}
	}

	ans := make([]int64, len(queries))

	for qi, q := range queries {
		l, r := q[0], q[1]

		if change[r]-change[l] != 0 {
			ans[qi] = -1
			continue
		}

		length := r - l + 1
		kth := (length + 1) / 2

		lo, hi := 0, len(uniq)-1

		for lo < hi {
			mid := (lo + hi) / 2

			cnt, _ := mst.query(1, 0, n-1, l, r, uniq[mid])

			if cnt >= kth {
				hi = mid
			} else {
				lo = mid + 1
			}
		}

		median := uniq[lo]

		cntLe, sumLe := mst.query(1, 0, n-1, l, r, median)
		_, totalSum := mst.query(1, 0, n-1, l, r, uniq[len(uniq)-1])

		cntGt := length - cntLe
		sumGt := totalSum - sumLe

		cost := int64(median)*int64(cntLe) - sumLe +
			sumGt - int64(median)*int64(cntGt)

		ans[qi] = cost
	}

	return ans
}

The Go version mirrors the Python implementation. The main difference is that all sums and answers use int64 because the total number of operations can exceed the range of a 32-bit integer. Binary searches are implemented using sort.Search, and prefix sums inside the Merge Sort Tree are also stored as int64.

Worked Examples

Example 1

nums = [1,4,7]
k = 3
queries = [[0,1],[0,2]]

Transformed values:

Index nums nums % 3 nums // 3
0 1 1 0
1 4 1 1
2 7 1 2

All residues are identical.

Query [0,1]

Range values:

[0,1]

Median:

0

or

1

Using median = 0:

|0-0| + |1-0| = 1

Answer:

1

Query [0,2]

Range values:

[0,1,2]

Median:

1

Cost:

|0-1| + |1-1| + |2-1|
= 1 + 0 + 1
= 2

Answer:

2

Final result:

[1,2]

Example 2

nums = [1,2,4]
k = 2

Residues:

Value Residue
1 1
2 0
4 0

Query [0,2]

Residues are not identical.

Answer:

-1

Query [0,0]

Single element.

Answer:

0

Query [1,2]

Transformed values:

[1,2]

Median:

1

Cost:

|1-1| + |2-1| = 1

Answer:

1

Final result:

[-1,0,1]

Complexity Analysis

Measure Complexity Explanation
Time O((n + q) log² n log V) Each query performs a median binary search, each step requires a range count query
Space O(n log n) Merge Sort Tree stores sorted lists at every level

The Merge Sort Tree contains every value in O(log n) nodes, producing O(n log n) memory usage. Each count or sum query requires traversing O(log n) nodes and performing binary search inside each node, resulting in O(log² n) time. Finding the median adds an additional binary search over the value domain.

Test Cases

sol = Solution()

# Example 1
assert sol.minOperations(
    [1, 4, 7],
    3,
    [[0, 1], [0, 2]]
) == [1, 2]

# Example 2
assert sol.minOperations(
    [1, 2, 4],
    2,
    [[0, 2], [0, 0], [1, 2]]
) == [-1, 0, 1]

# Single element range
assert sol.minOperations(
    [10],
    5,
    [[0, 0]]
) == [0]

# All already equal
assert sol.minOperations(
    [8, 8, 8],
    4,
    [[0, 2]]
) == [0]

# Different residues, impossible
assert sol.minOperations(
    [1, 2, 3],
    2,
    [[0, 2]]
) == [-1]

# Two elements needing one operation
assert sol.minOperations(
    [2, 4],
    2,
    [[0, 1]]
) == [1]

# Larger valid range
assert sol.minOperations(
    [1, 4, 7, 10, 13],
    3,
    [[0, 4]]
) == [6]

# Multiple mixed queries
assert sol.minOperations(
    [1, 4, 7, 2, 5],
    3,
    [[0, 2], [2, 4], [0, 4]]
) == [2, 2, -1]

# Large values
assert sol.minOperations(
    [10**9, 10**9 - 3],
    3,
    [[0, 1]]
) == [1]
Test Why
[1,4,7] Validates the first example
[1,2,4] Validates impossible and valid ranges together
Single element Minimum size query
Already equal Zero-cost range
Different residues Impossible transformation
Two elements Smallest nontrivial valid range
Arithmetic progression Larger median computation
Mixed queries Tests residue boundary detection
Large values Verifies 64-bit arithmetic correctness

Edge Cases

Single Element Subarray

A subarray containing only one element is already equalized. The correct answer is always 0. The implementation handles this naturally because the median equals the only value and the absolute deviation sum is zero.

Mixed Residues Modulo k

Consider:

nums = [1, 2]
k = 2

The remainders are 1 and 0. Since adding or subtracting 2 never changes a remainder modulo 2, these values can never become equal. The residue-change prefix array detects this in constant time and immediately returns -1.

Large Numerical Values

Values may be as large as 10^9, and a query may contain up to 40000 elements. The total operation count can therefore exceed the range of 32-bit integers. The implementation stores sums and answers using 64-bit arithmetic (int64 in Go and Python's arbitrary precision integers), preventing overflow.

Even Length Subarrays

For an even number of elements there are two valid medians. Any value between the two middle elements minimizes the sum of absolute deviations. The algorithm selects the lower median via the (len + 1) // 2 convention, which still produces the optimal minimum cost.