Conditional out label permutations should consider the permuted in labels

This commit is contained in:
Miguel M 2023-05-02 19:42:02 +01:00
parent 6ac08cee1b
commit ef438d8e5c
1 changed files with 36 additions and 29 deletions

View File

@ -67,30 +67,35 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
b_out_i, a_in_i, b_in_i);
size_t p_index = PIndex(a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, a_in_i, a_out_perms),
BOutPermuted(b_out_i, b_in_i, b_out_perms),
AInPermuted(a_in_i, a_in_perm),
BInPermuted(b_in_i, b_in_perm));
size_t p_index = PIndex(
a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm),
a_out_perms),
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm),
b_out_perms),
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
p[p_index] = cg[cg_index];
}
{
size_t p_index = PIndex(a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, a_in_i, a_out_perms),
BOutPermuted(b_out - 1, b_in_i, b_out_perms),
AInPermuted(a_in_i, a_in_perm),
BInPermuted(b_in_i, b_in_perm));
size_t p_index = PIndex(
a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm),
a_out_perms),
BOutPermuted(b_out - 1, BInPermuted(b_in_i, b_in_perm),
b_out_perms),
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
p[p_index] = 0;
}
}
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t p_index = PIndex(a_out, b_out, a_in, b_in,
AOutPermuted(a_out - 1, a_in_i, a_out_perms),
BOutPermuted(b_out_i, b_in_i, b_out_perms),
AInPermuted(a_in_i, a_in_perm),
BInPermuted(b_in_i, b_in_perm));
size_t p_index = PIndex(
a_out, b_out, a_in, b_in,
AOutPermuted(a_out - 1, AInPermuted(a_in_i, a_in_perm),
a_out_perms),
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
p[p_index] = 0;
}
}
@ -105,8 +110,10 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
size_t a_in_i = 0;
size_t b_in_i = 0;
size_t p_index = PIndex(
a_out, b_out, a_in, b_in, AOutPermuted(a_out_i, a_in_i, a_out_perms),
BOutPermuted(b_out_i, b_in_i, b_out_perms), a_in_i, b_in_i);
a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
p[p_index] += l;
}
}
@ -127,11 +134,11 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t b_in_i = 0;
size_t p_index = PIndex(a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, a_in_i, a_out_perms),
BOutPermuted(b_out_i, b_in_i, b_out_perms),
AInPermuted(a_in_i, a_in_perm),
BInPermuted(b_in_i, b_in_perm));
size_t p_index = PIndex(
a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
p[p_index] += a_marginal;
// PrintP(a_out, b_out, a_in, b_in, p);
}
@ -150,11 +157,11 @@ void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
size_t a_in_i = 0;
size_t p_index = PIndex(a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, a_in_i, a_out_perms),
BOutPermuted(b_out_i, b_in_i, b_out_perms),
AInPermuted(a_in_i, a_in_perm),
BInPermuted(b_in_i, b_in_perm));
size_t p_index = PIndex(
a_out, b_out, a_in, b_in,
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
p[p_index] += b_marginal;
// PrintP(a_out, b_out, a_in, b_in, p);
}
@ -284,9 +291,9 @@ void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
void PSwapParties(size_t out, size_t in, data_t *p) {
for (size_t a_out_i = 0; a_out_i < out; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < out; b_out_i++) {
for (size_t b_out_i = a_out_i + 1; b_out_i < out; b_out_i++) {
for (size_t a_in_i = 0; a_in_i < in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < in; b_in_i++) {
for (size_t b_in_i = a_in_i + 1; b_in_i < in; b_in_i++) {
size_t u = PIndex(out, out, in, in, a_out_i, b_out_i, a_in_i, b_in_i);
size_t v = PIndex(out, out, in, in, b_out_i, a_out_i, b_in_i, a_in_i);
data_t tmp = p[u];