400 lines
15 KiB
C
400 lines
15 KiB
C
#include "../includes/cg.h"
|
|
|
|
#include <stdio.h>
|
|
|
|
static inline size_t CgJointIndex(size_t a_out, size_t b_out,
|
|
size_t a_in __attribute__((unused)),
|
|
size_t b_in, size_t a_out_i, size_t b_out_i,
|
|
size_t a_in_i, size_t b_in_i) {
|
|
return b_out_i +
|
|
(b_out - 1) * (b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
|
|
}
|
|
|
|
static inline size_t CgAMarginalIndex(size_t a_out, size_t b_out, size_t a_in,
|
|
size_t b_in, size_t a_out_i,
|
|
size_t a_in_i) {
|
|
return (a_out - 1) * (b_out - 1) * a_in * b_in + a_out_i +
|
|
(a_out - 1) * a_in_i;
|
|
}
|
|
|
|
static inline size_t CgBMarginalIndex(size_t a_out, size_t b_out, size_t a_in,
|
|
size_t b_in, size_t b_out_i,
|
|
size_t b_in_i) {
|
|
return (a_out - 1) * (b_out - 1) * a_in * b_in + (a_out - 1) * a_in +
|
|
b_out_i + (b_out - 1) * b_in_i;
|
|
}
|
|
|
|
static inline size_t PIndex(size_t a_out, size_t b_out,
|
|
size_t a_in __attribute__((unused)), size_t b_in,
|
|
size_t a_out_i, size_t b_out_i, size_t a_in_i,
|
|
size_t b_in_i) {
|
|
return b_out_i + b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i));
|
|
}
|
|
|
|
static inline size_t AInPermuted(size_t a_in_i, const size_t *a_in_perm) {
|
|
return a_in_perm[a_in_i];
|
|
}
|
|
|
|
static inline size_t BInPermuted(size_t b_in_i, const size_t *b_in_perm) {
|
|
return b_in_perm[b_in_i];
|
|
}
|
|
|
|
static inline size_t AOutPermuted(size_t a_out_i, size_t a_in_i,
|
|
const permutation_generator_t *a_out_perms) {
|
|
return a_out_perms[a_in_i].permutation[a_out_i];
|
|
}
|
|
|
|
static inline size_t BOutPermuted(size_t b_out_i, size_t b_in_i,
|
|
const permutation_generator_t *b_out_perms) {
|
|
return b_out_perms[b_in_i].permutation[b_out_i];
|
|
}
|
|
|
|
void FromCgToP(const size_t a_out, const size_t b_out, const size_t a_in,
|
|
const size_t b_in, const data_t *restrict cg, data_t *restrict p,
|
|
const size_t *a_in_perm, const size_t *b_in_perm,
|
|
const permutation_generator_t *restrict a_out_perms,
|
|
const permutation_generator_t *restrict b_out_perms) {
|
|
size_t p_row_size = a_out * a_in * b_out * b_in;
|
|
size_t cg_row_size = ((a_out - 1) * (b_out - 1) * a_in * b_in +
|
|
(a_out - 1) * a_in + (b_out - 1) * b_in) +
|
|
1; // +1 accounts for L.
|
|
|
|
for (size_t i = 0; i < p_row_size; i++) {
|
|
p[i] = 0;
|
|
}
|
|
|
|
// Copy the given joint probabilities
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
|
|
b_out_i, a_in_i, b_in_i);
|
|
size_t p_index = PIndex(
|
|
a_out, b_out, a_in, b_in,
|
|
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm),
|
|
a_out_perms),
|
|
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm),
|
|
b_out_perms),
|
|
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
|
p[p_index] += cg[cg_index];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// L contributes to every joint probability, for a fixed choice of inputs,
|
|
// which we'll take to be 0,0.
|
|
data_t l = cg[cg_row_size - 1];
|
|
|
|
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
|
size_t a_in_i = 0;
|
|
size_t b_in_i = 0;
|
|
size_t p_index = PIndex(
|
|
a_out, b_out, a_in, b_in,
|
|
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
|
|
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
|
|
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
|
p[p_index] += l;
|
|
}
|
|
}
|
|
|
|
// Account for the marginal probabilities given.
|
|
// The convention will be that, where ambiguous, i.e., in the input of the
|
|
// other party, we will take it to be 0.
|
|
|
|
// A's marginals:
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
size_t cg_index =
|
|
CgAMarginalIndex(a_out, b_out, a_in, b_in, a_out_i, a_in_i);
|
|
data_t a_marginal = cg[cg_index];
|
|
if (a_marginal == 0) {
|
|
continue;
|
|
}
|
|
|
|
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
|
size_t b_in_i = 0;
|
|
size_t p_index = PIndex(
|
|
a_out, b_out, a_in, b_in,
|
|
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
|
|
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
|
|
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
|
p[p_index] += a_marginal;
|
|
// PrintP(a_out, b_out, a_in, b_in, p);
|
|
}
|
|
}
|
|
}
|
|
|
|
// B's marginals:
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t cg_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
|
|
(a_out - 1) * a_in + b_out_i + (b_out - 1) * b_in_i;
|
|
data_t b_marginal = cg[cg_index];
|
|
if (b_marginal == 0) {
|
|
continue;
|
|
}
|
|
|
|
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
|
size_t a_in_i = 0;
|
|
size_t p_index = PIndex(
|
|
a_out, b_out, a_in, b_in,
|
|
AOutPermuted(a_out_i, AInPermuted(a_in_i, a_in_perm), a_out_perms),
|
|
BOutPermuted(b_out_i, BInPermuted(b_in_i, b_in_perm), b_out_perms),
|
|
AInPermuted(a_in_i, a_in_perm), BInPermuted(b_in_i, b_in_perm));
|
|
p[p_index] += b_marginal;
|
|
// PrintP(a_out, b_out, a_in, b_in, p);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void FromPToCg(const size_t a_out, const size_t b_out, const size_t a_in,
|
|
const size_t b_in, data_t *restrict cg,
|
|
const data_t *restrict p) {
|
|
size_t cg_len = ((a_out - 1) * (b_out - 1) * a_in * b_in +
|
|
(a_out - 1) * a_in + (b_out - 1) * b_in) +
|
|
1;
|
|
|
|
// 1. Copy the common factors & zero-out the marginals
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
{ // Common factors
|
|
size_t p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i,
|
|
a_in_i, b_in_i);
|
|
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
|
|
b_out_i, a_in_i, b_in_i);
|
|
cg[cg_index] = p[p_index];
|
|
}
|
|
}
|
|
}
|
|
|
|
{ // Zero A's marginals
|
|
size_t cg_marginal_index =
|
|
CgAMarginalIndex(a_out, b_out, a_in, b_in, a_out_i, a_in_i);
|
|
cg[cg_marginal_index] = 0;
|
|
}
|
|
}
|
|
}
|
|
// Zero B's marginals
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
size_t cg_marginal_index =
|
|
CgBMarginalIndex(a_out, b_out, a_in, b_in, b_out_i, b_in_i);
|
|
cg[cg_marginal_index] = 0;
|
|
}
|
|
}
|
|
|
|
// 2. Contributions from the terms with both output labels OOBs
|
|
{
|
|
data_t *l = cg + (cg_len - 1);
|
|
*l = 0;
|
|
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
size_t oob_p_index = PIndex(a_out, b_out, a_in, b_in, a_out - 1,
|
|
b_out - 1, a_in_i, b_in_i);
|
|
|
|
// L contribution
|
|
*l += p[oob_p_index];
|
|
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
// Contributions to terms with no output labels OOBs
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
|
|
b_out_i, a_in_i, b_in_i);
|
|
cg[cg_index] += p[oob_p_index];
|
|
}
|
|
|
|
{ // Contribution to A's marginal
|
|
size_t cg_marginal_index =
|
|
CgAMarginalIndex(a_out, b_out, a_in, b_in, a_out_i, a_in_i);
|
|
cg[cg_marginal_index] -= p[oob_p_index];
|
|
}
|
|
}
|
|
|
|
// Contributions to B's marginal
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t cg_marginal_index =
|
|
CgBMarginalIndex(a_out, b_out, a_in, b_in, b_out_i, b_in_i);
|
|
cg[cg_marginal_index] -= p[oob_p_index];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Contributions from the terms with the B output label OOBs
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
|
|
(a_out - 1) * a_in_i + a_out_i;
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
size_t oob_p_index = PIndex(a_out, b_out, a_in, b_in, a_out_i,
|
|
b_out - 1, a_in_i, b_in_i);
|
|
|
|
// A's marginals contributions
|
|
cg[cg_marginal_index] += p[oob_p_index];
|
|
|
|
// Contributions to terms with no labels OOBs
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
|
|
b_out_i, a_in_i, b_in_i);
|
|
cg[cg_index] -= p[oob_p_index];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. Contributions from the A output label OOBs.
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t cg_marginal_index =
|
|
CgBMarginalIndex(a_out, b_out, a_in, b_in, b_out_i, b_in_i);
|
|
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
// B's marginals contributions
|
|
size_t oob_p_index = PIndex(a_out, b_out, a_in, b_in, a_out - 1,
|
|
b_out_i, a_in_i, b_in_i);
|
|
cg[cg_marginal_index] += p[oob_p_index];
|
|
|
|
// Other contributions
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
size_t cg_index = CgJointIndex(a_out, b_out, a_in, b_in, a_out_i,
|
|
b_out_i, a_in_i, b_in_i);
|
|
cg[cg_index] -= p[oob_p_index];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// To perform a swap for every combination, we need only to consider tuples
|
|
//
|
|
// (oA, oB, iA, iB)
|
|
//
|
|
// such that, when lexicographically ordered (and we take this order to be
|
|
// right-to-left, in the sense that a greater iB places an entry later, then
|
|
// a greater iA, and so on), we have that
|
|
//
|
|
// (oA, oB, iA, iB) < (oB, oA, iB, iA)
|
|
//
|
|
// meaning the left-hand side comes before, in the ordering, to the right-hand
|
|
// side.
|
|
// (The case where the two sides are equal is not covered, but that case does
|
|
// not require a swap.)
|
|
//
|
|
// This can easily be seen to mean that we require
|
|
//
|
|
// iB < iA OR (iB = iA AND oB < oA)
|
|
void PSwapParties(const size_t out, const size_t in, data_t *p) {
|
|
for (size_t a_in_i = 0; a_in_i < in; a_in_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < a_in_i; b_in_i++) {
|
|
for (size_t a_out_i = 0; a_out_i < out; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < out; b_out_i++) {
|
|
size_t u = PIndex(out, out, in, in, a_out_i, b_out_i, a_in_i, b_in_i);
|
|
size_t v = PIndex(out, out, in, in, b_out_i, a_out_i, b_in_i, a_in_i);
|
|
data_t tmp = p[u];
|
|
p[u] = p[v];
|
|
p[v] = tmp;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (size_t a_out_i = 0; a_out_i < out; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < a_out_i; b_out_i++) {
|
|
size_t u = PIndex(out, out, in, in, a_out_i, b_out_i, a_in_i, a_in_i);
|
|
size_t v = PIndex(out, out, in, in, b_out_i, a_out_i, a_in_i, a_in_i);
|
|
data_t tmp = p[u];
|
|
p[u] = p[v];
|
|
p[v] = tmp;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
|
const data_t *cg) {
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t idx =
|
|
b_out_i +
|
|
(b_out - 1) * (b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
|
|
if (cg[idx] != 0) {
|
|
printf("(%zu,%zu|%zu,%zu): %" PRImDATA ", ", a_out_i, b_out_i,
|
|
a_in_i, b_in_i, cg[idx]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
|
|
size_t idx = (a_out - 1) * (b_out - 1) * a_in * b_in + a_out_i +
|
|
(a_out - 1) * a_in_i;
|
|
if (cg[idx] != 0) {
|
|
printf("A(%zu|%zu): %" PRImDATA ", ", a_out_i, a_in_i, cg[idx]);
|
|
}
|
|
}
|
|
}
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
|
|
size_t idx = (a_out - 1) * (b_out - 1) * a_in * b_in +
|
|
(a_out - 1) * a_in + b_out_i + (b_out - 1) * b_in_i;
|
|
if (cg[idx] != 0) {
|
|
printf("B(%zu|%zu): %" PRImDATA ", ", b_out_i, b_in_i, cg[idx]);
|
|
}
|
|
}
|
|
}
|
|
{
|
|
size_t idx = (a_out - 1) * (b_out - 1) * a_in * b_in + (a_out - 1) * a_in +
|
|
(b_out - 1) * b_in;
|
|
printf("L: %" PRImDATA, cg[idx]);
|
|
}
|
|
printf("\n");
|
|
}
|
|
|
|
// Debug print a P row.
|
|
void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
|
const data_t *p) {
|
|
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
size_t idx = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i,
|
|
a_in_i, b_in_i);
|
|
if (p[idx] != 0) {
|
|
printf("%s%" PRImDATA "p(%zu,%zu|%zu,%zu)", p[idx] < 0 ? "" : "+",
|
|
p[idx], a_out_i, b_out_i, a_in_i, b_in_i);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
printf("\n");
|
|
}
|
|
|
|
void PrintPRaw(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
|
|
const data_t *p) {
|
|
const size_t len = a_out * b_out * a_in * b_in;
|
|
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
|
|
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
|
|
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
|
|
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
|
|
size_t idx = PIndex(a_out, b_out, a_in, b_in, a_out_i, b_out_i,
|
|
a_in_i, b_in_i);
|
|
printf("%" PRImDATA, p[idx]);
|
|
if (idx != len - 1) {
|
|
printf(", ");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
printf("\n");
|
|
} |