LeetCode 2572 - Count the Number of Square-Free Subsets

The problem asks us to count all non-empty subsets of a given array nums such that the product of the elements in each subset is square-free. A square-free integer is an integer not divisible by the square of any prime greater than 1.

LeetCode Problem 2572

Difficulty: 🟡 Medium
Topics: Array, Math, Dynamic Programming, Bit Manipulation, Number Theory, Bitmask

Solution

Problem Understanding

The problem asks us to count all non-empty subsets of a given array nums such that the product of the elements in each subset is square-free. A square-free integer is an integer not divisible by the square of any prime greater than 1. In other words, in its prime factorization, no prime appears more than once.

The input is an array nums of length up to 1000, where each element satisfies 1 <= nums[i] <= 30. The output should be the total number of square-free subsets, modulo 10^9 + 7. Subsets are considered different if they use different indices in the array, even if they contain the same values.

The constraints are important: although nums can have up to 1000 elements, each element is small (max 30). This hints that solutions can leverage precomputation based on prime factorization of numbers 1 to 30. We also need to consider the edge cases: numbers with squared prime factors (like 4, 8, 9, 12, 16, etc.) cannot appear in any square-free subset since they automatically make the product non-square-free.

Approaches

Brute Force Approach: We could generate all possible subsets of nums (there are 2^n - 1 non-empty subsets) and for each subset, compute the product and check whether it is square-free. While this would give the correct answer, it is too slow for n = 1000 since 2^1000 is astronomically large.

Optimal Approach: The key insight is to use bitmasking with dynamic programming based on the prime factorization of numbers from 1 to 30. Since there are only 10 primes ≤ 30, we can encode the presence of each prime in a subset as a bit in a 10-bit mask. We iterate over nums, skip numbers with squared prime factors, and dynamically update counts of valid subsets represented by their prime masks. This avoids computing actual products and efficiently counts all square-free combinations.

Approach Time Complexity Space Complexity Notes
Brute Force O(2^n * n) O(1) Generates all subsets and checks if product is square-free; infeasible for n=1000
Optimal O(n * 2^10) O(2^10) Uses DP with bitmasking on prime factorization of numbers 1-30; feasible due to small number of primes

Algorithm Walkthrough

  1. Identify all primes ≤ 30: The primes are [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]. Assign each prime an index from 0 to 9.

  2. Precompute number masks: For each integer from 1 to 30, create a bitmask representing its prime factors. If a number contains any squared prime factor (like 4 = 2^2), mark it invalid.

  3. Initialize DP table: Let dp[mask] store the count of square-free subsets whose product has a prime factor mask mask. Initially, dp[0] = 1 representing the empty subset.

  4. Iterate over nums:

  5. For each number, skip it if it is invalid (has squared prime factors).

  6. Get the number’s prime mask.

  7. Iterate over existing dp masks in reverse to avoid double-counting.

  8. For each existing mask, if it does not conflict with the number's prime mask (no overlapping primes), add the count to the new mask (existing_mask | number_mask).

  9. Sum DP values: After processing all numbers, sum all dp[mask] values except dp[0] to count all non-empty square-free subsets.

  10. Return result modulo 10^9 + 7.

Why it works: The DP table tracks subsets by prime composition, ensuring no subset contains repeated prime factors. By combining masks only when they are disjoint, we guarantee all counted subsets are square-free. Iterating in reverse prevents double-counting during updates.

Python Solution

from typing import List

class Solution:
    def squareFreeSubsets(self, nums: List[int]) -> int:
        MOD = 10**9 + 7
        primes = [2,3,5,7,11,13,17,19,23,29]
        # Map each number to its prime factor mask; -1 means it contains a squared prime
        num_mask = {}
        for num in range(1, 31):
            mask = 0
            x = num
            valid = True
            for i, p in enumerate(primes):
                cnt = 0
                while x % p == 0:
                    x //= p
                    cnt += 1
                if cnt > 1:
                    valid = False
                    break
                elif cnt == 1:
                    mask |= (1 << i)
            if valid:
                num_mask[num] = mask
        
        dp = {0: 1}  # mask -> count of subsets
        for num in nums:
            if num not in num_mask:
                continue
            nmask = num_mask[num]
            dp_items = list(dp.items())
            for mask, count in dp_items:
                if mask & nmask == 0:  # No overlapping primes
                    new_mask = mask | nmask
                    dp[new_mask] = (dp.get(new_mask, 0) + count) % MOD
        
        return (sum(dp.values()) - 1) % MOD  # exclude empty subset

Implementation Notes: First, we precompute masks for 1-30, marking numbers with squared primes as invalid. The DP dictionary stores counts keyed by prime masks. For each number, we merge it with existing masks only if primes do not overlap, which guarantees square-free subsets. Finally, we subtract 1 to exclude the empty subset.

Go Solution

func squareFreeSubsets(nums []int) int {
    MOD := 1_000_000_007
    primes := []int{2,3,5,7,11,13,17,19,23,29}
    
    numMask := make(map[int]int)
    for num := 1; num <= 30; num++ {
        mask := 0
        x := num
        valid := true
        for i, p := range primes {
            cnt := 0
            for x%p == 0 {
                x /= p
                cnt++
            }
            if cnt > 1 {
                valid = false
                break
            } else if cnt == 1 {
                mask |= 1 << i
            }
        }
        if valid {
            numMask[num] = mask
        }
    }
    
    dp := make(map[int]int)
    dp[0] = 1
    for _, num := range nums {
        nmask, ok := numMask[num]
        if !ok {
            continue
        }
        temp := make(map[int]int)
        for mask, count := range dp {
            if mask & nmask == 0 {
                newMask := mask | nmask
                temp[newMask] = (temp[newMask] + count) % MOD
            }
        }
        for k, v := range temp {
            dp[k] = (dp[k] + v) % MOD
        }
    }
    
    res := 0
    for mask, count := range dp {
        if mask != 0 {
            res = (res + count) % MOD
        }
    }
    return res
}

Go-Specific Notes: Go uses a map for DP instead of Python’s dictionary. To avoid updating the map while iterating, we use a temporary map temp for new combinations and then merge it. We also ensure we handle integer modulo arithmetic correctly.

Worked Examples

Example 1: nums = [3,4,4,5]

Step num valid? nmask DP before DP after
1 3 yes 0010 {0:1} {0:1, 0010:1}
2 4 no - {0:1, 0010:1} unchanged
3 4 no - {0:1, 0010:1} unchanged
4 5 yes 0100 {0:1, 0010:1} {0:1, 0010:1, 0100:1, 0110:1}

Sum all counts except dp[0] = 1: 1+1+1 = 3.

Example 2: nums = [1]

Step num valid? nmask DP before DP after
1 1 yes 0 {0:1} {0:2}

Exclude empty subset: 2 - 1 = 1.

Complexity Analysis

Measure Complexity Explanation
Time O(n * 2^10) For each number, iterate over all 1024 possible masks (10