from functools import reduce
import argparse
import signal
import sys

def signal_handler(sig, frame):
    b = [a[0]]
    for i in range(1,len(a)):
        b.append(a[i] - a[0])
    print("\nFinishing at {}, found collapses: {}".format(b,nc))
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

parser = argparse.ArgumentParser(description='Search for collapsing solutions.')
parser.add_argument('--boxes', '-b', type=int, dest='boxes', help='Number of boxes')
parser.add_argument('--colours', '-c', type=int, dest='colours', help='Number of colours')
parser.add_argument('--start', '-s', type=int, nargs='+', action='append', dest='start', help='Starting configuration')

args = parser.parse_args()

c = args.colours or 3
pd = 1

if args.start == None:
    a = [1]
    for i in range(c):
        a.append(2)
else:
    a = args.start[0]
    for i in range(1,len(a)):
        a[i] += a[0]
    c = len(a) - 1

n = args.boxes or c
nc = [0]*(c+1)

b = [a[0]]
for i in range(1,len(a)):
    b.append(a[i] - a[0])
print("Starting at {} with {} boxes and {} colours".format(b,n,c))

masks = []
for i in range(c+1):
    masks.append([])
for i in range(2**c):
    m = []
    b = i
    for j in range(c):
        m.append(b%2)
        b = b >> 1
    masks[sum(m)].append(m)

def gcd(a,b):
    if a < b:
        return gcd(b,a)
    if b == 0:
        return a
    return gcd(a - b,b)

def showList(b):
    o = []
    global pd
    for i in b:
        if len(str(i)) > pd:
            pd = len(str(i))
    for i in b:
        j = str(i)
        o.append(" " * (pd - len(j)) + j)
    print(", ".join(o),end="\r")
                 


while True:
    b = [a[0]]
    for i in range(1,len(a)):
        b.append(a[i] - a[0])
    showList(b)
    if reduce((lambda x, y: gcd(x,y)), b) == 1 :
        lhs = reduce((lambda x,y: x + y), b)**n
        rhs = 0
        sgn = 1
        for i in range(c-1,0,-1):
            for m in masks[i]:
                trm = b[0]
                for j in range(c):
                    trm += m[j] * b[j+1]
                rhs += sgn * trm**n
            if lhs == rhs:
                print("{} collapses at level {} with {} = {}".format(str(b),c-i+1,lhs,rhs))
                nc[c+1-i] += 1
            sgn *= -1
    inc = True
    i = 0
    while inc:
        if i == len(a) - 1 or a[i] + 1 < a[i+1]:
            a[i] += 1
            for j in range(i):
                a[j] = 1 + min(j,1)
            inc = False
        i += 1
    if a[len(a) - 1] == 400:
        break
