Compare commits

...

2 Commits

Author SHA1 Message Date
Miguel M f8eadd3523 2222 produces correct result (two classes) 2023-05-01 20:01:07 +01:00
Miguel M 51e1228af0 fix sanefile 2023-05-01 20:00:23 +01:00
4 changed files with 66 additions and 126 deletions

View File

@ -45,19 +45,20 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const permutation_generator_t* restrict a_out_perms,
const permutation_generator_t* restrict b_out_perms);
// Convert from CG representation to P representation, without applying any
// permutation.
//
// `p` isn't required to be initialized, only allocated.
void FromCgToPIdent(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t* restrict cg, data_t* restrict p);
// Convert from P representation to CG representation.
//
// `cg` isn't required to be initialized, only allocated.
void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
data_t* restrict cg, const data_t* restrict p);
// Swaps the input labels of the two parties, and the output labels of the two
// parties.
//
// This function only makes sense if `a_out == b_out` and `a_in ==
// b_in`. Calling it otherwise may be undefined behaviour. `p` is modified in
// place.
void PSwapParties(size_t out, size_t in, data_t* p);
// Debug print a CG row.
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t* cg);

View File

@ -8,7 +8,7 @@ from sane import *
COMPILER = "gcc"
COMPILE_FLAGS = [
"-std=c11",
"-g",
"-g3",
"-O0",
"-Wall",
"-Wextra",

118
src/cg.c
View File

@ -107,8 +107,7 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
size_t p_index = PIndex(
a_out, b_out, a_in, b_in, AOutPermuted(a_out_i, a_in_i, a_out_perms),
BOutPermuted(b_out_i, b_in_i, b_out_perms), a_in_i, b_in_i);
// Negative sign comes from moving L to the LHS of the inequality
p[p_index] -= l;
p[p_index] += l;
}
}
@ -163,101 +162,6 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
}
}
void FromCgToPIdent(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t *restrict cg, data_t *restrict p) {
// size_t p_row_size = a_out * a_in * b_out * b_in;
size_t cg_row_size = ((a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + (b_out - 1) * b_in) +
1; // +1 accounts for L.
// Copy the given joint probabilities
// (and initialize to 0 those that aren't given).
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
b_out_i, a_in_i, b_in_i);
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i,
a_in_i, b_in_i);
p[p_index] = cg[cg_index];
}
{
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out - 1,
a_in_i, b_in_i);
p[p_index] = 0;
}
}
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out - 1, b_out_i,
a_in_i, b_in_i);
p[p_index] = 0;
}
}
}
// L contributes to every joint probability, for a fixed choice of inputs,
// which we'll take to be 0,0.
data_t l = cg[cg_row_size - 1];
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t a_in_i = 0;
size_t b_in_i = 0;
size_t p_index =
PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i);
// Negative sign comes from moving L to the LHS of the inequality
p[p_index] -= l;
}
}
// Account for the marginal probabilities given.
// The convention will be that, where ambiguous, i.e., in the input of the
// other party, we will take it to be 0.
// A's marginals:
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
size_t cg_index =
CgAMarginalIndex(a_out, b_out, a_in, b_in, a_out_i, a_in_i);
data_t a_marginal = cg[cg_index];
if (a_marginal == 0) {
continue;
}
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t b_in_i = 0;
size_t p_index =
PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i);
p[p_index] += a_marginal;
// PrintP(a_out, b_out, a_in, b_in, p);
}
}
}
// B's marginals:
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + b_out_i + (b_out - 1) * b_in_i;
data_t b_marginal = cg[cg_index];
if (b_marginal == 0) {
continue;
}
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
size_t a_in_i = 0;
size_t p_index =
PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i, a_in_i, b_in_i);
p[p_index] += b_marginal;
// PrintP(a_out, b_out, a_in, b_in, p);
}
}
}
}
void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
data_t *restrict cg, const data_t *restrict p) {
size_t cg_len = ((a_out - 1) * (b_out - 1) * a_in * b_in +
@ -306,9 +210,7 @@ void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
b_out - 1, a_in_i, b_in_i);
// L contribution
// The minus side comes from considering L to be on the RHS of the
// inequality.
*l -= p[oob_p_index];
*l += p[oob_p_index];
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
// Contributions to terms with no output labels OOBs
@ -380,6 +282,22 @@ void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
}
}
void PSwapParties(size_t out, size_t in, data_t *p) {
for (size_t a_out_i = 0; a_out_i < out; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < out; b_out_i++) {
for (size_t a_in_i = 0; a_in_i < in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < in; b_in_i++) {
size_t u = PIndex(out, out, in, in, a_out_i, b_out_i, a_in_i, b_in_i);
size_t v = PIndex(out, out, in, in, b_out_i, a_out_i, b_in_i, a_in_i);
data_t tmp = p[u];
p[u] = p[v];
p[v] = tmp;
}
}
}
}
}
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t *cg) {
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {

View File

@ -100,11 +100,6 @@ int main(int argc, char *argv[]) {
}
data_t *lhs = matrix.head + lhs_i * matrix.row_len;
FromCgToPIdent(a_out, b_out, a_in, b_in, lhs, p_buf);
FromPToCg(a_out, b_out, a_in, b_in, lhs, p_buf);
printf("LHS:");
PrintCg(a_out, b_out, a_in, b_in, lhs);
for (size_t rhs_i = lhs_i + 1; rhs_i < row_count; rhs_i++) {
if (seen[rhs_i]) {
@ -112,8 +107,6 @@ int main(int argc, char *argv[]) {
}
data_t *rhs = matrix.head + rhs_i * matrix.row_len;
printf("RHS:");
PrintCg(a_out, b_out, a_in, b_in, rhs);
PermutationReset(&a_in_perm);
while (!a_in_perm.exhausted) {
@ -127,19 +120,44 @@ int main(int argc, char *argv[]) {
FromCgToP(a_out, b_out, a_in, b_in, rhs, p_buf,
a_in_perm.permutation, b_in_perm.permutation,
a_out_perms, b_out_perms);
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf);
_Bool equivalent = 1;
for (size_t i = 0; i < row_len; i++) {
if (lhs[i] != cg_buf[i]) {
equivalent = 0;
break;
{
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf);
_Bool equivalent = 1;
for (size_t i = 0; i < row_len; i++) {
if (lhs[i] != cg_buf[i]) {
equivalent = 0;
break;
}
}
if (equivalent) {
seen[rhs_i] = 1;
goto skip_permutations;
}
}
if (equivalent) {
seen[rhs_i] = 1;
goto skip_permutations;
// If the number of output labels are the same, and the number of
// input labels is also the same, we can also check for equality
// under party swapping.
// I don't expect this conditional to be very penalizing because
// it's very perdictable.
if (a_in == b_in && a_out == b_out) {
// Use the results in p_buf
PSwapParties(a_out, a_in, p_buf);
FromPToCg(a_out, b_out, a_in, b_in, cg_buf, p_buf);
_Bool equivalent = 1;
for (size_t i = 0; i < row_len; i++) {
if (lhs[i] != cg_buf[i]) {
equivalent = 0;
break;
}
}
if (equivalent) {
seen[rhs_i] = 1;
goto skip_permutations;
}
}
AdvanceConditionalPermutations(b_out_perms, b_out);
@ -158,12 +176,15 @@ int main(int argc, char *argv[]) {
} // For loop over rhs_i
} // For loop over lhs_i
// Print every unique row
for (size_t i = 0; i < row_count; i++) {
if (!seen[i]) {
#ifdef RELEASE
PrintMatrixRow(&matrix, i);
#else
printf("%zu: ", i);
PrintCg(a_out, b_out, a_in, b_in, matrix.head + i * matrix.row_len);
#endif
}
}