Python script to generate CG label permutations

This commit is contained in:
Miguel M 2023-04-26 13:09:30 +01:00
parent 77d61f36e4
commit 0b79ca9ac3
1 changed files with 170 additions and 0 deletions

170
aux/permutations.py Normal file
View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
"""Generates the permutations LUTs for the C project.
One central aspect of this project is testing for equivalences under
permutations. These permutations follow rules known ahead of time, and so they
can be precomputed into look-up tables as well. This script produces those LUTs.
The order for the symbolic entries of the probabilities vector is known, and
follows the pattern (for notation P[A outputs, B outputs | A inputs, B inputs]):
P[0,0|0,0] ... P[<A outputs>-2,<B outputs>-2|<A inputs>-1,<B outputs>-1],
P_A[0|0] ... P_A[<A outputs>-2|<A inputs>-1]
P_B[0|0] ... P_B[<B outputs>-2|<B inputs>-1]
For the elisions, the indexing order is B's outputs varying first, then B's
inputs, then A's outputs, then A's inputs. For the partials, we consider
(following the same reasoning) that the outputs vary "faster" than the inputs.
Let, for compactness, O_A/O_B be A's/B's outputs, and I_A/I_B be A's/B's
inputs. Thus, a permutation may be stored as a string of size
(O_A-1)(O_B-1)(I_A)(I_B) + (O_A-1)(I_A) + (O_B-1)(I_B)
Where entry i gives the new index j after the permutation.
"""
import argparse
import os
import sys
import datetime
import itertools
THIS_ROOT = os.path.dirname(os.path.realpath(__file__))
OUT_DIRECTORY = os.path.realpath(os.path.join(THIS_ROOT, "../contrib/permutation"))
FILE_NAME = "permutations_{}ao_{}bo_{}ai_{}bi"
HEADER_TEMPLATE = """// Generated by permutations.py on {date}
{doc}
#ifndef ACED_CONTRIB_PERMUTATION_{file_name}_H_
#define ACED_CONTRIB_PERMUTATION_{file_name}_H_
#include <stddef.h>
#define PERM_ARRAY kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}
// Array of permutations.
// Each permutation is an array of size `kPermLen`, implying the two-row form of
// the permutation (i.e., giving the index into which that entry maps).
const size_t kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}[{perm_count}][{perm_len}] = {{\n{perm_arr}\n}};
// Size of `kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}`.
const size_t kPermsLen = {perm_count};
// Size of any array in `kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}`.
const size_t kPermLen = {perm_len};
// Number of input labels for A.
const size_t kAInputs = {a_in};
// Number of output labels for A.
const size_t kAOutputs = {a_out};
// Number of input labels for B.
const size_t kBInputs = {b_in};
// Number of output labels for B.
const size_t kBOutputs = {b_out};
#endif // ACED_CONTRIB_PERMUTATION_{file_name}_H_"""
def main():
doclines = __doc__.splitlines()
description = doclines[0]
epilog = "\n".join(doclines[1:])
parser = argparse.ArgumentParser(description=description, epilog=epilog)
parser.add_argument("<A outputs>", type=int, help="Number of outputs for Alice.")
parser.add_argument("<B outputs>", type=int, help="Number of outputs for Bob.")
parser.add_argument("<A inputs>", type=int, help="Number of inputs for Alice.")
parser.add_argument("<B inputs>", type=int, help="Number of inputs for Bob.")
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite an existing file."
)
args = vars(parser.parse_args())
a_out, b_out, a_in, b_in = (
args["<A outputs>"],
args["<B outputs>"],
args["<A inputs>"],
args["<B inputs>"],
)
overwrite = args["overwrite"]
out_filename = FILE_NAME.format(a_out, b_out, a_in, b_in)
out_filepath = os.path.join(OUT_DIRECTORY, out_filename + ".h")
if os.path.exists(out_filepath) and not overwrite:
print(
"File already exists and --overwrite is not set, aborting.\n"
"(Was going to write to {})".format(out_filepath),
file=sys.stderr,
)
exit(1)
# Calculate the permutations
perm_len = (
(a_out - 1) * (b_out - 1) * a_in * b_in
+ (a_out - 1) * a_in
+ (b_out - 1) * b_in
)
permutations = []
for a_in_perm in itertools.permutations(range(a_in)):
for a_out_perm in itertools.permutations(range(a_out - 1)):
for b_in_perm in itertools.permutations(range(b_in)):
for b_out_perm in itertools.permutations(range(b_out - 1)):
permutation = []
# Joint probabilities
for new_a_in in a_in_perm:
for new_a_out in a_out_perm:
for new_b_in in b_in_perm:
for new_b_out in b_out_perm:
idx = new_b_out + (b_out - 1) * (
new_b_in
+ b_in * (new_a_out + (a_out - 1) * new_a_in)
)
permutation.append(idx)
# A's partials
for new_a_in in a_in_perm:
for new_a_out in a_out_perm:
idx = (
(a_out - 1) * (b_out - 1) * a_in * b_in
+ (a_out - 1) * new_a_in
+ new_a_out
)
permutation.append(idx)
# B's partials
for new_b_in in b_in_perm:
for new_b_out in b_out_perm:
idx = (
(a_out - 1) * (b_out - 1) * a_in * b_in
+ (a_out - 1) * a_in
+ (b_out - 1) * new_b_in
+ new_b_out
)
permutation.append(idx)
permutations.append(permutation)
assert len(permutation) == perm_len
# Write out the results
with open(out_filepath, "w") as out_file:
out_file.write(
HEADER_TEMPLATE.format(
date=datetime.datetime.now(),
doc="\n".join("// " + x for x in epilog.splitlines()),
file_name=out_filename.upper(),
a_out=a_out,
b_out=b_out,
a_in=a_in,
b_in=b_in,
perm_count=len(permutations),
perm_len=perm_len,
perm_arr=",\n".join(
" {" + ",".join(str(elem) for elem in perm) + "}"
for perm in permutations
),
)
)
if __name__ == "__main__":
main()