quantum_queries/src/lib.rs

440 lines
14 KiB
Rust

extern crate pest;
#[macro_use]
extern crate pest_derive;
mod vm;
use pyo3::prelude::*;
use vm::VmCode;
// We cache the A and B set up to a total of ~2Gb of memory,
// meaning ~ 10^7 elements in each set
const CACHE_SIZE_LIMIT: usize = 10_000_000;
#[pyclass]
pub struct BoundsResult {
#[pyo3(get)]
min_x_relations: u64,
#[pyo3(get)]
min_y_relations: u64,
#[pyo3(get)]
max_joint_relations: u128,
#[pyo3(get)]
single_bounding_x: u64,
#[pyo3(get)]
single_bounding_y: u64,
#[pyo3(get)]
joint_bounding_x: u64,
#[pyo3(get)]
joint_bounding_y: u64,
#[pyo3(get)]
bounding_window: u64,
}
#[pyclass]
pub struct Prover {
a_description: VmCode,
b_description: VmCode,
relationship: VmCode,
}
struct CachedSetIterator<'code> {
cache: Vec<u64>,
filter: &'code VmCode,
n: u32,
p: u32,
k: u32,
cached: bool,
counter: u64,
cache_counter: usize,
top: u64,
is_x: bool,
}
impl<'code> CachedSetIterator<'code> {
fn create_x(filter: &'code VmCode, n: u32, p: u32, k: u32) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
n,
p,
k,
cached: false,
counter: 0,
cache_counter: 0,
top: 2_u64.pow(n as u32),
is_x: true,
}
}
fn create_y(filter: &'code VmCode, n: u32, p: u32, k: u32) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
n,
p,
k,
cached: false,
counter: 0,
cache_counter: 0,
top: 2_u64.pow(n as u32),
is_x: false,
}
}
fn reset(&mut self) {
self.counter = 0;
self.cache_counter = 0;
}
}
impl<'code> Iterator for &mut CachedSetIterator<'code> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
if self.cached {
if self.cache_counter < self.cache.len() {
let result = self.cache[self.cache_counter];
self.cache_counter += 1;
if self.cache_counter == CACHE_SIZE_LIMIT {
self.counter = self.cache.last().unwrap() + 1;
}
return Some(result);
} else if self.cache.len() < CACHE_SIZE_LIMIT {
self.reset();
return None;
}
}
loop {
if self.counter == self.top {
self.cached = true;
self.reset();
return None;
}
let included = if self.is_x {
vm::Vm::load(
self.filter,
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
} else {
vm::Vm::load(
self.filter,
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
};
if !included {
self.counter += 1;
continue;
} else {
let result = self.counter;
if self.cache.len() < CACHE_SIZE_LIMIT {
self.cache.push(result);
}
self.counter += 1;
return Some(result);
}
}
}
}
// See https://stackoverflow.com/a/27755938
// In turn from http://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
struct FixedHammingWeight {
permutation: u64,
top: u64,
exhausted: bool,
}
impl FixedHammingWeight {
fn new(bits: u32, ones: u32) -> Self {
FixedHammingWeight {
permutation: (1 << ones) - 1,
top: ((1 << ones) - 1) << (bits - ones),
exhausted: false,
}
}
}
impl Iterator for FixedHammingWeight {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
if self.exhausted {
return None;
}
let to_yield = self.permutation;
if to_yield == self.top {
self.exhausted = true;
} else {
let t = self.permutation | (self.permutation - 1);
self.permutation = (t + 1)
| (((!t & (!t).wrapping_neg()) - 1) >> (self.permutation.trailing_zeros() + 1));
}
return Some(to_yield);
}
}
#[pymethods]
impl Prover {
/// Initializes a searcher for lower bounds.
///
/// This object prepares the search for lower bounds according to Ambainis et al.'s
/// adversary method. The bounds themselves for specific values of the parameters can
/// be found with the associated method `find_bounds`.
///
/// Arguments:
/// a_description A description of elements that should belong to the A set,
/// as specified discriminating on x [string]
/// b_description A description of elements that should belong to the B set,
/// as specified discriminating on y [string]
/// relationship A description of a relation, in the form of a boolean expression
/// depending on x and y, and resolving to true if x is related to
/// y [string]
#[new]
pub fn py_new(
a_description: String,
b_description: String,
relationship: String,
) -> PyResult<Self> {
// Parse ASTs
let relationship = match vm::parsing::parse_relation(&relationship) {
Ok(relationship) => relationship,
Err(msg) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"When parsing `relationship`:\n{}",
msg
)));
}
};
let a_description = match vm::parsing::parse_relation(&a_description) {
Ok(a_description) => a_description,
Err(msg) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"When parsing `a_description`:\n{}",
msg
)));
}
};
let b_description = match vm::parsing::parse_relation(&b_description) {
Ok(b_description) => b_description,
Err(msg) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"When parsing `b_description`:\n{}",
msg
)));
}
};
// Compile into bytecode
let relationship = vm::compile_boolean(relationship);
let a_description = vm::compile_boolean(a_description);
let b_description = vm::compile_boolean(b_description);
// Ensure that a_description only references X, and b_description only references Y
if a_description
.any(|opcode| matches!(opcode, vm::OpCode::Variable(vm::parsing::ast::Variable::Y)))
{
return Err(pyo3::exceptions::PyValueError::new_err(
"`a_description` cannot reference variable y.",
));
}
if b_description
.any(|opcode| matches!(opcode, vm::OpCode::Variable(vm::parsing::ast::Variable::X)))
{
return Err(pyo3::exceptions::PyValueError::new_err(
"`b_description` cannot reference variable x.",
));
}
Ok(Prover {
relationship,
a_description,
b_description,
})
}
/// Finds a lower bound to the query complexity according to the specified parameters.
///
/// Arguments:
/// n The maximum number of bits to consider when exhausting bitstrings
/// belonging to either of the specified sets [int]
/// p The number of parallel queries [int]
/// k Parameter accepted in set specifications and relations [int]
pub fn find_bounds(&self, n: u32, p: u32, k: u32) -> PyResult<BoundsResult> {
if n > 63 {
return Err(pyo3::exceptions::PyValueError::new_err(
"More than 63 bits is not supported. (Because of `n` parameter.)",
));
}
if p > 63 {
return Err(pyo3::exceptions::PyValueError::new_err(
"More than 63 bits is not supported. (Because of `p` parameter.)",
));
}
if k > 63 {
return Err(pyo3::exceptions::PyValueError::new_err(
"More than 63 bits is not supported. (Because of `k` parameter.)",
));
}
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
let window_iterator = FixedHammingWeight::new(n, p);
let mut min_x_relations = u64::MAX;
let mut min_y_relations = u64::MAX;
let mut max_joint_relations = 0_u128;
let mut single_bounding_x = 0_u64;
let mut single_bounding_y = 0_u64;
let mut joint_bounding_x = 0_u64;
let mut joint_bounding_y = 0_u64;
let mut bounding_window = 0_u64;
for window in window_iterator {
let mut max_x_relations = 0_u64;
let mut max_y_relations = 0_u64;
let mut bounding_x_candidate = 0_u64;
let mut bounding_y_candidate = 0_u64;
for x in &mut a_set_iterator {
let mut min_x_relations_candidate = 0_u64;
let mut max_x_relations_candidate = 0_u64;
for y in &mut b_set_iterator {
let related = vm::Vm::load(
&self.relationship,
vm::Registers::load(x, y, n, p, k),
)
.run()
.unwrap_bool();
if related {
min_x_relations_candidate += 1;
if (x & window) != (y & window) {
max_x_relations_candidate += 1;
}
}
}
if min_x_relations_candidate < min_x_relations {
min_x_relations = min_x_relations_candidate;
single_bounding_x = x;
}
if max_x_relations_candidate > max_x_relations {
max_x_relations = max_x_relations_candidate;
bounding_x_candidate = x;
}
}
for y in &mut b_set_iterator {
let mut min_y_relations_candidate = 0_u64;
let mut max_y_relations_candidate = 0_u64;
for x in &mut a_set_iterator {
let related =
vm::Vm::load(&self.relationship, vm::Registers::load(x, y, n, p, k))
.run()
.unwrap_bool();
if related {
min_y_relations_candidate += 1;
if (x & window) != (y & window) {
max_y_relations_candidate += 1;
}
}
}
if min_y_relations_candidate < min_y_relations {
min_y_relations = min_y_relations_candidate;
single_bounding_y = y;
}
if max_y_relations_candidate > max_y_relations {
max_y_relations = max_y_relations_candidate;
bounding_y_candidate = y;
}
}
let max_joint_relations_candidate = max_x_relations as u128 * max_y_relations as u128;
if max_joint_relations_candidate > max_joint_relations {
max_joint_relations = max_joint_relations_candidate;
joint_bounding_x = bounding_x_candidate;
joint_bounding_y = bounding_y_candidate;
bounding_window = window;
Python::with_gil(|py| py.check_signals())?;
}
}
Ok(BoundsResult {
min_x_relations,
min_y_relations,
max_joint_relations,
single_bounding_x,
single_bounding_y,
joint_bounding_x,
joint_bounding_y,
bounding_window,
})
}
}
/// A python module to brute force lower bounds per Ambainis & Co.'s adversarial
/// method.
#[pymodule]
fn adversary(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<BoundsResult>()?;
m.add_class::<Prover>()?;
Ok(())
}
#[test]
fn test_cached_iterator() {
let n = 10;
let k = 3;
let p = 5;
let filter =
vm::compile_boolean(vm::parsing::parse_relation("<= ham x k").expect("Valid expression"));
let mut iterator = CachedSetIterator::create_x(&filter, n, p, k);
let first_count = iterator.count();
let second_count = iterator.count();
for x in &mut iterator {
assert!(x < 2_u64.pow(n as u32));
assert!(x.count_ones() <= k.into());
}
assert!(first_count == second_count);
}
#[test]
fn test_hamming_strings() {
for x in FixedHammingWeight::new(10, 3) {
assert!(x.count_ones() <= 3);
}
}
#[test]
fn test_full_run() {
pyo3::prepare_freethreaded_python();
pyo3::Python::with_gil(|_py| {
let obj = Prover::py_new(
"= (ham x) k".to_string(),
"= (ham y) (+ k 1)".to_string(),
"<= ham (^ x y) p".to_string(),
)
.unwrap();
std::hint::black_box(obj.find_bounds(10, 5, 3)).expect("Success");
})
}