Generic Strategy to solve K-Sum Problems

Generic Strategy to solve K-Sum Problems

Introduction

Hey there! Welcome back If you are preparing for tech interviews you must have come across the problems like 2-sum, 3-sum, 4-sum etc, if not then you will be learning the same here.

Not only that in this blog I will be discussing a generic strategy to solve these kinds of problems. Let's put them in a single category called k-sum problems.

After you learn this strategy, you will be able to come up with the solution to any problem in this category be it 2-sum, 3-sum or 4-sum as far as the solution runs under the constraints of time and space.

Let's delve into the process

2-Sum Problem

In this problem, we are given an array of N integers let's call it nums and we want to find the number of unique pairs of integers (nums[i], nums[j]) such that (i != j) they have a sum equal to the target, there can be duplicate integers in the given array.

Let's understand this with an example

nums = [2, 7, 11, 15], target = 9

Here, the answer pair is the integer number one and two i.e. (2,7).

Let's discuss the solution to this problem.

Naive Solution:

The most straightforward solution to the problem would be to use nested loops where we select each possible pair of integers for the array and there would be n^2 pairs.

int n = nums.size()
//Outer loop which selects each element of the array
for(int i=0;i<n;i++){

   //Inner nested loop which selects second element of pair
   for(int j=i+1; j < n; j++){
       if(nums[i] + nums[j] == target){
           answer.push({nums[i],nums[j]}) 
       }
   }
}

The time complexity for this solution is O(N^2) because we are using a nested loop to find the pair of elements while the space complexity is O(1) because we are not using any extra space to solve the problem.

Binary Search Solution:

Observe that we are given an integer target and if we sort the array and start selecting integers one by one say nums[i] then in the rest of the array i.e. [i+1, n-1] we need to search for the target - nums[i] and if we find it then we have (nums[i], target - nums[i]) as a valid pair.

//Sort the given nums array
sort(nums.begin(), nums.end());

//Find the size of the nums array
n = nums.size();

//Iterate over the elements of nums and binary search for 
// the rest of the array
for(int i=0;i<n;i++){
   //Find the second element of the pair
   if(binary_search(nums.begin(), nums.end(), target - nums[i]) == NULL){
       answer.push_back({nums[i], target - nums[i]});
   }
}

The time complexity for this approach is O(NlogN) as we are performing a binary search for N-1 times and the time complexity for the binary search is O(logN). While the space complexity for the given binary search is O(1)

2-Pointers Solution:

We can observe that in a sorted array of size n the first element is the smallest and the last element is the largest, if we add them we get some value say X = nums[0] + nums[n-1], Now if X is larger than the target then that means we should not use the largest element to evaluate our target sum, then which one to choose just smaller than the largest so we pick second largest element nums[n-2] and compare the new sum X = nums[0] + nums[n-2] again and so on until we get the sum larger than the target.

On the other hand, if X = nums[0] + nums[n-1] is smaller than the target that means we must include a larger number in our X, which one is the best choice the next greater number which is second hence giving the new sum as X = nums[1] + nums[n-1] and compare it again with target and so on.

In this way, I will eventually arrive at my target pair if it exists.

Implementation in C++:

#include<bits/stdc++.h>
using namespace std;

vector<int>twoSum(vector<int>&nums, int target){
    //Calculate the size of the nums array
    int n = nums.size();

    //Initialize two pointers left and right
    int left = 0, right = n-1;

    //Vector to store the two sum pairs
    vector<int>result;

    //Iterate till these pointers do not match
    while(left < right){
         //check if the sum of two integers 
         // is less than target
         if(nums[left] + nums[right] == target){
             result.push_back({nums[left], nums[right]});
             // Now handle the duplicate integers
            while(left + 1 < n and nums[left] == nums[left+1])left++;
            left++;
            right--;
         }
         else if(nums[left] + nums[right] < target)
            left++;
         else
            right--;
    }
    return result;
}

Implementation in Java:

 import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

class Solution {
    public List<List<Integer>> twoSum(List<Integer> nums, int target) {
        // Calculate the size of the nums list
        int n = nums.size();

        // Initialize two pointers left and right
        int left = 0, right = n - 1;

        // List to store the two sum pairs
        List<List<Integer>> result = new ArrayList<>();

        // Iterate until these pointers do not match
        while (left < right) {
            // Check if the sum of two integers is equal to the target
            if (nums.get(left) + nums.get(right) == target) {
                result.add(Arrays.asList(nums.get(left), nums.get(right)));

                // Handle duplicate integers
                while (left + 1 < n && nums.get(left).equals(nums.get(left + 1))) {
                    left++;
                }
                left++;
                right--;
            } else if (nums.get(left) + nums.get(right) < target) {
                left++;
            } else {
                right--;
            }
        }

        return result;
    }
}

Implementation in Python:

 from typing import List

