LeetCode 1825 - Finding MK Average

The problem asks us to design a data structure that continuously receives numbers from a stream and can efficiently compute a special average called the MKAverage. For every query, we only care about the most recent m elements in the stream. From those m elements: 1.

LeetCode Problem 1825

Difficulty: 🔴 Hard
Topics: Design, Queue, Heap (Priority Queue), Data Stream, Ordered Set

Solution

Problem Understanding

The problem asks us to design a data structure that continuously receives numbers from a stream and can efficiently compute a special average called the MKAverage.

For every query, we only care about the most recent m elements in the stream. From those m elements:

  1. Remove the smallest k numbers.
  2. Remove the largest k numbers.
  3. Compute the floor of the average of the remaining m - 2k elements.

If fewer than m elements have been inserted so far, the result must be -1.

The challenge is that the stream changes dynamically. Every time a new number is added, the oldest number among the last m elements may disappear from consideration. This means the structure behaves like a sliding window over the stream.

The constraints are large:

  • m can be up to 10^5
  • Up to 10^5 operations are performed

A naive implementation that repeatedly sorts the last m elements for every query would be far too slow.

The key difficulty is maintaining three groups efficiently:

  • The smallest k elements
  • The middle m - 2k elements
  • The largest k elements

We also need to support:

  • Insertion of new elements
  • Removal of old elements
  • Fast retrieval of the middle section sum

Important edge cases include:

  • Fewer than m elements inserted
  • Duplicate values
  • Elements moving between partitions when insertions/removals occur
  • Very large streams where repeated sorting becomes impossible

The problem guarantees that:

  • 2k < m, so the middle section is always non-empty
  • All numbers are positive integers
  • The number of operations is manageable if each operation is approximately logarithmic

Approaches

Brute Force Approach

The simplest approach is to store all stream elements in a queue.

Whenever calculateMKAverage() is called:

  1. Take the last m elements.
  2. Sort them.
  3. Remove the first k and last k.
  4. Sum the remaining values.
  5. Return the floor average.

This works because sorting directly gives the smallest and largest elements.

However, the time complexity is too expensive. Sorting m elements costs O(m log m) per query. Since there can be up to 10^5 operations, this becomes infeasible.

Optimal Approach

The key observation is that we do not need to repeatedly sort the entire window from scratch.

Instead, we maintain the last m elements split into three balanced ordered groups:

  • low: smallest k elements
  • mid: middle m - 2k elements
  • high: largest k elements

We also maintain:

  • A queue containing the last m elements
  • The sum of all elements currently inside mid

When a new element arrives:

  1. Insert it into the appropriate group.
  2. If the window exceeds size m, remove the oldest element.
  3. Rebalance the groups so their sizes remain valid.

Because the answer only depends on the middle section, maintaining mid_sum allows calculateMKAverage() to run in O(1) time.

To support efficient insertion/removal while maintaining sorted order, we use balanced ordered structures.

In Python, sortedcontainers.SortedList is ideal.

In Go, we implement a Treap to support logarithmic operations.

Approach Comparison

Approach Time Complexity Space Complexity Notes
Brute Force O(m log m) per query O(m) Re-sorts window every time
Optimal O(log m) per operation O(m) Maintains balanced partitions dynamically

Algorithm Walkthrough

Data Structures

We maintain:

  • A queue storing the most recent m elements
  • low, an ordered set containing the smallest k elements
  • mid, an ordered set containing the middle elements
  • high, an ordered set containing the largest k elements
  • mid_sum, the sum of elements inside mid

Step-by-Step Process

  1. Initialize the data structure.

Create empty ordered structures for low, mid, and high. Also initialize the queue and mid_sum. 2. Add a new element.

When inserting a number:

  • If it belongs among the smallest values, place it in low
  • Else if it belongs among the largest values, place it in high
  • Otherwise place it in mid

Whenever an element enters or leaves mid, update mid_sum. 3. Maintain window size.

If the queue size exceeds m, remove the oldest element from whichever group currently contains it. 4. Rebalance the groups.

After insertion/removal, group sizes may become invalid.

We repeatedly rebalance:

  • If low has too many elements, move its largest element into mid
  • If high has too many elements, move its smallest element into mid
  • If low has too few elements, move the smallest element from mid
  • If high has too few elements, move the largest element from mid

Every transfer updates mid_sum appropriately. 5. Calculate the MKAverage.

If fewer than m elements exist, return -1.

