From 30f9cdc1209789964fbd390d176a1a4047878a2c Mon Sep 17 00:00:00 2001 From: Miguel M Date: Thu, 4 May 2023 19:47:54 +0100 Subject: [PATCH] Merge the permutations back into the Cg-P conversion As promised. --- includes/cg.h | 30 +++++-------- src/cg.c | 119 ++++++++++++++++++++------------------------------ src/main.c | 55 +++++++++++++---------- 3 files changed, 91 insertions(+), 113 deletions(-) diff --git a/includes/cg.h b/includes/cg.h index 13c007a..1d562c4 100644 --- a/includes/cg.h +++ b/includes/cg.h @@ -35,12 +35,15 @@ #include "../includes/data.h" #include "../includes/permutation.h" -// Convert from CG representation to P representation. +// Convert from CG representation to P representation, under a given +// permutation. // // `p` isn't required to be initialized, only allocated. void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t* restrict cg, - data_t* restrict p); + const size_t b_in, const data_t* restrict cg, data_t* restrict p, + const size_t* a_in_perm, const size_t* b_in_perm, + const permutation_generator_t* restrict a_out_perms, + const permutation_generator_t* restrict b_out_perms); // Convert from P representation to CG representation. // @@ -49,17 +52,6 @@ void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in, const size_t b_in, data_t* restrict cg, const data_t* restrict p); -// Swaps the entries of a P coefficients list under a given permtuation. -// -// The coefficients are read from `from_p` and written to `into_p`. Therefore, -// `into_p` does not need to be initialized, only allocated. -void PPermute(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const size_t* restrict a_in_perm, - const size_t* restrict b_in_perm, - const permutation_generator_t* restrict a_out_perms, - const permutation_generator_t* restrict b_out_perms, - const data_t* restrict from_p, data_t* restrict into_p); - // Swaps the input labels of the two parties, and the output labels of the two // parties. // @@ -69,15 +61,15 @@ void PPermute(const size_t a_out, const size_t b_out, const size_t a_in, void PSwapParties(const size_t out, const size_t in, data_t* p); // Debug print a CG row. -void PrintCg(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t* cg); +void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in, + const data_t* cg); // Debug print a P row. -void PrintP(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t* p); +void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in, + const data_t* p); // Print a P row as if it was in the input format. -void PrintPRaw(const size_t a_out, const size_t b_out, size_t a_in, size_t b_in, +void PrintPRaw(size_t a_out, size_t b_out, size_t a_in, size_t b_in, const data_t* p); #endif \ No newline at end of file diff --git a/src/cg.c b/src/cg.c index 9a2eef8..d642e8d 100644 --- a/src/cg.c +++ b/src/cg.c @@ -50,37 +50,35 @@ static inline size_t BOutPermuted(size_t b_out_i, size_t b_in_i, } void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t *restrict cg, - data_t *restrict p) { - // size_t p_row_size = a_out * a_in * b_out * b_in; + const size_t b_in, const data_t *restrict cg, data_t *restrict p, + const size_t *a_in_perm, const size_t *b_in_perm, + const permutation_generator_t *restrict a_out_perms, + const permutation_generator_t *restrict b_out_perms) { + size_t p_row_size = a_out * a_in * b_out * b_in; size_t cg_row_size = ((a_out - 1) * (b_out - 1) * a_in * b_in + (a_out - 1) * a_in + (b_out - 1) * b_in) + 1; // +1 accounts for L. + for (size_t i = 0; i < p_row_size; i++) { + p[i] = 0; + } + // Copy the given joint probabilities - // (and initialize to 0 those that aren't given). for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) { for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) { for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) { for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) { size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i); - size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, - a_in_i, b_in_i); - p[p_index] = cg[cg_index]; + size_t p_index = PIndex( + a_out, b_out, a_in, b_in, + AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), + a_out_perms), + BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), + b_out_perms), + AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm)); + p[p_index] += cg[cg_index]; } - - { - size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out - 1, - a_in_i, b_in_i); - p[p_index] = 0; - } - } - - for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) { - size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out - 1, b_out_i, - a_in_i, b_in_i); - p[p_index] = 0; } } } @@ -93,8 +91,11 @@ void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in, for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) { size_t a_in_i = 0; size_t b_in_i = 0; - size_t p_index = - PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i); + size_t p_index = PIndex( + a_out, b_out, a_in, b_in, + AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms), + BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms), + AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm)); p[p_index] += l; } } @@ -115,8 +116,11 @@ void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in, for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) { size_t b_in_i = 0; - size_t p_index = - PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i); + size_t p_index = PIndex( + a_out, b_out, a_in, b_in, + AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms), + BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms), + AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm)); p[p_index] += a_marginal; // PrintP(a_out, b_out, a_in, b_in, p); } @@ -135,8 +139,11 @@ void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in, for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) { size_t a_in_i = 0; - size_t p_index = - PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i); + size_t p_index = PIndex( + a_out, b_out, a_in, b_in, + AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms), + BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms), + AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm)); p[p_index] += b_marginal; // PrintP(a_out, b_out, a_in, b_in, p); } @@ -223,9 +230,8 @@ void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in, // 3. Contributions from the terms with the B output label OOBs for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) { for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) { - size_t cg_marginal_index = - CgAMarginalIndex(a_out, b_out, a_in, b_in, a_out_i, a_in_i); - + size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in + + (a_out - 1) * a_in_i + a_out_i; for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) { size_t oob_p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out - 1, a_in_i, b_in_i); @@ -266,55 +272,24 @@ void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in, } } -void PPermute(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const size_t *restrict a_in_perm, - const size_t *restrict b_in_perm, - const permutation_generator_t *restrict a_out_perms, - const permutation_generator_t *restrict b_out_perms, - const data_t *restrict from_p, data_t *restrict into_p) { - const size_t p_row_size = a_out * a_in * b_out * b_in; - for (size_t i = 0; i < p_row_size; i++) { - into_p[i] = 0; - } - for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) { - for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) { - for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) { - for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) { - size_t u = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, - b_in_i); - size_t v = PIndex(a_out, b_out, a_in, b_in, - AOutPermuted(a_out_i, a_in_i, a_out_perms), - BOutPermuted(b_out_i, b_in_i, b_out_perms), - AInPermuted(a_in_i, a_in_perm), - BInPermuted(b_in_i, b_in_perm)); - into_p[u] += from_p[v]; - } - } - } - } -} - void PSwapParties(const size_t out, const size_t in, data_t *p) { - // These loops could probably be improved, but I'm unsure how. - for (size_t a_in_i = 0; a_in_i < in; a_in_i++) { - for (size_t b_in_i = 0; b_in_i < in; b_in_i++) { - for (size_t a_out_i = 0; a_out_i < out; a_out_i++) { - for (size_t b_out_i = 0; b_out_i < out; b_out_i++) { + for (size_t a_out_i = 0; a_out_i < out; a_out_i++) { + for (size_t b_out_i = a_out_i + 1; b_out_i < out; b_out_i++) { + for (size_t a_in_i = 0; a_in_i < in; a_in_i++) { + for (size_t b_in_i = a_in_i + 1; b_in_i < in; b_in_i++) { size_t u = PIndex(out, out, in, in, a_out_i, b_out_i, a_in_i, b_in_i); size_t v = PIndex(out, out, in, in, b_out_i, a_out_i, b_in_i, a_in_i); - if (u < v) { - data_t tmp = p[u]; - p[u] = p[v]; - p[v] = tmp; - } + data_t tmp = p[u]; + p[u] = p[v]; + p[v] = tmp; } } } } } -void PrintCg(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t *cg) { +void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in, + const data_t *cg) { for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) { for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) { for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) { @@ -357,8 +332,8 @@ void PrintCg(const size_t a_out, const size_t b_out, const size_t a_in, } // Debug print a P row. -void PrintP(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t *p) { +void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in, + const data_t *p) { for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) { for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) { for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) { @@ -376,8 +351,8 @@ void PrintP(const size_t a_out, const size_t b_out, const size_t a_in, printf("\n"); } -void PrintPRaw(const size_t a_out, const size_t b_out, const size_t a_in, - const size_t b_in, const data_t *p) { +void PrintPRaw(size_t a_out, size_t b_out, size_t a_in, size_t b_in, + const data_t *p) { const size_t len = a_out * b_out * a_in * b_in; for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) { for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) { diff --git a/src/main.c b/src/main.c index 799b500..eb33868 100644 --- a/src/main.c +++ b/src/main.c @@ -71,7 +71,6 @@ int main(int argc, char *argv[]) { size_t row_count = matrix.len / matrix.row_len; data_t *p_buf = malloc(a_out * b_out * a_in * b_in * sizeof(data_t)); - data_t *p_swap_buf = malloc(a_out * b_out * a_in * b_in * sizeof(data_t)); data_t *cg_buf = malloc((((a_out - 1) * (b_out - 1) * a_in * b_in + (a_out - 1) * a_in + (b_out - 1) * b_in) + 1) * @@ -132,25 +131,25 @@ int main(int argc, char *argv[]) { } data_t *rhs = matrix.head + rhs_i * matrix.row_len; - FromCgToP(a_out, b_out, a_in, b_in, rhs, p_buf); PermutationReset(&a_in_perm); while (!a_in_perm.exhausted) { PermutationReset(&b_in_perm); while (!b_in_perm.exhausted) { - ResetConditionalPermutations(a_out_perms, a_in); - while (!a_out_perms[a_in - 1].exhausted) { - ResetConditionalPermutations(b_out_perms, b_in); - while (!b_out_perms[b_in - 1].exhausted) { - PPermute(a_out, b_out, a_in, b_in, a_in_perm.permutation, - b_in_perm.permutation, a_out_perms, b_out_perms, p_buf, - p_swap_buf); + ResetConditionalPermutations(a_out_perms, a_out); + while (!a_out_perms[a_out - 1].exhausted) { + ResetConditionalPermutations(b_out_perms, b_out); + while (!b_out_perms[b_out - 1].exhausted) { + // Compare the two rows + FromCgToP(a_out, b_out, a_in, b_in, rhs, p_buf, + a_in_perm.permutation, b_in_perm.permutation, + a_out_perms, b_out_perms); { - FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_swap_buf); + FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf); _Bool equivalent = 1; - for (size_t i = 0; i < row_len; i++) { - if (lhs[i] != cg_buf[i]) { + for (size_t i = row_len; i > 0; i--) { + if (lhs[i - 1] != cg_buf[i - 1]) { equivalent = 0; break; } @@ -168,12 +167,12 @@ int main(int argc, char *argv[]) { // I don't expect this conditional to be very penalizing because // it's very predictable. if (a_in == b_in && a_out == b_out) { - // Use the results in p_swap_buf - PSwapParties(a_out, a_in, p_swap_buf); - FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_swap_buf); + // Use the results in p_buf + PSwapParties(a_out, a_in, p_buf); + FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf); _Bool equivalent = 1; - for (size_t i = 0; i < row_len; i++) { - if (lhs[i] != cg_buf[i]) { + for (size_t i = row_len; i > 0; i--) { + if (lhs[i - 1] != cg_buf[i - 1]) { equivalent = 0; break; } @@ -185,10 +184,10 @@ int main(int argc, char *argv[]) { } } - AdvanceConditionalPermutations(b_out_perms, b_in); + AdvanceConditionalPermutations(b_out_perms, b_out); } - AdvanceConditionalPermutations(a_out_perms, a_in); + AdvanceConditionalPermutations(a_out_perms, a_out); } PermutationNext(&b_in_perm); @@ -201,6 +200,17 @@ int main(int argc, char *argv[]) { } // For loop over rhs_i } // For loop over lhs_i + if (cg_format == kPFmt) { + PermutationReset(&a_in_perm); + PermutationReset(&b_in_perm); + for (size_t i = 0; i < a_in; i++) { + PermutationReset(a_out_perms + i); + } + for (size_t i = 0; i < b_in; i++) { + PermutationReset(b_out_perms + i); + } + } + // Print every unique row for (size_t i = 0; i < row_count; i++) { if (!seen[i]) { @@ -216,12 +226,14 @@ int main(int argc, char *argv[]) { case kPFmt: { printf("%zu: ", i); FromCgToP(a_out, b_out, a_in, b_in, matrix.head + i * matrix.row_len, - p_buf); + p_buf, a_in_perm.permutation, b_in_perm.permutation, + a_out_perms, b_out_perms); PrintP(a_out, b_out, a_in, b_in, p_buf); } break; case kPRawFmt: { FromCgToP(a_out, b_out, a_in, b_in, matrix.head + i * matrix.row_len, - p_buf); + p_buf, a_in_perm.permutation, b_in_perm.permutation, + a_out_perms, b_out_perms); PrintPRaw(a_out, b_out, a_in, b_in, p_buf); } break; } @@ -232,7 +244,6 @@ int main(int argc, char *argv[]) { free(seen); free(cg_buf); free(p_buf); - free(p_swap_buf); PermutationFree(&a_in_perm); PermutationFree(&b_in_perm); for (size_t i = 0; i < a_in; i++) {