LeetCode 27: Remove Element
A clear explanation of removing all occurrences of a value from an array in place using a write pointer.
Problem Restatement
We are given an integer array nums and an integer val.
We need to remove every occurrence of val from nums in place. After removal, the first k elements of nums should contain only the values that are not equal to val.
Return k, the number of remaining elements.
The values after the first k positions do not matter. The order of the remaining elements may be changed. The problem requires O(1) extra memory.
Input and Output
| Item | Meaning |
|---|---|
| Input | An integer array nums and an integer val |
| Output | The number of elements not equal to val |
| Required mutation | Put the kept elements in the first k positions |
| Extra space | O(1) |
Function shape:
def removeElement(nums: list[int], val: int) -> int:
...
Examples
Example 1:
nums = [3, 2, 2, 3]
val = 3
We remove all 3s.
The remaining values are:
[2, 2]
So we return:
2
The first two elements of nums should be 2 and 2.
Example 2:
nums = [0, 1, 2, 2, 3, 0, 4, 2]
val = 2
We remove all 2s.
The remaining values are:
[0, 1, 3, 0, 4]
So we return:
5
The first five elements of nums should contain those five values.
First Thought: Brute Force
A simple solution is to create a new array that stores only values different from val.
class Solution:
def removeElement(self, nums: list[int], val: int) -> int:
kept = []
for num in nums:
if num != val:
kept.append(num)
for i in range(len(kept)):
nums[i] = kept[i]
return len(kept)
This gives the right answer, but it uses another array.
Problem With Brute Force
The problem asks us to modify nums in place with constant extra memory.
The brute force version uses O(n) extra memory because kept may store almost every element.
We need the same filtering idea, but we should write the kept values directly into the front of nums.
Key Insight
Use a write pointer.
The write pointer tells us where the next valid value should go.
As we scan the array, every value different from val should be kept. Every value equal to val should be skipped.
| Pointer | Meaning |
|---|---|
read |
Scans every value in the original array |
write |
Marks the next position for a kept value |
At all times, nums[0:write] contains the values we have decided to keep.
Algorithm
Start with:
write = 0
Then scan every number in nums.
For each number:
if num != val:
nums[write] = num
write += 1
When the loop ends, write is the number of elements not equal to val.
Return write.
Correctness
At the start, write = 0, so the kept prefix is empty.
During the scan, when we see a value equal to val, we skip it. This is correct because that value should not appear in the first k positions.
When we see a value different from val, we copy it into nums[write]. This appends it to the kept prefix. Then we increase write, so the prefix length grows by one.
After processing every element, the prefix nums[0:write] contains exactly the elements that are not equal to val. Therefore, write is exactly the required answer k.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) |
We scan the array once |
| Space | O(1) |
We only use one integer pointer |
Implementation
class Solution:
def removeElement(self, nums: list[int], val: int) -> int:
write = 0
for num in nums:
if num != val:
nums[write] = num
write += 1
return write
Code Explanation
We start with write = 0.
write = 0
This means no valid elements have been written yet.
Then we scan each value:
for num in nums:
If the value should stay, we place it at the next available position:
nums[write] = num
Then we move write forward:
write += 1
If the value equals val, we do nothing. Skipping it removes it from the valid prefix.
Finally:
return write
This returns the number of values kept.
Testing
def check(nums: list[int], val: int, expected: list[int]) -> None:
original = nums[:]
k = Solution().removeElement(nums, val)
assert k == len(expected), (original, val, k, expected)
assert nums[:k] == expected, (original, val, nums[:k], expected)
def run_tests():
check([3, 2, 2, 3], 3, [2, 2])
check([0, 1, 2, 2, 3, 0, 4, 2], 2, [0, 1, 3, 0, 4])
check([], 1, [])
check([1, 1, 1], 1, [])
check([1, 2, 3], 4, [1, 2, 3])
check([4], 4, [])
check([4], 3, [4])
print("all tests passed")
run_tests()
Test meaning:
| Test | Why |
|---|---|
[3,2,2,3], val = 3 |
Basic removal |
[0,1,2,2,3,0,4,2], val = 2 |
Multiple removed values |
| Empty array | No values to process |
| All values removed | Return 0 |
| No values removed | Return original length |
| Single removed value | Minimum non-empty removal |
| Single kept value | Minimum non-empty kept case |