def twoSum(nums: List[int], target: int) -> List[List[int]]:
    # Calculate the size of the nums list
    n = len(nums)

    # Initialize two pointers left and right
    left = 0
    right = n - 1

    # List to store the two sum pairs
    result = []

    # Iterate until these pointers do not match
    while left < right:
        # Check if the sum of two integers is equal to the target
        if nums[left] + nums[right] == target:
            result.append([nums[left], nums[right]])

            # Handle duplicate integers
            while left + 1 < n and nums[left] == nums[left + 1]:
                left += 1
            left += 1
            right -= 1
        elif nums[left] + nums[right] < target:
            left += 1
        else:
            right -= 1

    return result

Time and Space Complexity:

The time complexity for the given implementation is O(N) as we are iterating over each element exactly once using 2 pointers while the space complexity is O(1) as we are not using any extra space.

Now, this was a popular 2-sum problem, let's move to the next step i.e. 3-sum problem.

3-Sum Problem

In this problem, we are given an integer array of nums, we need to return all triplets [nums[i], nums[j], nums[k]] such that i != j, i != k, and j!= k, and nums[i] + nums[j] + nums[k] == 0.

Notice that the solution set must not contain duplicate triplets.

Let's understand the problem with an example nums = [-1, 0, 1, 2, -1, -4], and the target sum is 0, then the possible triplets would be [[-1, -1, 2], [-1, 0, 1]].

Let's understand the solution to this problem, first of all, we sort the given array as we did for the 2-sum problem, then we can observe that we need to choose three integers that sum up to 0, let's see how to do so, first, we iterate over the array and choose integers nums[i] one by one starting from the first till 3rd last i.e. i = 0 to i = n-2, we are assuming it to be our first integer in the tuple.

then, we require a pair of integers in the remaining array from i+1 to n-1 to make the entire tuple. Then, the remaining problem is the same as a 2-sum problem.

//Iterate over all the elements of the array till 3rd last
for(int i = 0; i < n-1; i++){
   //Choose integers one by one
   int remain = target - nums[i];

   //Find the pair of integers which sums up to remain
   twoSum(nums, remain);

  //If i get a pair which sums up to the remain
  //then i got one of my target tuples
  answer.push({nums[i], b, c});

  //here b,c represent the two sum pair for remain
}

Is it all we need to do or have we missed something, obviously there are duplicates in the array and we don't want our final answer to contain duplicate tuples, therefore we need to skip the similar integers.

//Iterate over all the elements of the array till 3rd last
for(int i = 0; i < n-1; i++){
   //Choose integers one by one
   int remain = target - nums[i];

   //Find the pair of integers which sums up to remain
   twoSum(nums, remain);

  //If i get a pair which sums up to the remain
  //then i got one of my target tuples
  answer.push({nums[i], b, c});

  //here b,c represent the two sum pair for remain

  //Here, we need to skip the similar integers
  while(i + 1 < n  and nums[i] == nums[i+1])i++;
}

In this way, I will have my tuples pushed inside the answer array.

Implementation in C++:

#include<bits/stdc++.h>
using namespace std;

vector<vector<int>>threeSum(vector<int>nums, int target){
    //Find the size of the nums array
    int n = nums.size();

   //Sort the given array
   sort(nums.begin(), nums.end());

   //Declare a vector to store the tuples which sums up to target
   vector<vector<int>>result;

   //Iterate over each element of the nums array
   //till n-2th element
   for(int i=0;i<n-1;i++){
       //Initialize two pointers to find the two sum pair
       int left = i+1, right = n-1;
       while(left < right){
          //find the sum of the three integers
          int sum = nums[i] + nums[left] + nums[right];

          //If the sum is equal to the target
          if(sum == target){
              result.push_back({nums[i], nums[left], nums[right]});
              while(left + 1 < n and nums[left] == nums[left + 1])left++;
              left++;
              right--;
          }
          else if(sum < target){
             left++;
          }else{
             right--;
          }
       }
      while(i + 1 < n and nums[i] == nums[i+1])i++;
   }

      return result;
}

Implementation in Java:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ThreeSum {
    public static List<List<Integer>> threeSum(List<Integer> nums, int target) {
        // Find the size of the nums array
        int n = nums.size();

        // Sort the given array
        nums.sort(null);

        // Declare a list to store the tuples which sums up to target
        List<List<Integer>> result = new ArrayList<>();

        // Iterate over each element of the nums array
        // till n-2th element
        for (int i = 0; i < n - 1; i++) {
            // Initialize two pointers to find the two sum pair
            int left = i + 1, right = n - 1;
            while (left < right) {
                // Find the sum of the three integers
                int sum = nums.get(i) + nums.get(left) + nums.get(right);

                // If the sum is equal to the target
                if (sum == target) {
                    result.add(Arrays.asList(nums.get(i), nums.get(left), nums.get(right)));
                    while (left + 1 < n && nums.get(left).equals(nums.get(left + 1)))
                        left++;
                    left++;
                    right--;
                } else if (sum < target) {
                    left++;
                } else {
                    right--;
                }
            }
            while (i + 1 < n && nums.get(i).equals(nums.get(i + 1)))
                i++;
        }

        return result;
    }

    public static void main(String[] args) {
        List<Integer> nums = Arrays.asList(-1, 0, 1, 2, -1, -4);
        int target = 0;
        List<List<Integer>> result = threeSum(nums, target);
        for (List<Integer> tuple : result) {
            System.out.println(tuple);
        }
    }
}

