490 lines
15 KiB
Rust
490 lines
15 KiB
Rust
extern crate pest;
|
|
#[macro_use]
|
|
extern crate pest_derive;
|
|
#[macro_use]
|
|
extern crate bitflags;
|
|
|
|
pub mod vm;
|
|
|
|
use pyo3::prelude::*;
|
|
|
|
// We cache the A and B set up to a total of ~2Gb of memory,
|
|
// meaning ~ 10^7 elements in each set
|
|
const CACHE_SIZE_LIMIT: usize = 10_000_000;
|
|
|
|
#[pyclass]
|
|
pub struct BoundsResult {
|
|
#[pyo3(get)]
|
|
min_x_relations: u64,
|
|
#[pyo3(get)]
|
|
min_y_relations: u64,
|
|
#[pyo3(get)]
|
|
max_joint_relations: u128,
|
|
#[pyo3(get)]
|
|
single_bounding_x: u64,
|
|
#[pyo3(get)]
|
|
single_bounding_y: u64,
|
|
#[pyo3(get)]
|
|
joint_bounding_x: u64,
|
|
#[pyo3(get)]
|
|
joint_bounding_y: u64,
|
|
#[pyo3(get)]
|
|
bounding_window: u64,
|
|
}
|
|
|
|
#[pyclass]
|
|
pub struct Prover {
|
|
a_description: vm::VmCode,
|
|
b_description: vm::VmCode,
|
|
relationship: vm::VmCode,
|
|
hints: ProofHints,
|
|
}
|
|
|
|
bitflags! {
|
|
struct ProofHints: u8 {
|
|
const SymmetricFn = 0b00000001;
|
|
}
|
|
}
|
|
|
|
struct CachedSetIterator<'code> {
|
|
cache: Vec<u64>,
|
|
filter: &'code vm::VmCode,
|
|
stack: vm::VmStack,
|
|
n: u32,
|
|
p: u32,
|
|
k: u32,
|
|
cached: bool,
|
|
counter: u64,
|
|
cache_counter: usize,
|
|
top: u64,
|
|
is_x: bool,
|
|
}
|
|
|
|
impl<'code> CachedSetIterator<'code> {
|
|
fn create_x(filter: &'code vm::VmCode, n: u32, p: u32, k: u32) -> Self {
|
|
CachedSetIterator {
|
|
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
|
|
filter,
|
|
stack: vm::VmStack::from_code(filter),
|
|
n,
|
|
p,
|
|
k,
|
|
cached: false,
|
|
counter: 0,
|
|
cache_counter: 0,
|
|
top: 2_u64.pow(n as u32),
|
|
is_x: true,
|
|
}
|
|
}
|
|
|
|
fn create_y(filter: &'code vm::VmCode, n: u32, p: u32, k: u32) -> Self {
|
|
CachedSetIterator {
|
|
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
|
|
filter,
|
|
stack: vm::VmStack::from_code(filter),
|
|
n,
|
|
p,
|
|
k,
|
|
cached: false,
|
|
counter: 0,
|
|
cache_counter: 0,
|
|
top: 2_u64.pow(n as u32),
|
|
is_x: false,
|
|
}
|
|
}
|
|
|
|
fn reset(&mut self) {
|
|
self.counter = 0;
|
|
self.cache_counter = 0;
|
|
}
|
|
}
|
|
|
|
impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
|
type Item = u64;
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
if self.cached {
|
|
if self.cache_counter < self.cache.len() {
|
|
let result = self.cache[self.cache_counter];
|
|
self.cache_counter += 1;
|
|
|
|
if self.cache_counter == CACHE_SIZE_LIMIT {
|
|
self.counter = self.cache.last().unwrap() + 1;
|
|
}
|
|
|
|
return Some(result);
|
|
} else if self.cache.len() < CACHE_SIZE_LIMIT {
|
|
self.reset();
|
|
return None;
|
|
}
|
|
}
|
|
|
|
loop {
|
|
if self.counter == self.top {
|
|
self.cached = true;
|
|
self.reset();
|
|
return None;
|
|
}
|
|
|
|
let included = if self.is_x {
|
|
vm::Vm::load(
|
|
self.filter,
|
|
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
|
|
&mut self.stack,
|
|
)
|
|
.run()
|
|
.output_bool()
|
|
.unwrap()
|
|
} else {
|
|
vm::Vm::load(
|
|
self.filter,
|
|
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
|
|
&mut self.stack,
|
|
)
|
|
.run()
|
|
.output_bool()
|
|
.unwrap()
|
|
};
|
|
|
|
if !included {
|
|
self.counter += 1;
|
|
continue;
|
|
} else {
|
|
let result = self.counter;
|
|
if self.cache.len() < CACHE_SIZE_LIMIT {
|
|
self.cache.push(result);
|
|
}
|
|
|
|
self.counter += 1;
|
|
return Some(result);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// See https://stackoverflow.com/a/27755938
|
|
// In turn from http://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
|
|
struct FixedHammingWeight {
|
|
ones: u32,
|
|
permutation: u64,
|
|
top: u64,
|
|
exhausted: bool,
|
|
}
|
|
|
|
impl FixedHammingWeight {
|
|
fn new(bits: u32, ones: u32) -> Self {
|
|
FixedHammingWeight {
|
|
ones,
|
|
permutation: (1 << ones) - 1,
|
|
top: ((1 << ones) - 1) << (bits - ones),
|
|
exhausted: false,
|
|
}
|
|
}
|
|
|
|
fn reset(&mut self) {
|
|
self.permutation = (1 << self.ones) - 1;
|
|
self.exhausted = false;
|
|
}
|
|
}
|
|
|
|
impl Iterator for FixedHammingWeight {
|
|
type Item = u64;
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
if self.exhausted {
|
|
self.reset();
|
|
return None;
|
|
}
|
|
let to_yield = self.permutation;
|
|
if to_yield == self.top {
|
|
self.exhausted = true;
|
|
} else {
|
|
let t = self.permutation | (self.permutation - 1);
|
|
self.permutation = (t + 1)
|
|
| (((!t & (!t).wrapping_neg()) - 1) >> (self.permutation.trailing_zeros() + 1));
|
|
}
|
|
return Some(to_yield);
|
|
}
|
|
}
|
|
|
|
#[pymethods]
|
|
impl Prover {
|
|
/// Initializes a searcher for lower bounds.
|
|
///
|
|
/// This object prepares the search for lower bounds according to Ambainis et al.'s
|
|
/// adversary method. The bounds themselves for specific values of the parameters can
|
|
/// be found with the associated method `find_bounds`.
|
|
///
|
|
/// Arguments:
|
|
/// a_description A description of elements that should belong to the A set,
|
|
/// as specified discriminating on x [string]
|
|
/// b_description A description of elements that should belong to the B set,
|
|
/// as specified discriminating on y [string]
|
|
/// relationship A description of a relation, in the form of a boolean expression
|
|
/// depending on x and y, and resolving to true if x is related to
|
|
/// y [string]
|
|
#[new]
|
|
pub fn py_new(
|
|
a_description: String,
|
|
b_description: String,
|
|
relationship: String,
|
|
) -> PyResult<Self> {
|
|
// Parse ASTs
|
|
|
|
let relationship = match vm::parsing::parse_relation(&relationship) {
|
|
Ok(relationship) => relationship,
|
|
Err(msg) => {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
|
"When parsing `relationship`:\n{}",
|
|
msg
|
|
)));
|
|
}
|
|
};
|
|
|
|
let a_description = match vm::parsing::parse_relation(&a_description) {
|
|
Ok(a_description) => a_description,
|
|
Err(msg) => {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
|
"When parsing `a_description`:\n{}",
|
|
msg
|
|
)));
|
|
}
|
|
};
|
|
|
|
let b_description = match vm::parsing::parse_relation(&b_description) {
|
|
Ok(b_description) => b_description,
|
|
Err(msg) => {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
|
"When parsing `b_description`:\n{}",
|
|
msg
|
|
)));
|
|
}
|
|
};
|
|
|
|
// Compile into bytecode
|
|
let relationship = vm::compile_boolean(relationship);
|
|
let a_description = vm::compile_boolean(a_description);
|
|
let b_description = vm::compile_boolean(b_description);
|
|
|
|
// Ensure that a_description only references X, and b_description only references Y
|
|
if a_description
|
|
.any(|opcode| matches!(opcode, vm::OpCode::Variable(vm::parsing::ast::Variable::Y)))
|
|
{
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"`a_description` cannot reference variable y.",
|
|
));
|
|
}
|
|
if b_description
|
|
.any(|opcode| matches!(opcode, vm::OpCode::Variable(vm::parsing::ast::Variable::X)))
|
|
{
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"`b_description` cannot reference variable x.",
|
|
));
|
|
}
|
|
|
|
Ok(Prover {
|
|
relationship,
|
|
a_description,
|
|
b_description,
|
|
hints: ProofHints::empty(),
|
|
})
|
|
}
|
|
|
|
/// Hints to the prover that the function in question is totally symmetric,
|
|
/// or, equivalently, that `a_description` only depends on the Hamming weight
|
|
/// of x, and `b_description` only depends on the Hamming weight of y.
|
|
///
|
|
/// If you l your results will lower bound the upper bound.
|
|
pub fn hint_symmetric<'s>(mut py_self: PyRefMut<'s, Self>, py: Python) -> PyRefMut<'s, Self> {
|
|
(*py_self).hints |= ProofHints::SymmetricFn;
|
|
py_self
|
|
}
|
|
|
|
/// Finds a lower bound to the query complexity according to the specified parameters.
|
|
///
|
|
/// Arguments:
|
|
/// n The maximum number of bits to consider when exhausting bitstrings
|
|
/// belonging to either of the specified sets [int]
|
|
/// p The number of parallel queries [int]
|
|
/// k Parameter accepted in set specifications and relations [int]
|
|
pub fn find_bounds(&self, n: u32, p: u32, k: u32) -> PyResult<BoundsResult> {
|
|
if n > 63 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"More than 63 bits is not supported. (Because of `n` parameter.)",
|
|
));
|
|
}
|
|
if p > 63 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"More than 63 bits is not supported. (Because of `p` parameter.)",
|
|
));
|
|
}
|
|
if k > 63 {
|
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
|
"More than 63 bits is not supported. (Because of `k` parameter.)",
|
|
));
|
|
}
|
|
|
|
let mut vm_stack = vm::VmStack::from_code(&self.relationship);
|
|
|
|
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
|
|
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
|
|
|
|
let mut min_x_relations = u64::MAX;
|
|
let mut min_y_relations = u64::MAX;
|
|
let mut max_joint_relations = 0_u128;
|
|
let mut single_bounding_x = 0_u64;
|
|
let mut single_bounding_y = 0_u64;
|
|
let mut joint_bounding_x = 0_u64;
|
|
let mut joint_bounding_y = 0_u64;
|
|
let mut bounding_window = 0_u64;
|
|
|
|
let mut window_iterator = FixedHammingWeight::new(n, p).into_iter();
|
|
let mut first_entry = [(1_u64 << p) - 1].into_iter();
|
|
let effective_window_iterator = if self.hints.contains(ProofHints::SymmetricFn) {
|
|
&mut first_entry as &mut dyn Iterator<Item = u64>
|
|
} else {
|
|
&mut window_iterator as &mut dyn Iterator<Item = u64>
|
|
};
|
|
|
|
for window in effective_window_iterator {
|
|
let mut max_x_relations = 0_u64;
|
|
let mut max_y_relations = 0_u64;
|
|
let mut bounding_x_candidate = 0_u64;
|
|
let mut bounding_y_candidate = 0_u64;
|
|
|
|
for x in &mut a_set_iterator {
|
|
let mut min_x_relations_candidate = 0_u64;
|
|
let mut max_x_relations_candidate = 0_u64;
|
|
|
|
for y in &mut b_set_iterator {
|
|
let related = vm::Vm::load(
|
|
&self.relationship,
|
|
vm::Registers::load(x, y, n, p, k),
|
|
&mut vm_stack,
|
|
)
|
|
.run()
|
|
.output_bool()
|
|
.unwrap();
|
|
if related {
|
|
min_x_relations_candidate += 1;
|
|
|
|
if (x & window) != (y & window) {
|
|
max_x_relations_candidate += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
if min_x_relations_candidate < min_x_relations {
|
|
min_x_relations = min_x_relations_candidate;
|
|
single_bounding_x = x;
|
|
}
|
|
if max_x_relations_candidate > max_x_relations {
|
|
max_x_relations = max_x_relations_candidate;
|
|
bounding_x_candidate = x;
|
|
}
|
|
}
|
|
|
|
for y in &mut b_set_iterator {
|
|
let mut min_y_relations_candidate = 0_u64;
|
|
let mut max_y_relations_candidate = 0_u64;
|
|
|
|
for x in &mut a_set_iterator {
|
|
let related = vm::Vm::load(
|
|
&self.relationship,
|
|
vm::Registers::load(x, y, n, p, k),
|
|
&mut vm_stack,
|
|
)
|
|
.run()
|
|
.output_bool()
|
|
.unwrap();
|
|
if related {
|
|
min_y_relations_candidate += 1;
|
|
|
|
if (x & window) != (y & window) {
|
|
max_y_relations_candidate += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
if min_y_relations_candidate < min_y_relations {
|
|
min_y_relations = min_y_relations_candidate;
|
|
single_bounding_y = y;
|
|
}
|
|
|
|
if max_y_relations_candidate > max_y_relations {
|
|
max_y_relations = max_y_relations_candidate;
|
|
bounding_y_candidate = y;
|
|
}
|
|
}
|
|
|
|
let max_joint_relations_candidate = max_x_relations as u128 * max_y_relations as u128;
|
|
if max_joint_relations_candidate > max_joint_relations {
|
|
max_joint_relations = max_joint_relations_candidate;
|
|
joint_bounding_x = bounding_x_candidate;
|
|
joint_bounding_y = bounding_y_candidate;
|
|
bounding_window = window;
|
|
}
|
|
|
|
Python::with_gil(|py| py.check_signals())?;
|
|
}
|
|
|
|
Ok(BoundsResult {
|
|
min_x_relations,
|
|
min_y_relations,
|
|
max_joint_relations,
|
|
single_bounding_x,
|
|
single_bounding_y,
|
|
joint_bounding_x,
|
|
joint_bounding_y,
|
|
bounding_window,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// A python module to brute force lower bounds per Ambainis & Co.'s adversarial
|
|
/// method.
|
|
#[pymodule]
|
|
fn adversary(_py: Python, m: &PyModule) -> PyResult<()> {
|
|
m.add_class::<BoundsResult>()?;
|
|
m.add_class::<Prover>()?;
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_cached_iterator() {
|
|
let n = 10;
|
|
let k = 3;
|
|
let p = 5;
|
|
let filter =
|
|
vm::compile_boolean(vm::parsing::parse_relation("<= ham x k").expect("Valid expression"));
|
|
let mut iterator = CachedSetIterator::create_x(&filter, n, p, k);
|
|
let first_count = iterator.count();
|
|
let second_count = iterator.count();
|
|
for x in &mut iterator {
|
|
assert!(x < 2_u64.pow(n as u32));
|
|
assert!(x.count_ones() <= k.into());
|
|
}
|
|
assert!(first_count == second_count);
|
|
}
|
|
|
|
#[test]
|
|
fn test_hamming_strings() {
|
|
for x in FixedHammingWeight::new(10, 3) {
|
|
assert!(x.count_ones() <= 3);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_full_run() {
|
|
pyo3::prepare_freethreaded_python();
|
|
pyo3::Python::with_gil(|_py| {
|
|
let obj = Prover::py_new(
|
|
"= (ham x) k".to_string(),
|
|
"= (ham y) (+ k 1)".to_string(),
|
|
"<= ham (^ x y) p".to_string(),
|
|
)
|
|
.unwrap();
|
|
std::hint::black_box(obj.find_bounds(10, 5, 3)).expect("Success");
|
|
})
|
|
}
|