Otherwise:

$$\text{answer} = \left\lfloor \frac{\text{mid_sum}}{m - 2k} \right\rfloor$$

Why It Works

The invariant is that:

  • low always contains exactly the smallest k elements
  • high always contains exactly the largest k elements
  • mid contains every remaining element

Because every insertion and deletion is followed by rebalancing, these properties always hold. Since mid_sum tracks the exact sum of the middle section, dividing by m - 2k produces the correct MKAverage.

Python Solution

from collections import deque
from sortedcontainers import SortedList

class MKAverage:

    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        
        self.queue = deque()
        
        self.low = SortedList()
        self.mid = SortedList()
        self.high = SortedList()
        
        self.mid_sum = 0

    def _add_to_group(self, num: int) -> None:
        if not self.low or num <= self.low[-1]:
            self.low.add(num)
        elif not self.high or num >= self.high[0]:
            self.high.add(num)
        else:
            self.mid.add(num)
            self.mid_sum += num

    def _remove_from_group(self, num: int) -> None:
        if num in self.low:
            self.low.remove(num)
        elif num in self.high:
            self.high.remove(num)
        else:
            self.mid.remove(num)
            self.mid_sum -= num

    def _rebalance(self) -> None:
        while len(self.low) > self.k:
            value = self.low.pop(-1)
            self.mid.add(value)
            self.mid_sum += value

        while len(self.high) > self.k:
            value = self.high.pop(0)
            self.mid.add(value)
            self.mid_sum += value

        while len(self.low) < self.k and self.mid:
            value = self.mid.pop(0)
            self.mid_sum -= value
            self.low.add(value)

        while len(self.high) < self.k and self.mid:
            value = self.mid.pop(-1)
            self.mid_sum -= value
            self.high.add(value)

    def addElement(self, num: int) -> None:
        self.queue.append(num)

        self._add_to_group(num)
        self._rebalance()

        if len(self.queue) > self.m:
            old = self.queue.popleft()
            self._remove_from_group(old)
            self._rebalance()

    def calculateMKAverage(self) -> int:
        if len(self.queue) < self.m:
            return -1

        return self.mid_sum // (self.m - 2 * self.k)

# Your MKAverage object will be instantiated and called as such:
# obj = MKAverage(m, k)
# obj.addElement(num)
# param_2 = obj.calculateMKAverage()

The implementation maintains three sorted partitions.

The _add_to_group() method determines where a newly inserted number belongs based on current partition boundaries.

The _remove_from_group() method removes the outgoing element from whichever partition currently contains it.

The _rebalance() method is the most important part of the solution. It restores the invariants that:

  • low has exactly k elements
  • high has exactly k elements
  • all remaining elements stay in mid

The mid_sum variable is updated whenever elements move in or out of mid. This allows average computation in constant time.

The queue ensures we always know which element should leave the sliding window once more than m elements have been inserted.

Go Solution

package main

import (
	"container/list"
	"math/rand"
)

type Node struct {
	key      int
	priority int
	count    int
	size     int
	left     *Node
	right    *Node
}

func nodeSize(n *Node) int {
	if n == nil {
		return 0
	}
	return n.size
}

func update(n *Node) {
	if n != nil {
		n.size = n.count + nodeSize(n.left) + nodeSize(n.right)
	}
}

func rotateRight(y *Node) *Node {
	x := y.left
	y.left = x.right
	x.right = y
	update(y)
	update(x)
	return x
}

func rotateLeft(x *Node) *Node {
	y := x.right
	x.right = y.left
	y.left = x
	update(x)
	update(y)
	return y
}

func insert(root *Node, key int) *Node {
	if root == nil {
		return &Node{
			key:      key,
			priority: rand.Int(),
			count:    1,
			size:     1,
		}
	}

	if key == root.key {
		root.count++
	} else if key < root.key {
		root.left = insert(root.left, key)
		if root.left.priority > root.priority {
			root = rotateRight(root)
		}
	} else {
		root.right = insert(root.right, key)
		if root.right.priority > root.priority {
			root = rotateLeft(root)
		}
	}

	update(root)
	return root
}

func erase(root *Node, key int) *Node {
	if root == nil {
		return nil
	}

	if key < root.key {
		root.left = erase(root.left, key)
	} else if key > root.key {
		root.right = erase(root.right, key)
	} else {
		if root.count > 1 {
			root.count--
		} else {
			if root.left == nil {
				return root.right
			}
			if root.right == nil {
				return root.left
			}

			if root.left.priority > root.right.priority {
				root = rotateRight(root)
				root.right = erase(root.right, key)
			} else {
				root = rotateLeft(root)
				root.left = erase(root.left, key)
			}
		}
	}

	update(root)
	return root
}