Implementation in Python:

def threeSum(nums, target):
    # Sort the given array
    nums.sort()

    # Declare a list to store the tuples which sum up to target
    result = []

    # Iterate over each element of the nums array
    # till n-2th element
    n = len(nums)
    for i in range(n - 2):
        # Initialize two pointers to find the two sum pair
        left = i + 1
        right = n - 1
        while left < right:
            # Find the sum of the three integers
            sum = nums[i] + nums[left] + nums[right]

            # If the sum is equal to the target
            if sum == target:
                result.append([nums[i], nums[left], nums[right]])
                while left + 1 < n and nums[left] == nums[left + 1]:
                    left += 1
                left += 1
                right -= 1
            elif sum < target:
                left += 1
            else:
                right -= 1
        while i + 1 < n and nums[i] == nums[i + 1]:
            i += 1

    return result

# Example usage
nums = [-1, 0, 1, 2, -1, -4]
target = 0
result = threeSum(nums, target)
print(result)

Time and Space Complexity:

First of all, we are sorting the array which has a time complexity of O(NlogN), then we are iterating over N-3 integers in the outer loop, and inside we are calling the twoSum function which has a time complexity of O(N), Hence, the overall time complexity becomes O(N^2 + NlogN) i.e. O(N^2). While the space complexity is O(1).

Now, let's ask ourselves, how can a 4-Sum problem be solved using the same strategy in which we need to return the set of all quadruples having a sum equal to the target, now you can easily guess, the solution would be to sort the array first then use two nested loops which select the first two integers of the quadruple and then find the rest pair using twoSum. In this way, in any K-sum problem, we can use K-2 nested loops to select the first K-2 integers and then the base case of two integers is handled by twoSum function.

K-Sum Problem

In the K-sum problem, we have the same idea for larger dimensions or k- dimensions i.e. given an array of integers nums and an integer K we need to find the array of all unique sets of K integers [nums[a], nums[b], nums[c], nums[d], ..., K elements] such that:

  • a,b,c,d,...K elements are less than 'n' and greater than 0.

  • a, b, c, d up to K different indices are Unique.

  • nums[a] + nums[b] + nums[c] + nums[d] + ... + nums[K distinct indices] == target.

  • You can return the answer in any order.

Approach:

Keeping in mind the ideas that we have learned so far in this article, If we carefully observe the pattern in these problems 2-sum, 3-sum, 4-sum .. and so on we can see a recursive relationship where 2-sum is the base case as shown below:

Since this is just recursion, we can solve this for an integer K as follows

Imagine a function Ksum(nums, k, target) which has three parameters array nums, integer k and target, It returns the array of K- sum sets in the nums array, this is a recursive function. This function will select the first integer say val1, and recursively call itself for the k-1 size set with target sum as target - val1, i.e. Ksum(nums, k-1, target - val1) which will eventually select another value val2 and call itself for k-2 size and target - val1 - val2 target value as Ksum(nums, k-2, target-val1-val2) and so on till we reach the base case i.e. k == 2 where we call the twoSum function to get the target pair.

Implementation in C++

#include<bits/stdc++.h>
using namespace std;

//answer array to store the sets 
vector<vector<int>>ans;

//Two sum function
void twoSum(vector<int>&nums, int left, int target,vector<int>&path){
    int start = left, end = nums.size()-1;
    while(start < end){
         long long sum = nums[start] + nums[end];
         if(sum == target){
             path.push_back(nums[start]);
             path.push_back(nums[end]);
             ans.push_back(path);
             path.pop_back();
             path.pop_back();
             //skip the duplicates
             while(start + 1 < end and nums[start] == nums[start+1])start++;
             start++;
             end--;
         }else if(sum > target)end--;
         else start++;
    }
}

//K sum function to return the array of K-sum sets
void kSum(vector<int>&nums, int left, int right, int k, int target, vector<int>&path){
   //Base case
    if(k == 2){
        twoSum(nums,left,target,path);
        return;
    }

    while(left < right){

         path.push_back(nums[left]);
         int rem = target - nums[left];
         kSum(nums,left+1,right,k-1,rem,path);
         path.pop_back();
          //skip the duplicates 
         while(left+ 1 < right and nums[left] == nums[left+1])left++;
         left++;
    }
}

Time and Space Complexity:

The time complexity for the K-sum problem is O(N^k-2 * N) i.e. O(N^k-1). This is the general case, we can derive the complexities of 2-sum as O(N^2-1) that is O(N), for three sum O(N^3-1) that is O(N^2). While the space complexity for the given implementation is O(k) as we are using a path array to temporarily store the k-sum set which is counted as extra space.

Practice Problems to get in shape:

  1. Two Sum

  2. Three Sum

  3. Four Sum

  4. 3SUM

So that's all I have got in this blog, If you find it helpful then please share it with your friends and stay with me for updates in the future :D