Merge the permutations back into the Cg-P conversion
As promised.
This commit is contained in:
parent
18940c8fb8
commit
30f9cdc120
|
@ -35,12 +35,15 @@
|
|||
#include "../includes/data.h"
|
||||
#include "../includes/permutation.h"
|
||||
|
||||
// Convert from CG representation to P representation.
|
||||
// Convert from CG representation to P representation, under a given
|
||||
// permutation.
|
||||
//
|
||||
// `p` isn't required to be initialized, only allocated.
|
||||
void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t* restrict cg,
|
||||
data_t* restrict p);
|
||||
const size_t b_in, const data_t* restrict cg, data_t* restrict p,
|
||||
const size_t* a_in_perm, const size_t* b_in_perm,
|
||||
const permutation_generator_t* restrict a_out_perms,
|
||||
const permutation_generator_t* restrict b_out_perms);
|
||||
|
||||
// Convert from P representation to CG representation.
|
||||
//
|
||||
|
@ -49,17 +52,6 @@ void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
const size_t b_in, data_t* restrict cg,
|
||||
const data_t* restrict p);
|
||||
|
||||
// Swaps the entries of a P coefficients list under a given permtuation.
|
||||
//
|
||||
// The coefficients are read from `from_p` and written to `into_p`. Therefore,
|
||||
// `into_p` does not need to be initialized, only allocated.
|
||||
void PPermute(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const size_t* restrict a_in_perm,
|
||||
const size_t* restrict b_in_perm,
|
||||
const permutation_generator_t* restrict a_out_perms,
|
||||
const permutation_generator_t* restrict b_out_perms,
|
||||
const data_t* restrict from_p, data_t* restrict into_p);
|
||||
|
||||
// Swaps the input labels of the two parties, and the output labels of the two
|
||||
// parties.
|
||||
//
|
||||
|
@ -69,15 +61,15 @@ void PPermute(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
void PSwapParties(const size_t out, const size_t in, data_t* p);
|
||||
|
||||
// Debug print a CG row.
|
||||
void PrintCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t* cg);
|
||||
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
||||
const data_t* cg);
|
||||
|
||||
// Debug print a P row.
|
||||
void PrintP(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t* p);
|
||||
void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
||||
const data_t* p);
|
||||
|
||||
// Print a P row as if it was in the input format.
|
||||
void PrintPRaw(const size_t a_out, const size_t b_out, size_t a_in, size_t b_in,
|
||||
void PrintPRaw(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
||||
const data_t* p);
|
||||
|
||||
#endif
|
119
src/cg.c
119
src/cg.c
|
@ -50,37 +50,35 @@ static inline size_t BOutPermuted(size_t b_out_i, size_t b_in_i,
|
|||
}
|
||||
|
||||
void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t *restrict cg,
|
||||
data_t *restrict p) {
|
||||
// size_t p_row_size = a_out * a_in * b_out * b_in;
|
||||
const size_t b_in, const data_t *restrict cg, data_t *restrict p,
|
||||
const size_t *a_in_perm, const size_t *b_in_perm,
|
||||
const permutation_generator_t *restrict a_out_perms,
|
||||
const permutation_generator_t *restrict b_out_perms) {
|
||||
size_t p_row_size = a_out * a_in * b_out * b_in;
|
||||
size_t cg_row_size = ((a_out - 1) * (b_out - 1) * a_in * b_in +
|
||||
(a_out - 1) * a_in + (b_out - 1) * b_in) +
|
||||
1; // +1 accounts for L.
|
||||
|
||||
for (size_t i = 0; i < p_row_size; i++) {
|
||||
p[i] = 0;
|
||||
}
|
||||
|
||||
// Copy the given joint probabilities
|
||||
// (and initialize to 0 those that aren't given).
|
||||
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
||||
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
||||
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
||||
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
||||
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
|
||||
b_out_i, a_in_i, b_in_i);
|
||||
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i,
|
||||
a_in_i, b_in_i);
|
||||
p[p_index] = cg[cg_index];
|
||||
size_t p_index = PIndex(
|
||||
a_out, b_out, a_in, b_in,
|
||||
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm),
|
||||
a_out_perms),
|
||||
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm),
|
||||
b_out_perms),
|
||||
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
||||
p[p_index] += cg[cg_index];
|
||||
}
|
||||
|
||||
{
|
||||
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out - 1,
|
||||
a_in_i, b_in_i);
|
||||
p[p_index] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
||||
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out - 1, b_out_i,
|
||||
a_in_i, b_in_i);
|
||||
p[p_index] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -93,8 +91,11 @@ void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
||||
size_t a_in_i = 0;
|
||||
size_t b_in_i = 0;
|
||||
size_t p_index =
|
||||
PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i);
|
||||
size_t p_index = PIndex(
|
||||
a_out, b_out, a_in, b_in,
|
||||
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
|
||||
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
|
||||
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
||||
p[p_index] += l;
|
||||
}
|
||||
}
|
||||
|
@ -115,8 +116,11 @@ void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
|
||||
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
||||
size_t b_in_i = 0;
|
||||
size_t p_index =
|
||||
PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i);
|
||||
size_t p_index = PIndex(
|
||||
a_out, b_out, a_in, b_in,
|
||||
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
|
||||
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
|
||||
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
||||
p[p_index] += a_marginal;
|
||||
// PrintP(a_out, b_out, a_in, b_in, p);
|
||||
}
|
||||
|
@ -135,8 +139,11 @@ void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
|
||||
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
||||
size_t a_in_i = 0;
|
||||
size_t p_index =
|
||||
PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i);
|
||||
size_t p_index = PIndex(
|
||||
a_out, b_out, a_in, b_in,
|
||||
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
|
||||
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
|
||||
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
||||
p[p_index] += b_marginal;
|
||||
// PrintP(a_out, b_out, a_in, b_in, p);
|
||||
}
|
||||
|
@ -223,9 +230,8 @@ void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
// 3. Contributions from the terms with the B output label OOBs
|
||||
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
||||
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
||||
size_t cg_marginal_index =
|
||||
CgAMarginalIndex(a_out, b_out, a_in, b_in, a_out_i, a_in_i);
|
||||
|
||||
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
|
||||
(a_out - 1) * a_in_i + a_out_i;
|
||||
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
||||
size_t oob_p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i,
|
||||
b_out - 1, a_in_i, b_in_i);
|
||||
|
@ -266,55 +272,24 @@ void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
}
|
||||
}
|
||||
|
||||
void PPermute(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const size_t *restrict a_in_perm,
|
||||
const size_t *restrict b_in_perm,
|
||||
const permutation_generator_t *restrict a_out_perms,
|
||||
const permutation_generator_t *restrict b_out_perms,
|
||||
const data_t *restrict from_p, data_t *restrict into_p) {
|
||||
const size_t p_row_size = a_out * a_in * b_out * b_in;
|
||||
for (size_t i = 0; i < p_row_size; i++) {
|
||||
into_p[i] = 0;
|
||||
}
|
||||
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
||||
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
||||
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
||||
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
||||
size_t u = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i,
|
||||
b_in_i);
|
||||
size_t v = PIndex(a_out, b_out, a_in, b_in,
|
||||
AOutPermuted(a_out_i, a_in_i, a_out_perms),
|
||||
BOutPermuted(b_out_i, b_in_i, b_out_perms),
|
||||
AInPermuted(a_in_i, a_in_perm),
|
||||
BInPermuted(b_in_i, b_in_perm));
|
||||
into_p[u] += from_p[v];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PSwapParties(const size_t out, const size_t in, data_t *p) {
|
||||
// These loops could probably be improved, but I'm unsure how.
|
||||
for (size_t a_in_i = 0; a_in_i < in; a_in_i++) {
|
||||
for (size_t b_in_i = 0; b_in_i < in; b_in_i++) {
|
||||
for (size_t a_out_i = 0; a_out_i < out; a_out_i++) {
|
||||
for (size_t b_out_i = 0; b_out_i < out; b_out_i++) {
|
||||
for (size_t a_out_i = 0; a_out_i < out; a_out_i++) {
|
||||
for (size_t b_out_i = a_out_i + 1; b_out_i < out; b_out_i++) {
|
||||
for (size_t a_in_i = 0; a_in_i < in; a_in_i++) {
|
||||
for (size_t b_in_i = a_in_i + 1; b_in_i < in; b_in_i++) {
|
||||
size_t u = PIndex(out, out, in, in, a_out_i, b_out_i, a_in_i, b_in_i);
|
||||
size_t v = PIndex(out, out, in, in, b_out_i, a_out_i, b_in_i, a_in_i);
|
||||
if (u < v) {
|
||||
data_t tmp = p[u];
|
||||
p[u] = p[v];
|
||||
p[v] = tmp;
|
||||
}
|
||||
data_t tmp = p[u];
|
||||
p[u] = p[v];
|
||||
p[v] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PrintCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t *cg) {
|
||||
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
||||
const data_t *cg) {
|
||||
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
||||
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
||||
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
||||
|
@ -357,8 +332,8 @@ void PrintCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
}
|
||||
|
||||
// Debug print a P row.
|
||||
void PrintP(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t *p) {
|
||||
void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
||||
const data_t *p) {
|
||||
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
||||
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
||||
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
||||
|
@ -376,8 +351,8 @@ void PrintP(const size_t a_out, const size_t b_out, const size_t a_in,
|
|||
printf("\n");
|
||||
}
|
||||
|
||||
void PrintPRaw(const size_t a_out, const size_t b_out, const size_t a_in,
|
||||
const size_t b_in, const data_t *p) {
|
||||
void PrintPRaw(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
||||
const data_t *p) {
|
||||
const size_t len = a_out * b_out * a_in * b_in;
|
||||
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
||||
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
||||
|
|
55
src/main.c
55
src/main.c
|
@ -71,7 +71,6 @@ int main(int argc, char *argv[]) {
|
|||
size_t row_count = matrix.len / matrix.row_len;
|
||||
|
||||
data_t *p_buf = malloc(a_out * b_out * a_in * b_in * sizeof(data_t));
|
||||
data_t *p_swap_buf = malloc(a_out * b_out * a_in * b_in * sizeof(data_t));
|
||||
data_t *cg_buf = malloc((((a_out - 1) * (b_out - 1) * a_in * b_in +
|
||||
(a_out - 1) * a_in + (b_out - 1) * b_in) +
|
||||
1) *
|
||||
|
@ -132,25 +131,25 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
|
||||
data_t *rhs = matrix.head + rhs_i * matrix.row_len;
|
||||
FromCgToP(a_out, b_out, a_in, b_in, rhs, p_buf);
|
||||
|
||||
PermutationReset(&a_in_perm);
|
||||
while (!a_in_perm.exhausted) {
|
||||
PermutationReset(&b_in_perm);
|
||||
while (!b_in_perm.exhausted) {
|
||||
ResetConditionalPermutations(a_out_perms, a_in);
|
||||
while (!a_out_perms[a_in - 1].exhausted) {
|
||||
ResetConditionalPermutations(b_out_perms, b_in);
|
||||
while (!b_out_perms[b_in - 1].exhausted) {
|
||||
PPermute(a_out, b_out, a_in, b_in, a_in_perm.permutation,
|
||||
b_in_perm.permutation, a_out_perms, b_out_perms, p_buf,
|
||||
p_swap_buf);
|
||||
ResetConditionalPermutations(a_out_perms, a_out);
|
||||
while (!a_out_perms[a_out - 1].exhausted) {
|
||||
ResetConditionalPermutations(b_out_perms, b_out);
|
||||
while (!b_out_perms[b_out - 1].exhausted) {
|
||||
// Compare the two rows
|
||||
FromCgToP(a_out, b_out, a_in, b_in, rhs, p_buf,
|
||||
a_in_perm.permutation, b_in_perm.permutation,
|
||||
a_out_perms, b_out_perms);
|
||||
|
||||
{
|
||||
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_swap_buf);
|
||||
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf);
|
||||
_Bool equivalent = 1;
|
||||
for (size_t i = 0; i < row_len; i++) {
|
||||
if (lhs[i] != cg_buf[i]) {
|
||||
for (size_t i = row_len; i > 0; i--) {
|
||||
if (lhs[i - 1] != cg_buf[i - 1]) {
|
||||
equivalent = 0;
|
||||
break;
|
||||
}
|
||||
|
@ -168,12 +167,12 @@ int main(int argc, char *argv[]) {
|
|||
// I don't expect this conditional to be very penalizing because
|
||||
// it's very predictable.
|
||||
if (a_in == b_in && a_out == b_out) {
|
||||
// Use the results in p_swap_buf
|
||||
PSwapParties(a_out, a_in, p_swap_buf);
|
||||
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_swap_buf);
|
||||
// Use the results in p_buf
|
||||
PSwapParties(a_out, a_in, p_buf);
|
||||
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf);
|
||||
_Bool equivalent = 1;
|
||||
for (size_t i = 0; i < row_len; i++) {
|
||||
if (lhs[i] != cg_buf[i]) {
|
||||
for (size_t i = row_len; i > 0; i--) {
|
||||
if (lhs[i - 1] != cg_buf[i - 1]) {
|
||||
equivalent = 0;
|
||||
break;
|
||||
}
|
||||
|
@ -185,10 +184,10 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
}
|
||||
|
||||
AdvanceConditionalPermutations(b_out_perms, b_in);
|
||||
AdvanceConditionalPermutations(b_out_perms, b_out);
|
||||
}
|
||||
|
||||
AdvanceConditionalPermutations(a_out_perms, a_in);
|
||||
AdvanceConditionalPermutations(a_out_perms, a_out);
|
||||
}
|
||||
|
||||
PermutationNext(&b_in_perm);
|
||||
|
@ -201,6 +200,17 @@ int main(int argc, char *argv[]) {
|
|||
} // For loop over rhs_i
|
||||
} // For loop over lhs_i
|
||||
|
||||
if (cg_format == kPFmt) {
|
||||
PermutationReset(&a_in_perm);
|
||||
PermutationReset(&b_in_perm);
|
||||
for (size_t i = 0; i < a_in; i++) {
|
||||
PermutationReset(a_out_perms + i);
|
||||
}
|
||||
for (size_t i = 0; i < b_in; i++) {
|
||||
PermutationReset(b_out_perms + i);
|
||||
}
|
||||
}
|
||||
|
||||
// Print every unique row
|
||||
for (size_t i = 0; i < row_count; i++) {
|
||||
if (!seen[i]) {
|
||||
|
@ -216,12 +226,14 @@ int main(int argc, char *argv[]) {
|
|||
case kPFmt: {
|
||||
printf("%zu: ", i);
|
||||
FromCgToP(a_out, b_out, a_in, b_in, matrix.head + i * matrix.row_len,
|
||||
p_buf);
|
||||
p_buf, a_in_perm.permutation, b_in_perm.permutation,
|
||||
a_out_perms, b_out_perms);
|
||||
PrintP(a_out, b_out, a_in, b_in, p_buf);
|
||||
} break;
|
||||
case kPRawFmt: {
|
||||
FromCgToP(a_out, b_out, a_in, b_in, matrix.head + i * matrix.row_len,
|
||||
p_buf);
|
||||
p_buf, a_in_perm.permutation, b_in_perm.permutation,
|
||||
a_out_perms, b_out_perms);
|
||||
PrintPRaw(a_out, b_out, a_in, b_in, p_buf);
|
||||
} break;
|
||||
}
|
||||
|
@ -232,7 +244,6 @@ int main(int argc, char *argv[]) {
|
|||
free(seen);
|
||||
free(cg_buf);
|
||||
free(p_buf);
|
||||
free(p_swap_buf);
|
||||
PermutationFree(&a_in_perm);
|
||||
PermutationFree(&b_in_perm);
|
||||
for (size_t i = 0; i < a_in; i++) {
|
||||
|
|
Loading…
Reference in New Issue