16

I asked this question on StackOverflow, but I think here is a more appropriate place.

This is a problem from Introduction to algorithms course:

You have an array $a$ with $n$ positive integers (the array doesn't need to be sorted or the elements unique). Suggest an $O(n)$ algorithm to find the largest sum of elements that is divisible by $n$.

Example: $a = [6, 1, 13, 4, 9, 8, 25], n = 7$. The answer is $56$ (with elements $6, 13, 4, 8, 25$)

It's relatively easy to find it in $O(n^2)$ using dynamic programming and storing largest sum with remainder $0, 1, 2,..., n - 1$.

Also, if we restrict attention to a contiguous sequence of elements, it's easy to find the optimal such sequence in $O(n)$ time, by storing partial sums modulo $n$: let $S[i]=a[0]+a[1]+\dots + a[i]$, for each remainder $r$ remember the largest index $j$ such that $S[j] \equiv r \pmod{n}$, and then for each $i$ you consider $S[j]-S[i]$ where $j$ is the index corresponding to $r=S[i] \bmod n$.

But is there a $O(n)$-time solution for the general case? Any suggestions will be appreciated! I consider this has something to deal with linear algebra but I'm not sure what exactly.

Alternatively, can this be done in $O(n \log n)$ time?

2 Answers2

4

Here are a few random ideas:

  • The dynamic-programming algorithm can be flipped to look for a smallest sum instead of a largest sum. You just end up looking for a sum congruent to the remainder of the sum of the entire array, instead of one congruent to zero. If we process the elements in increasing order, this sometimes allows the dynamic algorithm to terminate before processing the entire array.

    The cost would be $O(n k)$ if we processed $k$ elements. There's not a lower bound of $\Omega(n \log n)$ on this algorithm because we don't have to sort all the elements. It only takes $O(n \log k)$ time to get the $k$ smallest elements.

  • If we cared about the set with the larget size, instead of the set with the largest sum, we might be able to use fast-fourier-transform-based polynomial multiplication to solve the problem in $O(n (\log n)^2 (\log \log n))$ time. Similar to what's done in 3SUM when the domain range is limited. (Note: use repeated squaring to do a binary search, else you'll get $O(n k (\log n) (\log \log n))$ where $k$ is the number of omitted elements.)

  • When $n$ is composite, and almost all remainders are a multiple of one of $n$'s factors, significant time might be saved by focusing on the remainders that aren't a multiple of that factor.

  • When a remainder r is very common, or there are only a few remainders present, keeping track of 'next open slot if you start from here and keep advancing by r' information can save a lot of scanning-for-jumps-into-open-spots time.

  • You can shave a log factor by only tracking reachability and using bit masks (in the flipped dynamic algorithm), then backtracking once you reach the target remainder.

  • The dynamic programming algorithm is very amenable to being run in parallel. With a processor for each buffer slot you can get down to $O(n)$. Alternatively, by using $O(n^2)$ breadth, and divide and conquer aggregation instead of iterative aggregation, the circuit depth cost can get all the way down to $O(\log^2 n)$.

  • (Meta) I strongly suspect that the problem you were given is about contiguous sums. If you linked to the actual problem, it would be easy to verify that. Otherwise I'm very surprised by how difficult this problem is, given that it was assigned in a course called "Introduction to Algorithms". But maybe you covered a trick in class that makes it trivial.

Craig Gidney
  • 5,992
  • 1
  • 26
  • 51
-1

My proposed algorithm goes as follows:

A sum is divisible by n if you only add summands which are multiples of n.

Before you start you create a hashmap with an int as key and a list of indices as value. You also create a resultlist containing indices.

You then loop over the array and add every index which mod n is zero to your result list. For every other index you do the following:

You subtract the value mod n of this index from n. This result is the key for your hashmap which stores indices for elements with the required value. Now, you add this index to the list in the hashmap and move on.

After you finished looping over the array you compute the output. You do this by sorting each list in the hashmap according to the value the index points to. Now you consider every pair in the hashmap summing up to n. So if n = 7 you search the hashmap for 3 and 4. If you got an entry in both you take the two largest values remove them from their lists and add them to your resultlist.

Last recommendation: still didn't test the algorithm, write a testcase against it using a brute force algorithm.