Compare commits

...

2 Commits

Author SHA1 Message Date
Miguel M 0b79ca9ac3 Python script to generate CG label permutations 2023-04-26 13:09:30 +01:00
Miguel M 77d61f36e4 Google style changes + write errors to stderr 2023-04-23 19:36:36 +01:00
4 changed files with 185 additions and 12 deletions

170
aux/permutations.py Normal file
View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
"""Generates the permutations LUTs for the C project.
One central aspect of this project is testing for equivalences under
permutations. These permutations follow rules known ahead of time, and so they
can be precomputed into look-up tables as well. This script produces those LUTs.
The order for the symbolic entries of the probabilities vector is known, and
follows the pattern (for notation P[A outputs, B outputs | A inputs, B inputs]):
P[0,0|0,0] ... P[<A outputs>-2,<B outputs>-2|<A inputs>-1,<B outputs>-1],
P_A[0|0] ... P_A[<A outputs>-2|<A inputs>-1]
P_B[0|0] ... P_B[<B outputs>-2|<B inputs>-1]
For the elisions, the indexing order is B's outputs varying first, then B's
inputs, then A's outputs, then A's inputs. For the partials, we consider
(following the same reasoning) that the outputs vary "faster" than the inputs.
Let, for compactness, O_A/O_B be A's/B's outputs, and I_A/I_B be A's/B's
inputs. Thus, a permutation may be stored as a string of size
(O_A-1)(O_B-1)(I_A)(I_B) + (O_A-1)(I_A) + (O_B-1)(I_B)
Where entry i gives the new index j after the permutation.
"""
import argparse
import os
import sys
import datetime
import itertools
THIS_ROOT = os.path.dirname(os.path.realpath(__file__))
OUT_DIRECTORY = os.path.realpath(os.path.join(THIS_ROOT, "../contrib/permutation"))
FILE_NAME = "permutations_{}ao_{}bo_{}ai_{}bi"
HEADER_TEMPLATE = """// Generated by permutations.py on {date}
{doc}
#ifndef ACED_CONTRIB_PERMUTATION_{file_name}_H_
#define ACED_CONTRIB_PERMUTATION_{file_name}_H_
#include <stddef.h>
#define PERM_ARRAY kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}
// Array of permutations.
// Each permutation is an array of size `kPermLen`, implying the two-row form of
// the permutation (i.e., giving the index into which that entry maps).
const size_t kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}[{perm_count}][{perm_len}] = {{\n{perm_arr}\n}};
// Size of `kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}`.
const size_t kPermsLen = {perm_count};
// Size of any array in `kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}`.
const size_t kPermLen = {perm_len};
// Number of input labels for A.
const size_t kAInputs = {a_in};
// Number of output labels for A.
const size_t kAOutputs = {a_out};
// Number of input labels for B.
const size_t kBInputs = {b_in};
// Number of output labels for B.
const size_t kBOutputs = {b_out};
#endif // ACED_CONTRIB_PERMUTATION_{file_name}_H_"""
def main():
doclines = __doc__.splitlines()
description = doclines[0]
epilog = "\n".join(doclines[1:])
parser = argparse.ArgumentParser(description=description, epilog=epilog)
parser.add_argument("<A outputs>", type=int, help="Number of outputs for Alice.")
parser.add_argument("<B outputs>", type=int, help="Number of outputs for Bob.")
parser.add_argument("<A inputs>", type=int, help="Number of inputs for Alice.")
parser.add_argument("<B inputs>", type=int, help="Number of inputs for Bob.")
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite an existing file."
)
args = vars(parser.parse_args())
a_out, b_out, a_in, b_in = (
args["<A outputs>"],
args["<B outputs>"],
args["<A inputs>"],
args["<B inputs>"],
)
overwrite = args["overwrite"]
out_filename = FILE_NAME.format(a_out, b_out, a_in, b_in)
out_filepath = os.path.join(OUT_DIRECTORY, out_filename + ".h")
if os.path.exists(out_filepath) and not overwrite:
print(
"File already exists and --overwrite is not set, aborting.\n"
"(Was going to write to {})".format(out_filepath),
file=sys.stderr,
)
exit(1)
# Calculate the permutations
perm_len = (
(a_out - 1) * (b_out - 1) * a_in * b_in
+ (a_out - 1) * a_in
+ (b_out - 1) * b_in
)
permutations = []
for a_in_perm in itertools.permutations(range(a_in)):
for a_out_perm in itertools.permutations(range(a_out - 1)):
for b_in_perm in itertools.permutations(range(b_in)):
for b_out_perm in itertools.permutations(range(b_out - 1)):
permutation = []
# Joint probabilities
for new_a_in in a_in_perm:
for new_a_out in a_out_perm:
for new_b_in in b_in_perm:
for new_b_out in b_out_perm:
idx = new_b_out + (b_out - 1) * (
new_b_in
+ b_in * (new_a_out + (a_out - 1) * new_a_in)
)
permutation.append(idx)
# A's partials
for new_a_in in a_in_perm:
for new_a_out in a_out_perm:
idx = (
(a_out - 1) * (b_out - 1) * a_in * b_in
+ (a_out - 1) * new_a_in
+ new_a_out
)
permutation.append(idx)
# B's partials
for new_b_in in b_in_perm:
for new_b_out in b_out_perm:
idx = (
(a_out - 1) * (b_out - 1) * a_in * b_in
+ (a_out - 1) * a_in
+ (b_out - 1) * new_b_in
+ new_b_out
)
permutation.append(idx)
permutations.append(permutation)
assert len(permutation) == perm_len
# Write out the results
with open(out_filepath, "w") as out_file:
out_file.write(
HEADER_TEMPLATE.format(
date=datetime.datetime.now(),
doc="\n".join("// " + x for x in epilog.splitlines()),
file_name=out_filename.upper(),
a_out=a_out,
b_out=b_out,
a_in=a_in,
b_in=b_in,
perm_count=len(permutations),
perm_len=perm_len,
perm_arr=",\n".join(
" {" + ",".join(str(elem) for elem in perm) + "}"
for perm in permutations
),
)
)
if __name__ == "__main__":
main()

