Compare commits
2 Commits
168a8cf88c
...
0b79ca9ac3
Author | SHA1 | Date |
---|---|---|
Miguel M | 0b79ca9ac3 | |
Miguel M | 77d61f36e4 |
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""Generates the permutations LUTs for the C project.
|
||||
|
||||
One central aspect of this project is testing for equivalences under
|
||||
permutations. These permutations follow rules known ahead of time, and so they
|
||||
can be precomputed into look-up tables as well. This script produces those LUTs.
|
||||
|
||||
The order for the symbolic entries of the probabilities vector is known, and
|
||||
follows the pattern (for notation P[A outputs, B outputs | A inputs, B inputs]):
|
||||
|
||||
P[0,0|0,0] ... P[<A outputs>-2,<B outputs>-2|<A inputs>-1,<B outputs>-1],
|
||||
P_A[0|0] ... P_A[<A outputs>-2|<A inputs>-1]
|
||||
P_B[0|0] ... P_B[<B outputs>-2|<B inputs>-1]
|
||||
|
||||
For the elisions, the indexing order is B's outputs varying first, then B's
|
||||
inputs, then A's outputs, then A's inputs. For the partials, we consider
|
||||
(following the same reasoning) that the outputs vary "faster" than the inputs.
|
||||
|
||||
Let, for compactness, O_A/O_B be A's/B's outputs, and I_A/I_B be A's/B's
|
||||
inputs. Thus, a permutation may be stored as a string of size
|
||||
|
||||
(O_A-1)(O_B-1)(I_A)(I_B) + (O_A-1)(I_A) + (O_B-1)(I_B)
|
||||
|
||||
Where entry i gives the new index j after the permutation.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import itertools
|
||||
|
||||
|
||||
THIS_ROOT = os.path.dirname(os.path.realpath(__file__))
|
||||
OUT_DIRECTORY = os.path.realpath(os.path.join(THIS_ROOT, "../contrib/permutation"))
|
||||
FILE_NAME = "permutations_{}ao_{}bo_{}ai_{}bi"
|
||||
HEADER_TEMPLATE = """// Generated by permutations.py on {date}
|
||||
{doc}
|
||||
|
||||
#ifndef ACED_CONTRIB_PERMUTATION_{file_name}_H_
|
||||
#define ACED_CONTRIB_PERMUTATION_{file_name}_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#define PERM_ARRAY kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}
|
||||
|
||||
// Array of permutations.
|
||||
// Each permutation is an array of size `kPermLen`, implying the two-row form of
|
||||
// the permutation (i.e., giving the index into which that entry maps).
|
||||
const size_t kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}[{perm_count}][{perm_len}] = {{\n{perm_arr}\n}};
|
||||
// Size of `kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}`.
|
||||
const size_t kPermsLen = {perm_count};
|
||||
// Size of any array in `kPermAo{a_out}_Bo{b_out}_Ai{a_in}_Bi{b_in}`.
|
||||
const size_t kPermLen = {perm_len};
|
||||
// Number of input labels for A.
|
||||
const size_t kAInputs = {a_in};
|
||||
// Number of output labels for A.
|
||||
const size_t kAOutputs = {a_out};
|
||||
// Number of input labels for B.
|
||||
const size_t kBInputs = {b_in};
|
||||
// Number of output labels for B.
|
||||
const size_t kBOutputs = {b_out};
|
||||
|
||||
#endif // ACED_CONTRIB_PERMUTATION_{file_name}_H_"""
|
||||
|
||||
|
||||
def main():
|
||||
doclines = __doc__.splitlines()
|
||||
description = doclines[0]
|
||||
epilog = "\n".join(doclines[1:])
|
||||
|
||||
parser = argparse.ArgumentParser(description=description, epilog=epilog)
|
||||
parser.add_argument("<A outputs>", type=int, help="Number of outputs for Alice.")
|
||||
parser.add_argument("<B outputs>", type=int, help="Number of outputs for Bob.")
|
||||
parser.add_argument("<A inputs>", type=int, help="Number of inputs for Alice.")
|
||||
parser.add_argument("<B inputs>", type=int, help="Number of inputs for Bob.")
|
||||
parser.add_argument(
|
||||
"--overwrite", action="store_true", help="Overwrite an existing file."
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
a_out, b_out, a_in, b_in = (
|
||||
args["<A outputs>"],
|
||||
args["<B outputs>"],
|
||||
args["<A inputs>"],
|
||||
args["<B inputs>"],
|
||||
)
|
||||
overwrite = args["overwrite"]
|
||||
|
||||
out_filename = FILE_NAME.format(a_out, b_out, a_in, b_in)
|
||||
out_filepath = os.path.join(OUT_DIRECTORY, out_filename + ".h")
|
||||
|
||||
if os.path.exists(out_filepath) and not overwrite:
|
||||
print(
|
||||
"File already exists and --overwrite is not set, aborting.\n"
|
||||
"(Was going to write to {})".format(out_filepath),
|
||||
file=sys.stderr,
|
||||
)
|
||||
exit(1)
|
||||
|
||||
# Calculate the permutations
|
||||
perm_len = (
|
||||
(a_out - 1) * (b_out - 1) * a_in * b_in
|
||||
+ (a_out - 1) * a_in
|
||||
+ (b_out - 1) * b_in
|
||||
)
|
||||
permutations = []
|
||||
|
||||
for a_in_perm in itertools.permutations(range(a_in)):
|
||||
for a_out_perm in itertools.permutations(range(a_out - 1)):
|
||||
for b_in_perm in itertools.permutations(range(b_in)):
|
||||
for b_out_perm in itertools.permutations(range(b_out - 1)):
|
||||
permutation = []
|
||||
# Joint probabilities
|
||||
for new_a_in in a_in_perm:
|
||||
for new_a_out in a_out_perm:
|
||||
for new_b_in in b_in_perm:
|
||||
for new_b_out in b_out_perm:
|
||||
idx = new_b_out + (b_out - 1) * (
|
||||
new_b_in
|
||||
+ b_in * (new_a_out + (a_out - 1) * new_a_in)
|
||||
)
|
||||
permutation.append(idx)
|
||||
# A's partials
|
||||
for new_a_in in a_in_perm:
|
||||
for new_a_out in a_out_perm:
|
||||
idx = (
|
||||
(a_out - 1) * (b_out - 1) * a_in * b_in
|
||||
+ (a_out - 1) * new_a_in
|
||||
+ new_a_out
|
||||
)
|
||||
permutation.append(idx)
|
||||
# B's partials
|
||||
for new_b_in in b_in_perm:
|
||||
for new_b_out in b_out_perm:
|
||||
idx = (
|
||||
(a_out - 1) * (b_out - 1) * a_in * b_in
|
||||
+ (a_out - 1) * a_in
|
||||
+ (b_out - 1) * new_b_in
|
||||
+ new_b_out
|
||||
)
|
||||
permutation.append(idx)
|
||||
|
||||
permutations.append(permutation)
|
||||
assert len(permutation) == perm_len
|
||||
|
||||
# Write out the results
|
||||
with open(out_filepath, "w") as out_file:
|
||||
out_file.write(
|
||||
HEADER_TEMPLATE.format(
|
||||
date=datetime.datetime.now(),
|
||||
doc="\n".join("// " + x for x in epilog.splitlines()),
|
||||
file_name=out_filename.upper(),
|
||||
a_out=a_out,
|
||||
b_out=b_out,
|
||||
a_in=a_in,
|
||||
b_in=b_in,
|
||||
perm_count=len(permutations),
|
||||
perm_len=perm_len,
|
||||
perm_arr=",\n".join(
|
||||
" {" + ",".join(str(elem) for elem in perm) + "}"
|
||||
for perm in permutations
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,5 +1,5 @@
|
|||
#ifndef __H_MATRIX
|
||||
#define __H_MATRIX
|
||||
#ifndef ACED__MATRIX_H_
|
||||
#define ACED__MATRIX_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <inttypes.h>
|
||||
|
@ -33,7 +33,7 @@ typedef struct
|
|||
} matrix_bufs_t;
|
||||
|
||||
// Prints a debug view of a `matrix_t` to standard output.
|
||||
void debug_print_grid(matrix_t *grid);
|
||||
void DebugPrintGrid(matrix_t *grid);
|
||||
|
||||
// Parse a matrix from the standard input.
|
||||
//
|
||||
|
@ -41,6 +41,6 @@ void debug_print_grid(matrix_t *grid);
|
|||
// l is an m-sized column vector. This input is given in standard input as
|
||||
//
|
||||
// <row length> <whitespace separated entries>
|
||||
matrix_t parse_matrix(void);
|
||||
matrix_t ParseMatrix(void);
|
||||
|
||||
#endif // __H_MATRIX
|
||||
#endif // ACED__MATRIX_H_
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
matrix_t grid = parse_matrix();
|
||||
debug_print_grid(&grid);
|
||||
matrix_t grid = ParseMatrix();
|
||||
DebugPrintGrid(&grid);
|
||||
return 0;
|
||||
}
|
||||
|
|
13
src/matrix.c
13
src/matrix.c
|
@ -4,7 +4,7 @@
|
|||
#include <ctype.h>
|
||||
|
||||
// Prints a debug view of a `matrix_t` to standard output.
|
||||
void debug_print_grid(matrix_t *matrix)
|
||||
void DebugPrintGrid(matrix_t *matrix)
|
||||
{
|
||||
printf("Grid ( row_len: %zu, data_head: %p, entries_head: %p, data_len: %zu, entries_len: %zu ) { ",
|
||||
(matrix->row_len), (void *)(matrix->data_head), (void *)(matrix->entries_head),
|
||||
|
@ -84,7 +84,8 @@ matrix_bufs_t compact_buffers(matrix_t *matrix) {
|
|||
matrix->entries_head = realloc(matrix->entries_head, matrix->entries_len * sizeof(grid_data_t *));
|
||||
if ((matrix->data_head == NULL) || (matrix->entries_head == NULL))
|
||||
{
|
||||
printf(
|
||||
fprintf(
|
||||
stderr,
|
||||
"Failed to reallocate matrix data buffers. Was trying to "
|
||||
"allocate for %zu data entries, aborting.",
|
||||
matrix->data_len);
|
||||
|
@ -103,7 +104,7 @@ matrix_bufs_t compact_buffers(matrix_t *matrix) {
|
|||
// what we have is contiguous blocks of strings, structured as
|
||||
//
|
||||
// <string length> <characters...>
|
||||
matrix_t parse_matrix(void)
|
||||
matrix_t ParseMatrix(void)
|
||||
{
|
||||
size_t row_len;
|
||||
scanf("%zu", &row_len);
|
||||
|
@ -127,7 +128,8 @@ matrix_t parse_matrix(void)
|
|||
// A little input sanitization.
|
||||
if (!(isdigit(scanned) || (scanned == '-') || (scanned == '.')))
|
||||
{
|
||||
printf(
|
||||
fprintf(
|
||||
stderr,
|
||||
"Foreign character %c (%x) found when processing input, aborting.",
|
||||
scanned, scanned);
|
||||
exit(EXIT_FAILURE);
|
||||
|
@ -172,7 +174,8 @@ matrix_t parse_matrix(void)
|
|||
// of the number of entries in a row.
|
||||
if (!((entries_len % row_len) == 0))
|
||||
{
|
||||
printf("Number of entries is not consistent with provided row length. "
|
||||
fprintf(stderr,
|
||||
"Number of entries is not consistent with provided row length. "
|
||||
"Got row length of %zu, and read %zu entries. Aborting.",
|
||||
row_len, entries_len);
|
||||
// No need to free the buffers.
|
||||
|
|
Loading…
Reference in New Issue