LeetCode 2623 - Memoize

This problem asks us to create a memoized version of a function. Memoization is a technique where you store the results of expensive function calls and return the cached result when the same inputs occur again.

LeetCode Problem 2623

Difficulty: 🟡 Medium
Topics:

Solution

Problem Understanding

This problem asks us to create a memoized version of a function. Memoization is a technique where you store the results of expensive function calls and return the cached result when the same inputs occur again. Here, the input is a function fn that could be sum, fib, or factorial.

The key points to understand are:

  1. sum(a, b) takes two integers and returns their sum. The order of arguments matters, so sum(2, 3) is different from sum(3, 2).
  2. fib(n) returns the nth Fibonacci number using a recursive definition. For example, fib(5) = fib(4) + fib(3).
  3. factorial(n) returns the factorial of n recursively. For example, factorial(3) = 3 * factorial(2).

We need to track how many actual calls to the original function occur, ignoring cached results. Input actions are either "call" with arguments, or "getCallCount" to return the number of uncached calls.

Constraints inform us about the input scale:

  • Arguments for sum can go up to 105, while fib and factorial only go up to 10, so caching is feasible.
  • Action length can be large (up to 10^5), so we need efficient O(1) caching per call.

Edge cases to consider include repeated calls with the same arguments, calls to functions with minimal inputs like 0 or 1, and calls with arguments in different orders for sum.

Approaches

A brute-force approach would be to call the original function every time without caching. This is correct but inefficient, especially for fib and factorial, because redundant recursive calls make it exponential for fib.

The optimal approach is to use a hash map (dictionary in Python, map in Go) to cache results. For sum, we use a tuple (a, b) as a key; for fib and factorial, the integer argument itself is the key. We also maintain a counter for the number of real function calls.

This approach guarantees O(1) access to cached results, avoiding redundant computation and satisfying large input constraints.

Approach Time Complexity Space Complexity Notes
Brute Force O(2^n) for fib, O(n) for factorial O(n) recursion stack Calls original function every time, redundant for fib and factorial
Optimal O(1) per call for caching O(n) for cache Uses hash map to store results, tracks call count

Algorithm Walkthrough

  1. Define a dictionary cache to store previously computed results. The key will be the function arguments, and the value will be the result.
  2. Define an integer call_count to track actual function calls.
  3. Define a wrapper function memoized_fn that:
  • Converts its arguments into a tuple to use as a cache key.
  • Checks if the key exists in cache.
  • If yes, returns the cached value.
  • If no, increments call_count, calls the original function, stores the result in cache, and returns it.
  1. Implement a method getCallCount that simply returns call_count.
  2. Return both the memoized function and getCallCount as methods or attributes.

Why it works: The algorithm ensures that each unique set of arguments is computed at most once. The invariant is that cache always contains the result of the first computation for every unique input. The call counter only increments on new computations, which satisfies the requirement.

Python Solution

from typing import Callable, Any

def memoize(fn: Callable) -> Any:
    cache = {}
    call_count = 0

    def memoized_fn(*args):
        nonlocal call_count
        key = tuple(args)
        if key in cache:
            return cache[key]
        result = fn(*args)
        cache[key] = result
        call_count += 1
        return result

    memoized_fn.getCallCount = lambda: call_count
    return memoized_fn

The memoized_fn function converts arguments to a tuple for use as a key in the dictionary. The nonlocal keyword allows modification of call_count inside the nested function. The getCallCount method is added as an attribute to allow querying without exposing internal details.

Go Solution

package main

type MemoizedFunc struct {
    fn        func(...int) int
    cache     map[string]int
    callCount int
}

func NewMemoizedFunc(fn func(...int) int) *MemoizedFunc {
    return &MemoizedFunc{
        fn:    fn,
        cache: make(map[string]int),
    }
}

func (m *MemoizedFunc) Call(args ...int) int {
    key := ""
    for _, v := range args {
        key += string(rune(v)) + ","
    }
    if val, exists := m.cache[key]; exists {
        return val
    }
    result := m.fn(args...)
    m.cache[key] = result
    m.callCount++
    return result
}

func (m *MemoizedFunc) GetCallCount() int {
    return m.callCount
}

In Go, we represent the memoized function as a struct with the original function, a cache map, and a call count. The key is serialized from arguments as a string. Methods Call and GetCallCount encapsulate the behavior.

Worked Examples

Example 1: sum

Action Input Cache Before Call Count Cache After Output
call (2, 2) {} 0 {(2,2):4} 4
call (2, 2) {(2,2):4} 1 {(2,2):4} 4
getCallCount - {(2,2):4} 1 {(2,2):4} 1
call (1, 2) {(2,2):4} 1 {(2,2):4,(1,2):3} 3
getCallCount - {(2,2):4,(1,2):3} 2 {(2,2):4,(1,2):3} 2

Example 2: factorial

Action Input Cache Before Call Count Cache After Output
call 2 {} 0 {2:2} 2
call 3 {2:2} 1 {2:2,3:6} 6
call 2 {2:2,3:6} 2 {2:2,3:6} 2
getCallCount - {2:2,3:6} 2 {2:2,3:6} 2
call 3 {2:2,3:6} 2 {2:2,3:6} 6
getCallCount - {2:2,3:6} 2 {2:2,3:6} 2

Example 3: fib

Action Input Cache Before Call Count Cache After Output
call 5 {} 0 {5:8} 8
getCallCount - {5:8} 1 {5:8} 1

Complexity Analysis

Measure Complexity Explanation
Time O(1) per call Cache lookup and storage are O(1) operations per call
Space O(n) The cache stores each unique input exactly once

The memoization ensures that each unique function argument combination is computed only once, making repeated calls extremely efficient.

Test Cases

# Provided examples
sum_fn = lambda a, b: a + b
memoSum = memoize(sum_fn)
assert memoSum(2, 2) == 4
assert memoSum(2, 2) == 4
assert memoSum.getCallCount() == 1
assert memoSum(1, 2) == 3
assert memoSum.getCallCount() == 2

fact_fn = lambda n: 1 if n <= 1 else n * fact_fn(n - 1)
memoFact = memoize(fact_fn)
assert memoFact(2) == 2
assert memoFact(3) == 6
assert memoFact(2) == 2
assert memoFact.getCallCount() == 2
assert memoFact(3) == 6
assert memoFact.getCallCount() == 2

fib_fn = lambda n: 1 if n <= 1 else fib_fn(n - 1) + fib_fn(n - 2)
memoFib = memoize(fib_fn)
assert memoFib(5) == 8
assert memoFib.getCallCount() == 1

# Edge cases
assert memoSum(0, 0) == 0  # sum with zeros
assert memoFact(1) == 1    # factorial