checkpoint: cg-p conversion not quite working yet

This commit is contained in:
Miguel M 2023-04-30 18:26:02 +01:00
parent 37f8b0a30a
commit 8c1bf8a57f
2 changed files with 359 additions and 0 deletions

56
includes/cg.h Normal file
View File

@ -0,0 +1,56 @@
#ifndef ACED_INCLUDES_CG_H
#define ACED_INCLUDES_CG_H
// This file provides utilities to convert from and to P and CG notation.
//
// P notation refers to writing a set of joint probabilities of a two player
// game (or associated coefficients) as the explicit set of all possible joint
// probabilities, i.e.,
//
// P[0,0|0,0] ... P[<A outputs>-1, <B outputs>-1 | <A inputs>-1, <B inputs>-1] L
//
// where L is the element associated to the identity.
// For the elisions, the indexing order is B's outputs varying first, then B's
// inputs, then A's outputs, then A's inputs. For the partials, we consider
// (following the same reasoning) that the outputs vary "faster" than the
// inputs.
//
// Collins-Gisin (CG) notation represents the coefficients associated to a set
// of probabilities as in P notation, but in a condensed form, namely
//
// P[0,0|0,0] ... P[<A outputs>-2,<B outputs>-2|<A inputs>-1,<B outputs>-1],
// P_A[0|0] ... P_A[<A outputs>-2|<A inputs>-1]
// P_B[0|0] ... P_B[<B outputs>-2|<B inputs>-1]
//
// where we introduce the notion of marginal distributions P_A and P_B, meaning
// the coefficients associated to, respectively
//
// for all(i_A, o_A), sum(o_B) p(o_A, o_B | i_A, i_B), for all(i_B)
// for all(i_B, o_B), sum(o_A) p(o_A, o_B | i_A, i_B), for all(i_A)
//
// The same elision rules apply.
#include <stddef.h>
#include "../includes/data.h"
// Convert from CG representation to P representation.
//
// `p` isn't required to be initialized, as this function will overwrite every
// entry (provided that `p` is of the expected size).row
void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t* restrict cg, data_t* restrict p);
// Convert from P representation to CG representation.
void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
data_t* restrict cg, const data_t* restrict p);
// Debug print a CG row.
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t* cg);
// Debug print a P row.
void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t* p);
#endif

303
src/cg.c Normal file
View File