View File

@ -1,5 +1,5 @@
#ifndef __H_MATRIX
#define __H_MATRIX
#ifndef ACED__MATRIX_H_
#define ACED__MATRIX_H_
#include <stddef.h>
#include <inttypes.h>
@ -33,7 +33,7 @@ typedef struct
} matrix_bufs_t;
// Prints a debug view of a `matrix_t` to standard output.
void debug_print_grid(matrix_t *grid);
void DebugPrintGrid(matrix_t *grid);
// Parse a matrix from the standard input.
//
@ -41,6 +41,6 @@ void debug_print_grid(matrix_t *grid);
// l is an m-sized column vector. This input is given in standard input as
//
// <row length> <whitespace separated entries>
matrix_t parse_matrix(void);
matrix_t ParseMatrix(void);
#endif // __H_MATRIX
#endif // ACED__MATRIX_H_

View File

@ -6,7 +6,7 @@
int main(int argc, char *argv[])
{
matrix_t grid = parse_matrix();
debug_print_grid(&grid);
matrix_t grid = ParseMatrix();
DebugPrintGrid(&grid);
return 0;
}

View File

@ -4,7 +4,7 @@
#include <ctype.h>
// Prints a debug view of a `matrix_t` to standard output.
void debug_print_grid(matrix_t *matrix)
void DebugPrintGrid(matrix_t *matrix)
{
printf("Grid ( row_len: %zu, data_head: %p, entries_head: %p, data_len: %zu, entries_len: %zu ) { ",
(matrix->row_len), (void *)(matrix->data_head), (void *)(matrix->entries_head),
@ -84,7 +84,8 @@ matrix_bufs_t compact_buffers(matrix_t *matrix) {
matrix->entries_head = realloc(matrix->entries_head, matrix->entries_len * sizeof(grid_data_t *));
if ((matrix->data_head == NULL) || (matrix->entries_head == NULL))
{
printf(
fprintf(
stderr,
"Failed to reallocate matrix data buffers. Was trying to "
"allocate for %zu data entries, aborting.",
matrix->data_len);
@ -103,7 +104,7 @@ matrix_bufs_t compact_buffers(matrix_t *matrix) {
// what we have is contiguous blocks of strings, structured as
//
// <string length> <characters...>
matrix_t parse_matrix(void)
matrix_t ParseMatrix(void)
{
size_t row_len;
scanf("%zu", &row_len);
@ -127,7 +128,8 @@ matrix_t parse_matrix(void)
// A little input sanitization.
if (!(isdigit(scanned) || (scanned == '-') || (scanned == '.')))
{
printf(
fprintf(
stderr,
"Foreign character %c (%x) found when processing input, aborting.",
scanned, scanned);
exit(EXIT_FAILURE);
@ -172,7 +174,8 @@ matrix_t parse_matrix(void)
// of the number of entries in a row.
if (!((entries_len % row_len) == 0))
{
printf("Number of entries is not consistent with provided row length. "
fprintf(stderr,
"Number of entries is not consistent with provided row length. "
"Got row length of %zu, and read %zu entries. Aborting.",
row_len, entries_len);
// No need to free the buffers.