diff --git a/aux/permutations.py b/aux/permutations.py new file mode 100644 index 0000000..80423f1 --- /dev/null +++ b/aux/permutations.py @@ -0,0 +1,69 @@ +"""Gives the permutations of a row + +Given a row in P notation (space separated coefficients), prints out all the P +rows resulting from label permutation. +""" + +import argparse +import itertools + + +def get_idx(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i): + return b_out_i + b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i)) + + +if __name__ == "__main__": + doclines = __doc__.splitlines() + description = doclines[0] + epilog = "\n".join(doclines[1:]) + + parser = argparse.ArgumentParser(description=description, epilog=epilog) + parser.add_argument("", type=int, help="Number of outputs for Alice.") + parser.add_argument("", type=int, help="Number of outputs for Bob.") + parser.add_argument("", type=int, help="Number of inputs for Alice.") + parser.add_argument("", type=int, help="Number of inputs for Bob.") + # parser.add_argument( + # "--overwrite", action="store_true", help="Overwrite an existing file." + # ) + args = vars(parser.parse_args()) + a_out = args[""] + b_out = args[""] + a_in = args[""] + b_in = args[""] + + row = [int(x) for x in input().split(',')] + + b_out_permutations = list(itertools.permutations(range(b_out))) + a_out_permutations = list(itertools.permutations(range(a_out))) + + for a_in_perm in itertools.permutations(range(a_in)): + for b_in_perm in itertools.permutations(range(b_in)): + for a_out_perms in itertools.product(a_out_permutations, repeat=a_in): + for b_out_perms in itertools.product(b_out_permutations, repeat=b_in): + perm_row = ["-" for _ in range(len(row))] + for a_in_i in range(a_in): + for b_in_i in range(b_in): + for a_out_i in range(a_out): + for b_out_i in range(b_out): + u = get_idx( + a_out, + b_out, + a_in, + b_in, + a_out_i, + b_out_i, + a_in_i, + b_in_i, + ) + v = get_idx( + a_out, + b_out, + a_in, + b_in, + a_out_perms[a_in_perm[a_in_i]][a_out_i], + b_out_perms[b_in_perm[b_in_i]][b_out_i], + a_in_perm[a_in_i], + b_in_perm[b_in_i], + ) + perm_row[v] = row[u] + print(perm_row) diff --git a/aux/swaptest.py b/aux/swaptest.py new file mode 100644 index 0000000..8d7ac75 --- /dev/null +++ b/aux/swaptest.py @@ -0,0 +1,69 @@ +"""Simple test script to check that the defined conditions go over all swapping pairs. + +This is just to help me reason.""" + +import argparse + +if __name__ == "__main__": + doclines = __doc__.splitlines() + description = doclines[0] + epilog = "\n".join(doclines[1:]) + + parser = argparse.ArgumentParser(description=description, epilog=epilog) + parser.add_argument( + "", type=int, help="Number of outputs for both parties." + ) + parser.add_argument("", type=int, help="Number of inputs for both parties.") + # parser.add_argument( + # "--overwrite", action="store_true", help="Overwrite an existing file." + # ) + args = vars(parser.parse_args()) + out = args[""] + in_ = args[""] + + marked = list( + "x" if (a_in == b_in and a_out == b_out) else "-" + for a_in in range(in_) + for a_out in range(out) + for b_in in range(in_) + for b_out in range(out) + ) + + def idx(a_out, b_out, a_in, b_in): + return b_out + out * (b_in + in_ * (a_out + out * a_in)) + + def swap(a_out_i, b_out_i, a_in_i, b_in_i): + u = idx(a_out_i, b_out_i, a_in_i, b_in_i) + v = idx(b_out_i, a_out_i, b_in_i, a_in_i) + #print(f"({a_out_i}, {b_out_i} | {a_in_i}, {b_in_i})") + #print(u, v) + + if marked[u] != "-" or marked[v] != "-": + print(f"FAILED for ({a_out_i}, {b_out_i} | {a_in_i}, {b_in_i})") + print(marked) + exit(1) + + marked[u] = "x" + marked[v] = "x" + + for a_in_i in range(in_): + for b_in_i in range(a_in_i): + for a_out_i in range(out): + for b_out_i in range(out): + swap(a_out_i, b_out_i, a_in_i, b_in_i) + + for a_out_i in range(out): + for b_out_i in range(a_out_i): + swap(a_out_i, b_out_i, a_in_i, a_in_i) + + if not all(x == "x" for x in marked): + print("FAILED missing entries") + for a_out_i in range(out): + for b_out_i in range(out): + for a_in_i in range(in_): + for b_in_i in range(in_): + if marked[idx(a_out_i, b_out_i, a_in_i, b_in_i)] == "-": + print(f"({a_out_i}, {b_out_i} | {a_in_i}, {b_in_i})") + exit(1) + + print("SUCCESS")