@ -0,0 +1,303 @@
#include "../includes/cg.h"
#include <stdio.h>
void FromCgToP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t *restrict cg, data_t *restrict p) {
// size_t p_row_size = a_out * a_in * b_out * b_in;
size_t cg_row_size = ((a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + (b_out - 1) * b_in) +
1; // +1 accounts for L.
// Copy the given joint probabilities
// (and initialize to 0 those that aren't given).
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index =
b_out_i +
(b_out - 1) * (b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
size_t p_index =
b_out_i + b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i));
p[p_index] = cg[cg_index];
}
{
size_t p_index = (b_out - 1) +
b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i));
p[p_index] = 0;
}
}
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t p_index =
(b_out - 1) +
b_out * (b_in_i + b_in * ((a_out - 1) + a_out * a_in_i));
p[p_index] = 0;
}
}
}
// L contributes to every joint probability, for a fixed choice of inputs,
// which we'll take to be 0,0.
data_t l = cg[cg_row_size - 1];
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
size_t p_index = b_out_i + b_out * (0 + b_in * (a_out_i + a_out * 0));
// Negative sign comes from moving L to the LHS of the inequality
p[p_index] -= l;
}
}
// Account for the marginal probabilities given.
// The convention will be that, where ambiguous, i.e., in the input of the
// other party, we will take it to be 0.
// A's marginals:
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
size_t cg_index = (a_out - 1) * (b_out - 1) * a_in * b_in + a_out_i +
(a_out - 1) * a_in_i;
data_t a_marginal = cg[cg_index];
if (a_marginal == 0) {
continue;
}
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
// printf("A(%zu|%zu) contributing to (%zu,%zu|%zu,0) with %" PRImDATA
//"\n",
// a_out_i, a_in_i, a_out_i, b_out_i, a_in_i, a_marginal);
size_t p_index =
b_out_i + b_out * (0 + b_in * (a_out_i + a_out * a_in_i));
p[p_index] += a_marginal;
// PrintP(a_out, b_out, a_in, b_in, p);
}
}
}
// B's marginals:
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + b_out_i + (b_out - 1) * b_in_i;
data_t b_marginal = cg[cg_index];
if (b_marginal == 0) {
continue;
}
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
// printf("B(%zu|%zu) contributing to (%zu,%zu|0,%zu) with %" PRImDATA
//"\n",
// b_out_i, b_in_i, a_out_i, b_out_i, b_in_i, b_marginal);
size_t p_index =
b_out_i + b_out * (b_in_i + b_in * (a_out_i + a_out * 0));
p[p_index] += b_marginal;
// PrintP(a_out, b_out, a_in, b_in, p);
}
}
}
}
void FromPToCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
data_t *restrict cg, const data_t *restrict p) {
size_t cg_len = ((a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + (b_out - 1) * b_in) +
1;
// 1. Copy the common factors & zero-out the marginals
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
{ // Common factors
size_t p_index =
b_out_i + b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i));
size_t cg_index =
b_out_i +
(b_out - 1) *
(b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
cg[cg_index] = p[p_index];
}
}
}
{ // Zero A's marginals
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in_i + a_out_i;
cg[cg_marginal_index] = 0;
}
}
}
// Zero B's marginals
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + (b_out - 1) * b_in_i +
b_out_i;
cg[cg_marginal_index] = 0;
}
}
// 2. Contributions from the terms with both output labels OOBs
{
data_t *l = cg + (cg_len - 1);
*l = 0;
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
size_t oob_p_index =
(b_out - 1) +
b_out * (b_in_i + b_in * ((a_out - 1) + a_out * a_in_i));
// L contribution
// The minus side comes from considering L to be on the RHS of the
// inequality.
*l -= p[oob_p_index];
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
// Contributions to terms with no output labels OOBs
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index =
b_out_i +
(b_out - 1) *
(b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
cg[cg_index] += p[oob_p_index];
}
{ // Contribution to A's marginal
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in_i + a_out_i;
cg[cg_marginal_index] -= p[oob_p_index];
}
}
// Contributions to B's marginal
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + (b_out - 1) * b_in_i +
b_out_i;
cg[cg_marginal_index] -= p[oob_p_index];
}
}
}
}
// 3. Contributions from the terms with the B output label OOBs
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in_i + a_out_i;
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
size_t oob_p_index =
(b_out - 1) +
b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i));
// A's marginals contributions
cg[cg_marginal_index] += p[oob_p_index];
// Contributions to terms with no labels OOBs
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t cg_index =
b_out_i +
(b_out - 1) * (b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
cg[cg_index] -= p[oob_p_index];
}
}
}
}
// 4. Contributions from the A output label OOBs.
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
{ // B's marginals contributions
size_t cg_marginal_index = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + (b_out - 1) * b_in_i +
b_out_i;
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
size_t oob_p_index =
(b_out - 1) +
b_out * (b_in_i + b_in * ((a_out - 1) + a_out * a_in_i));
cg[cg_marginal_index] += p[oob_p_index];
}
}
// Other contributions
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
size_t oob_p_index =
b_out_i + b_out * (b_in_i + b_in * ((a_out - 1) + a_out * a_in_i));
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
size_t cg_index =
b_out_i +
(b_out - 1) * (b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
cg[cg_index] -= p[oob_p_index];
}
}
}
}
}
void PrintCg(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t *cg) {
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t idx =
b_out_i +
(b_out - 1) * (b_in_i + b_in * (a_out_i + (a_out - 1) * a_in_i));
if (cg[idx] != 0) {
printf("(%zu,%zu|%zu,%zu): %" PRImDATA ", ", a_out_i, b_out_i,
a_in_i, b_in_i, cg[idx]);
}
}
}
}
}
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t a_out_i = 0; a_out_i < a_out - 1; a_out_i++) {
size_t idx = (a_out - 1) * (b_out - 1) * a_in * b_in + a_out_i +
(a_out - 1) * a_in_i;
if (cg[idx] != 0) {
printf("A(%zu|%zu): %" PRImDATA ", ", a_out_i, a_in_i, cg[idx]);
}
}
}
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
for (size_t b_out_i = 0; b_out_i < b_out - 1; b_out_i++) {
size_t idx = (a_out - 1) * (b_out - 1) * a_in * b_in +
(a_out - 1) * a_in + b_out_i + (b_out - 1) * b_in_i;
if (cg[idx] != 0) {
printf("B(%zu|%zu): %" PRImDATA ", ", b_out_i, b_in_i, cg[idx]);
}
}
}
{
size_t idx = (a_out - 1) * (b_out - 1) * a_in * b_in + (a_out - 1) * a_in +
(b_out - 1) * b_in;
printf("L: %" PRImDATA, cg[idx]);
}
printf("\n");
}
// Debug print a P row.
void PrintP(size_t a_out, size_t b_out, size_t a_in, size_t b_in,
const data_t *p) {
for (size_t a_out_i = 0; a_out_i < a_out; a_out_i++) {
for (size_t b_out_i = 0; b_out_i < b_out; b_out_i++) {
for (size_t a_in_i = 0; a_in_i < a_in; a_in_i++) {
for (size_t b_in_i = 0; b_in_i < b_in; b_in_i++) {
size_t idx =
b_out_i + b_out * (b_in_i + b_in * (a_out_i + a_out * a_in_i));
if (p[idx] != 0) {
printf("(%zu,%zu|%zu,%zu): %" PRImDATA ", ", a_out_i, b_out_i,
a_in_i, b_in_i, p[idx]);
}
}
}
}
}
printf("\n");
}