import itertools as it
from time import time
import sys
from config import chars, ords, RTList, RFList, RTqstart, RTDict
from tree_creator_base_final import RFHelper, InitialiseRTList

'''
Copyright Dr Paul Brown and Prof. Trevor Fenner 2020.
'''

def RFHelper_adjmat(n, q, L):
    # generate 2-tuples comprising the canonical weight sequence of the tree
    # with the root removed together with the adjacency matrix of the rooted
    # tree; this routine is only used in UFT (i.e., at the top level)

    # block of zeros for the vertices in the first subtree of the root
    a_zeros = [0] * q
    # block of zeros for the vertices in the other subtrees of the root
    b_zeros = [0] * (n - q - 1)

    if q == 2:
        for t in range((n - 1) // 2, -1, -1):
            adjmat = [[0] * n for i in range(n)]
            # weight 2 vertices
            for i in range(1, 2 * t + 1, 2):
                adjmat[0][i] = 1        # edge from root
                adjmat[i][0] = 1        # edge to root
                adjmat[i][i + 1] = 1    # edge to leaf
                adjmat[i + 1][i] = 1    # edge from leaf
            # leaves
            for i in range(2 * t + 1, n):
                adjmat[0][i] = 1        # edge to leaf
                adjmat[i][0] = 1        # edge from leaf
            yield '21' * t + '1' * (n - 1 - 2 * t), adjmat

    elif q == (n - 1) // 2:
		# construct the first and last n - q - 1 rows of each adjacency matrix
        # using the cache RTDict, by inserting q zeros into columns 1 to q
        # of the matrix of the subtree of order n - q, corresponding to b
        ls_adjs = [[[vx_row[0]] + a_zeros + vx_row[1:]
                            for vx_row in RTDict[t]] for t in RTList[n - q]]
		# add link from each centroid to its first child
        for adj in ls_adjs: adj[0][1] = 1
        start = 0
        if n % 2 == 0: start = len(RTList[n // 2])
        for a in RTList[q]:
            # append n - q - 1 zeros to each of rows 1 to q of the
            # adjacency matrix of the first subtree
            a_mat = [[0] + vx_row + b_zeros for vx_row in RTDict[a]]
            # add edge from the root of the first subtree to the root
            # note that a_mat[0][0] corresponds to the [1][0]
            # element of the adjacency matrix of the whole tree
            a_mat[0][0] = 1
            # enumerate RFList[n - q][start:] to synchronise with ls_adjs
            for i, b_hash in enumerate(RFList[n - q][start:], start):
                b_mat = ls_adjs[i]
                # b_mat[0] is the row corresponding to the root
                # a_mat corresponds to the rows of the first subtree
                # b_mat[1:] corresponds to the rows of the rest of the tree
                yield a + b_hash, [b_mat[0]] + a_mat + b_mat[1:]
            start +=1

    elif q >= n - L:
        ls_adjs = [[[vx_adj[0]] + a_zeros + vx_adj[1:]
                            for vx_adj in RTDict[t]] for t in RTList[n - q]]
        for adj in ls_adjs: adj[0][1] = 1
        for a in RTList[q]:
            a_mat = [[0] + vx_row + b_zeros for vx_row in RTDict[a]]
            # add edge from the root of the first subtree to the root
            a_mat[0][0] = 1
            start = RTqstart[n - q][q]
            a_sentinel = a + 'z'
            for start, bhash in enumerate(RFList[n - q][start:], start):
                if a_sentinel >= bhash: break
            for i, b_hash in enumerate(RFList[n - q][start:], start):
                b_mat = ls_adjs[i]
                yield a + b_hash, [b_mat[0]] + a_mat + b_mat[1:]

    # recursive case
    else:
        RFGenArr = [None for _ in range(q + 1)]
        # allocate array for generators of adjacency matrices
        adjgens = [None for _ in range(q + 1)]
        for r in range(2, q + 1):
            RFGenArr[r] = RFHelper(n - q, r, L)
            RFGenArr[r], RFGA = it.tee(RFGenArr[r], 2)
            adjgens[r] = (Adjsplit(f, q + 1, n) for f in RFGA)
        for a in RTList[q]:
            a_sentinel = a + 'z'
            a_mat = [[0] + vx_row + b_zeros for vx_row in RTDict[a]]
            # add edge from the root of the first subtree to the root
            a_mat[0][0] = 1
            for r in range(q, 1, - 1):
                RFGenArr[r], RFHelperListGen = it.tee(RFGenArr[r], 2)
                # make a copy of the adjacency matrices generator
                adjgens[r], gen_adjs = it.tee(adjgens[r], 2)
                if r == q:
                    cen_row, bhash_mat = next(gen_adjs)
                    for b_hash in RFHelperListGen:
                        if a_sentinel >= b_hash: break
                        cen_row, bhash_mat = next(gen_adjs)
                    yield a + b_hash, [cen_row] + a_mat + bhash_mat
                for b_hash in RFHelperListGen:
                    cen_row, bhash_mat = next(gen_adjs)
                    yield a + b_hash, [cen_row] + a_mat + bhash_mat


def AdjMatFromWS(tree):
    # returns adjacency matrix from the weight sequence of the tree

    n = len(tree)
    A = [[0] * n for _ in range(n)]
    ws = [ords[t] for t in tree]
    for i in range(n):
        j = i + 1
        # loop through weight sequence for this subtree
        # add in adjacencies for vertex i
        while j < i + ws[i]:
            A[i][j] = 1
            A[j][i] = 1
            j += ws[j]        # get next neighbour of i
    # add edge joining the two centroids where applicable
    if ws[0] == n / 2:
        A[0][n / 2] = 1
        A[n // 2][0] = 1
    return A

def Adjsplit(ws, stpt, n):
    # returns a 2-tuple comprising the row of the adjacency matrix
    # corresponding to the root together with the rows of the adjacency
    # matrix corresponding to bhash using the cache RTDict

    # ws is the weight sequence bhash, and stpt is the start index of
    # bhash in the weight sequence of the whole tree
    forest_mat = []
    root_mat = [0] * n
    root_mat[1] = 1
    i = 0
    # i iterates through the indices of the children of the root
    while i < len(ws):
        k = ords[ws[i]]
        # get the adjacency matrices of the subtrees of the root from the
        # cache RTDict, adding in zeros at the beginning of each row
        mat = [[0] * (stpt + i) + vx_adj + [0] * (n - stpt - i - k)
                        for vx_adj in RTDict[ws[i: i + k]]]
        # add back-edge from child to root
        mat[0][0] = 1
        forest_mat += mat
        # append edge from root to child
        root_mat[i + stpt] = 1
        # move to next child
        i += k
    return root_mat, forest_mat

def InitialiseRTDict(L):
    # compute adjacency matrices for rooted trees up to order L

    # initialise dictionary cache of adjacency matrices for small trees
    RTDict['1'] = [[0]]
    RTDict['21'] = [[0, 1], [1, 0]]
    RTDict['321'] = [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
    RTDict['311'] = [[0, 1, 1], [1, 0, 0], [1, 0, 0]]
    RTDict['4321'] = [[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0]]
    RTDict['4311'] = [[0, 1, 0, 0], [1, 0, 1, 1], [0, 1, 0, 0], [0, 1, 0, 0]]
    RTDict['4211'] = [[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]]
    RTDict['4111'] = [[0, 1, 1, 1], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]]

    # compute adjacency matrices for rooted trees of order > 4
    for k in range(5, L + 1):
        for rt in RTList[k]:
            RTDict[rt] = AdjMatFromWS(rt)

def UFT(n, L):
    # generate 2-tuples comprising the canonical weight sequence of
    # the unicentroidal tree together with its adjacency matrix

    rootchar = chars[n]
    for q in range((n - 1) // 2, 1, -1):
        for a, a_mat in RFHelper_adjmat(n, q, L): yield rootchar + a, a_mat

def BFT(n):
    # generate 2-tuples comprising the canonical weight sequence of
    # the bicentroidal tree together with its adjacency matrix

    ho_zeros = [0] * (n // 2)
    # make a duplicate dictionary of the adjacency matrices for the subtree
    # rooted at the second centroid, prepending n / 2 zeros to each row
    RTDict_2 = {}
    for rt in RTList[n // 2]:
        RTDict_2[rt] = [ho_zeros + vx_adj for vx_adj in RTDict[rt]]
        # add in edge joining the second centroid to the first
        RTDict_2[rt][0][0] = 1

    for i, a_1 in enumerate(RTList[n // 2]):
        adj_1 = [vx_adj + ho_zeros for vx_adj in RTDict[a_1]]
        adj_1[0][n // 2] = 1
        for a_2 in RTList[n // 2][i:]:
            # get the adjacency matrix of the subtree rooted at
            # the second centroid
            adj_2 = [t for t in RTDict_2[a_2]]
            yield a_1 + a_2, adj_1 + adj_2

def FreeTrees(n, L):
    # generate 2-tuples comprising the canonical weight sequence of
    # the free tree together with its adjacency matrix for all n

    if n == 1: yield '1', [[0]]
    elif n == 2: yield '11', [[0, 1], [1, 0]]
    elif n == 3: yield '311', [[0, 1, 1], [1, 0, 0], [1, 0, 0]]
    elif n == 4:
        yield '4111', [[0, 1, 1, 1], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]]
        yield '2121', [[0, 1, 1, 0], [1, 0, 0, 0], [1, 0, 0, 1], [0, 0, 1, 0]]
    else:
        # add bicentroidal trees when n is even
        for a, a_adj in UFT(n, L): yield a, a_adj
        if n % 2 == 0:
            for a, a_adj in BFT(n): yield a, a_adj

def main(N):

    sttime = time()
    count = 0
    L = (N // 2) + 1
    # Initialise Cache for order <= L
    InitialiseRTList(L)
    InitialiseRTDict(L)
    for g in FreeTrees(N, L):
        # print(g)
        count += 1
    print('BRFE(adjmat)', N, count, time() - sttime)


if __name__ == '__main__':
    # to check orders up to 27
    for i in range(28):
       main(i)
    # to run from command line use:
    # main(int(sys.argv[1]))

