Count Partitions (IB and CodeForces)

Problem

You've got array a[1], a[2], ..., a[n], consisting of n integers. Count the number of ways to split all the elements of the array into three contiguous parts so that the sum of elements in each part is the same.

Example:

Input: nums = [1, 2, 3, 0, 3]
Output: 2
Explaination:  There are 2 ways to make 
partitions -
 1. (1,2)+(3)+(0,3).
 2. (1,2)+(3,0)+(3).
 Input: nums = [0, 1, -1, 0]
 Output: 1
 Explaination: There is only 1 way to make 
 partition -
 1. (0)+(-1,1)+(0).

Thought Process

  • Since the array must be split into 3 equal parts, the sum of the whole array must be divisible by 3.

  • If the current sum of nums[i] is 2/32/3 of the total sum, we know we just found one valid partition (leaving the unvisited elements summing up to the rest of the 1/31/3 of the total sum). Since we just found this unique partition by adding this nums[i] element to the current sum, we can conclude that there are oneThirdSum distinct pairs that have 1/31/3 of the total sum in each of them, leaving the other 1/31/3 part being after nums[i].

  • Why does this work?

    • Each valid partition must have the 2nd part (the second 1/3 sum) on an element between nums[1] to nums[n-2]. And each time the 2nd part ends up on one of these elements, it generates some more unique partitions (the 2nd part will be different). After adding one element to the current sum and it becomes 2/32/3 of the total sum, we just found such a unique partition. Now we just need to add the count of this (1/3,(1/3,2/3)2/3)pair. Since the 2nd part is fixed, the count is just how many prefix sums of 1 / 3 of total sum we've seen so far.

    • One key note here is that we must check if the current sum is 2/3 of total sum first, then check if 1/3. This ensures we don't have 1/3 and 2/3 partition ends on the same spot, for the corner case of 0 total sum.

Solution

class Solution:
    def solve(self, n, nums):
        
        if sum(nums) % 3 != 0:
            return 0
        
        cSum = 0
        oneThirdSum = sum(nums)//3
        count = 0
        oneThird = 0
        
           
        for i in range(len(nums)-1):
            cSum+=nums[i]
            if cSum == 2*(oneThirdSum):
                count+=oneThird
                
            if cSum == oneThirdSum:
                oneThird+=1
            

        return count

Key Facts

  • We only increase total count when we come across a 2/3 sum element because:

    • the previous 1/3 sum element we found contributes to the first part of the partiton, the 2/3 sum element we found contributes to the middle part of the partiton, and the elements after this 2/3 sum part contributes to the end part of the partiton. Each sum of these parts are equal to 1/3 of the total sum of the array.

Time Complexity

  • Time: O(n)O(n)

  • Space: O(1)O(1)

Last updated