func getMin(root *Node) int {
	for root.left != nil {
		root = root.left
	}
	return root.key
}

func getMax(root *Node) int {
	for root.right != nil {
		root = root.right
	}
	return root.key
}

func contains(root *Node, key int) bool {
	if root == nil {
		return false
	}

	if key == root.key {
		return true
	}

	if key < root.key {
		return contains(root.left, key)
	}

	return contains(root.right, key)
}

type OrderedSet struct {
	root *Node
	size int
}

func (s *OrderedSet) Add(x int) {
	s.root = insert(s.root, x)
	s.size++
}

func (s *OrderedSet) Remove(x int) {
	s.root = erase(s.root, x)
	s.size--
}

func (s *OrderedSet) Min() int {
	return getMin(s.root)
}

func (s *OrderedSet) Max() int {
	return getMax(s.root)
}

func (s *OrderedSet) Contains(x int) bool {
	return contains(s.root, x)
}

func (s *OrderedSet) Len() int {
	return s.size
}

type MKAverage struct {
	m      int
	k      int
	queue  *list.List
	low    OrderedSet
	mid    OrderedSet
	high   OrderedSet
	midSum int64
}

func Constructor(m int, k int) MKAverage {
	return MKAverage{
		m:     m,
		k:     k,
		queue: list.New(),
	}
}

func (this *MKAverage) addToGroup(num int) {
	if this.low.Len() == 0 || num <= this.low.Max() {
		this.low.Add(num)
	} else if this.high.Len() == 0 || num >= this.high.Min() {
		this.high.Add(num)
	} else {
		this.mid.Add(num)
		this.midSum += int64(num)
	}
}

func (this *MKAverage) removeFromGroup(num int) {
	if this.low.Contains(num) {
		this.low.Remove(num)
	} else if this.high.Contains(num) {
		this.high.Remove(num)
	} else {
		this.mid.Remove(num)
		this.midSum -= int64(num)
	}
}

func (this *MKAverage) rebalance() {
	for this.low.Len() > this.k {
		val := this.low.Max()
		this.low.Remove(val)
		this.mid.Add(val)
		this.midSum += int64(val)
	}

	for this.high.Len() > this.k {
		val := this.high.Min()
		this.high.Remove(val)
		this.mid.Add(val)
		this.midSum += int64(val)
	}

	for this.low.Len() < this.k && this.mid.Len() > 0 {
		val := this.mid.Min()
		this.mid.Remove(val)
		this.midSum -= int64(val)
		this.low.Add(val)
	}

	for this.high.Len() < this.k && this.mid.Len() > 0 {
		val := this.mid.Max()
		this.mid.Remove(val)
		this.midSum -= int64(val)
		this.high.Add(val)
	}
}

func (this *MKAverage) AddElement(num int) {
	this.queue.PushBack(num)

	this.addToGroup(num)
	this.rebalance()

	if this.queue.Len() > this.m {
		front := this.queue.Front()
		old := front.Value.(int)
		this.queue.Remove(front)

		this.removeFromGroup(old)
		this.rebalance()
	}
}

func (this *MKAverage) CalculateMKAverage() int {
	if this.queue.Len() < this.m {
		return -1
	}

	return int(this.midSum / int64(this.m-2*this.k))
}

/**
 * Your MKAverage object will be instantiated and called as such:
 * obj := Constructor(m, k);
 * obj.AddElement(num);
 * param_2 := obj.CalculateMKAverage();
 */

The Go solution cannot rely on a built-in balanced ordered multiset like Python's SortedList.

Instead, it implements a Treap, which is a randomized balanced binary search tree. The Treap supports:

  • insertion
  • deletion
  • minimum lookup
  • maximum lookup
  • duplicate counts

all in expected O(log n) time.

The midSum variable uses int64 because repeated additions may exceed standard 32-bit integer range.

Worked Examples

Example 1

Input:

m = 3
k = 1

Step 1

Add 3

Structure Values
Queue [3]
low [3]
mid []
high []

Not enough elements yet.

Result: -1

Step 2

Add 1

After rebalancing:

Structure Values
Queue [3,1]
low [1]
mid []
high [3]

