LeetCode 3478 - Choose K Elements With Maximum Sum

This problem asks us to compute a "maximum sum" for each index in an array based on a relationship with other elements.

LeetCode Problem 3478

Difficulty: 🟡 Medium
Topics: Array, Sorting, Heap (Priority Queue)

Solution

Problem Understanding

This problem asks us to compute a "maximum sum" for each index in an array based on a relationship with other elements. Specifically, for each index i in nums1, we must consider all previous or other indices j where nums1[j] < nums1[i], and then select up to k values from nums2[j] to maximize the sum. The output is an array of the same length as nums1 where each position stores this maximum sum for the corresponding index.

The input consists of two integer arrays of equal length, nums1 and nums2, and a positive integer k. nums1 represents the values used for comparison, and nums2 represents the values whose sum we want to maximize. The constraints indicate that arrays can be large (up to 100,000 elements) and contain relatively large numbers (up to 1,000,000). This implies that a naive O(n²) solution would be too slow.

Important edge cases include:

  • All elements in nums1 are equal, meaning no index satisfies nums1[j] < nums1[i]. The output should then be all zeros.
  • k is larger than the number of eligible indices. We must take only as many elements as exist, not more.
  • Arrays with strictly increasing or decreasing sequences, testing the algorithm's ability to select the top k sums efficiently.

Approaches

Brute Force Approach

A brute-force approach would iterate over every index i in nums1, then iterate over all indices j to find those where nums1[j] < nums1[i]. We would then sort the eligible nums2[j] values in descending order and sum the top k. While this approach is correct, it has time complexity O(n² log n) due to sorting for each i. Given n can be up to 100,000, this is impractical.

Optimal Approach

The key insight is that we can pre-process the indices in order of increasing nums1 value. By doing this, when we consider an index i, all previously processed elements have nums1[j] <= nums1[i]. We can maintain a min-heap of size k to keep track of the top k largest nums2 values seen so far. For each element, the sum of values in the heap represents the maximum sum of up to k elements for all smaller nums1[j]. This reduces repeated work and avoids sorting at each step.

The approach leverages sorting and a priority queue (min-heap) to maintain a rolling top k sum efficiently.

Approach Time Complexity Space Complexity Notes
Brute Force O(n² log n) O(n) Iterate all pairs, sort eligible nums2 for each i
Optimal O(n log n) O(k + n) Sort indices by nums1 and use min-heap of size k

Algorithm Walkthrough

  1. Create a list of indices of nums1 sorted by their values. This allows us to process elements in increasing order of nums1.

  2. Initialize a min-heap to store up to k largest nums2 values seen so far. Also initialize a variable current_sum to track the sum of the heap elements.

  3. Initialize an array answer of size n filled with zeros.

  4. Iterate over the sorted indices:

  5. For the current index i, the heap contains up to k largest nums2[j] values where nums1[j] < nums1[i].

  6. Set answer[i] = current_sum.

  7. Push nums2[i] into the heap. If the heap exceeds size k, remove the smallest element and adjust current_sum accordingly.

  8. Return answer.

Why it works: By processing in increasing order of nums1, we ensure that when we calculate the answer for a given index, we have already considered all smaller elements. Using a min-heap ensures that we always maintain the top k largest values efficiently.

Python Solution

from typing import List
import heapq

class Solution:
    def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        n = len(nums1)
        answer = [0] * n
        min_heap = []
        current_sum = 0
        
        # Sort indices based on nums1 values
        sorted_indices = sorted(range(n), key=lambda x: nums1[x])
        
        for idx in sorted_indices:
            answer[idx] = current_sum
            heapq.heappush(min_heap, nums2[idx])
            current_sum += nums2[idx]
            
            if len(min_heap) > k:
                removed = heapq.heappop(min_heap)
                current_sum -= removed
                
        return answer

Explanation: We first sort the indices based on nums1 to ensure we process in ascending order. For each index, the current_sum of the min-heap represents the maximum sum of at most k elements with smaller nums1 values. After updating the answer, we add the current nums2 value to the heap and remove the smallest if the heap exceeds size k.

Go Solution

package main

import (
    "container/heap"
    "sort"
)

type IntHeap []int

func (h IntHeap) Len() int           { return len(h) }
func (h IntHeap) Less(i, j int) bool { return h[i] < h[j] }
func (h IntHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *IntHeap) Push(x any)        { *h = append(*h, x.(int)) }
func (h *IntHeap) Pop() any {
    old := *h
    n := len(old)
    x := old[n-1]
    *h = old[0 : n-1]
    return x
}

func findMaxSum(nums1 []int, nums2 []int, k int) []int64 {
    n := len(nums1)
    answer := make([]int64, n)
    
    indices := make([]int, n)
    for i := 0; i < n; i++ {
        indices[i] = i
    }
    
    sort.Slice(indices, func(i, j int) bool { return nums1[indices[i]] < nums1[indices[j]] })
    
    minHeap := &IntHeap{}
    heap.Init(minHeap)
    var currentSum int64 = 0
    
    for _, idx := range indices {
        answer[idx] = currentSum
        heap.Push(minHeap, nums2[idx])
        currentSum += int64(nums2[idx])
        
        if minHeap.Len() > k {
            removed := heap.Pop(minHeap).(int)
            currentSum -= int64(removed)
        }
    }
    
    return answer
}

Explanation: Go implementation mirrors the Python version. We use a min-heap (container/heap) to maintain the top k values. Since Go lacks built-in min-heap for integers, we define IntHeap with standard heap interface methods. Summation uses int64 to avoid overflow for large numbers.

Worked Examples

Example 1: nums1 = [4,2,1,5,3], nums2 = [10,20,30,40,50], k = 2

Sorted indices by nums1: [2, 1, 4, 0, 3]

Step idx nums1[idx] Heap current_sum answer
1 2 1 [30] 30 0
2 1 2 [20,30] 50 30
3 4 3 [50,30] 80 50
4 0 4 [50,30] 80 80
5 3 5 [40,50] 80 80

Output: [80,30,0,80,50]

Example 2: nums1 = [2,2,2,2], nums2 = [3,1,2,3], k = 1

All elements equal, heap never populated before current index. All answers are 0.

Output: [0,0,0,0]

Complexity Analysis

Measure Complexity Explanation
Time O(n log n) Sorting takes O(n log n) and heap operations are O(log k) per element, so O(n log k) is dominated by sorting
Space O(n + k) Array of size n for answers and heap of size up to k

Sorting ensures we process elements in increasing nums1 order. Heap ensures we can maintain top k efficiently, giving an overall optimal solution.

Test Cases

# provided examples
assert Solution().findMaxSum([4,2,1,5,3], [10,20,30,40,50], 2) == [80,30,