Still fewer than m.

Result: -1

Step 3

Add 10

After rebalancing:

Structure Values
Queue [3,1,10]
low [1]
mid [3]
high [10]

mid_sum = 3

Average:

$$3 // 1 = 3$$

Result: 3

Step 4

Add 5

Window exceeds size 3, remove oldest element 3.

Current window:

[1,10,5]

After rebalancing:

Structure Values
low [1]
mid [5]
high [10]

Result:

$$5 // 1 = 5$$

Step 5

Add 5

Window becomes:

[10,5,5]

After rebalancing:

Structure Values
low [5]
mid [5]
high [10]

Result:

$$5 // 1 = 5$$

Step 6

Add 5

Window becomes:

[5,5,5]

After rebalancing:

Structure Values
low [5]
mid [5]
high [5]

Result:

$$5 // 1 = 5$$

Complexity Analysis

Measure Complexity Explanation
Time O(log m) per operation Ordered set insertion/removal/rebalancing
Space O(m) Stores last m elements and partitions

Each insertion or removal from an ordered balanced structure costs O(log m). Rebalancing only moves a constant number of elements between partitions, so each operation remains logarithmic.

The memory usage is linear because all active elements inside the sliding window must be stored.

Test Cases

# Example from problem statement
obj = MKAverage(3, 1)
obj.addElement(3)
obj.addElement(1)
assert obj.calculateMKAverage() == -1  # not enough elements

obj.addElement(10)
assert obj.calculateMKAverage() == 3  # middle element is 3

obj.addElement(5)
obj.addElement(5)
obj.addElement(5)
assert obj.calculateMKAverage() == 5  # middle element is 5

# Minimum valid middle size
obj = MKAverage(5, 2)
obj.addElement(1)
obj.addElement(2)
obj.addElement(3)
obj.addElement(4)
obj.addElement(5)
assert obj.calculateMKAverage() == 3  # only middle value remains

# Duplicate values
obj = MKAverage(6, 1)
for x in [5, 5, 5, 5, 5, 5]:
    obj.addElement(x)
assert obj.calculateMKAverage() == 5  # all values identical

# Sliding window behavior
obj = MKAverage(4, 1)
for x in [1, 2, 3, 4]:
    obj.addElement(x)
assert obj.calculateMKAverage() == 2  # middle is [2,3]

obj.addElement(100)
assert obj.calculateMKAverage() == 3  # window becomes [2,3,4,100]

# Large extremes
obj = MKAverage(5, 1)
for x in [100000, 1, 99999, 2, 50000]:
    obj.addElement(x)
assert obj.calculateMKAverage() == 50000  # middle is [2,50000,99999]

# Repeated removals and insertions
obj = MKAverage(3, 1)
obj.addElement(1)
obj.addElement(2)
obj.addElement(3)
assert obj.calculateMKAverage() == 2

obj.addElement(4)
assert obj.calculateMKAverage() == 3

obj.addElement(5)
assert obj.calculateMKAverage() == 4

# Window not yet full
obj = MKAverage(10, 3)
for x in [1, 2, 3]:
    obj.addElement(x)
assert obj.calculateMKAverage() == -1  # insufficient elements

Test Summary

Test Why
Problem statement example Verifies core functionality
Minimum middle size Tests case where only one middle element remains
Duplicate values Ensures duplicates are handled correctly
Sliding window behavior Verifies old elements are removed properly
Large extremes Tests partition correctness with large values
Continuous updates Verifies repeated rebalancing
Incomplete window Ensures -1 is returned correctly

Edge Cases

One important edge case occurs when fewer than m elements have been inserted. A buggy implementation might attempt to compute an average before enough elements exist. The solution explicitly checks the queue length before performing any calculation and returns -1 immediately.

Another tricky case involves duplicate values. Since many elements may have identical values, using ordinary sets would fail because duplicates would collapse into one entry. Both implementations therefore use multiset behavior. Python's SortedList naturally supports duplicates, while the Go Treap stores a count field for repeated values.

A third important edge case occurs when the sliding window removes an element that belongs to one partition, causing partition sizes to become invalid. For example, removing the only element in low requires pulling a replacement from mid. The rebalance step guarantees that after every insertion or deletion, all partitions regain their required sizes.

Another subtle case appears when all values are identical. In this situation, partition boundaries are ambiguous because every value compares equally. The implementation handles this correctly because balancing depends only on partition sizes, not on strict inequality distinctions.