first implementation of bounds search

This commit is contained in:
Miguel M 2023-02-23 02:07:40 +00:00
parent a5ea324114
commit 9bd4dcf010
5 changed files with 445 additions and 40 deletions

View File

@ -5,19 +5,207 @@ extern crate pest_derive;
mod vm;
use pyo3::prelude::*;
use vm::VmCode;
// We cache the A and B set up to a total of ~2Gb of memory,
// meaning ~ 10^7 elements in each set
const CACHE_SIZE_LIMIT: usize = 10_000_000;
#[pyclass]
struct Prover {}
struct BoundsResult {
#[pyo3(get)]
min_x_relations: u64,
#[pyo3(get)]
min_y_relations: u64,
#[pyo3(get)]
max_joint_relations: u128,
#[pyo3(get)]
single_bounding_x: u64,
#[pyo3(get)]
single_bounding_y: u64,
#[pyo3(get)]
joint_bounding_x: u64,
#[pyo3(get)]
joint_bounding_y: u64,
#[pyo3(get)]
bounding_window: u64,
}
#[pyclass]
struct Prover {
a_description: VmCode,
b_description: VmCode,
relationship: VmCode,
}
struct CachedSetIterator<'code> {
cache: Vec<u64>,
filter: &'code VmCode,
n: u8,
p: u8,
k: u8,
cached: bool,
counter: u64,
cache_counter: usize,
top: u64,
is_x: bool,
}
impl<'code> CachedSetIterator<'code> {
fn create_x(filter: &'code VmCode, n: u8, p: u8, k: u8) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
n,
p,
k,
cached: false,
counter: 0,
cache_counter: 0,
top: 2_u64.pow(n as u32),
is_x: true,
}
}
fn create_y(filter: &'code VmCode, n: u8, p: u8, k: u8) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
n,
p,
k,
cached: false,
counter: 0,
cache_counter: 0,
top: 2_u64.pow(n as u32),
is_x: false,
}
}
fn reset(&mut self) {
self.counter = 0;
self.cache_counter = 0;
}
}
impl<'code> Iterator for &mut CachedSetIterator<'code> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
if self.cached {
if self.cache_counter < self.cache.len() {
let result = self.cache[self.cache_counter];
self.cache_counter += 1;
if self.cache_counter == CACHE_SIZE_LIMIT {
self.counter = self.cache.last().unwrap() + 1;
}
return Some(result);
} else if self.cache.len() < CACHE_SIZE_LIMIT {
self.reset();
return None;
}
}
if self.counter == self.top {
self.cached = true;
self.reset();
return None;
}
let included = if self.is_x {
vm::Vm::load(
self.filter,
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
} else {
vm::Vm::load(
self.filter,
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
};
if !included {
self.counter += 1;
return self.next();
}
let result = self.counter;
if self.cache.len() < CACHE_SIZE_LIMIT {
self.cache.push(result);
}
self.counter += 1;
Some(result)
}
}
// See https://stackoverflow.com/a/27755938
// In turn from http://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
struct FixedHammingWeight {
permutation: u64,
top: u64,
exhausted: bool,
}
impl FixedHammingWeight {
fn new(bits: u8, ones: u8) -> Self {
FixedHammingWeight {
permutation: (1 << ones) - 1,
top: ((1 << ones) - 1) << (bits - ones),
exhausted: false,
}
}
}
impl Iterator for FixedHammingWeight {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
if self.exhausted {
return None;
}
let to_yield = self.permutation;
if to_yield == self.top {
self.exhausted = true;
} else {
let t = self.permutation | (self.permutation - 1);
self.permutation = (t + 1)
| (((!t & (!t).wrapping_neg()) - 1) >> (self.permutation.trailing_zeros() + 1));
}
return Some(to_yield);
}
}
#[pymethods]
impl Prover {
/// Initializes a searcher for lower bounds.
///
/// This object prepares the search for lower bounds according to Ambainis et al.'s
/// adversary method. The bounds themselves for specific values of the parameters can
/// be found with the associated method `find_bounds`.
///
/// Arguments:
/// a_description A description of elements that should belong to the A set,
/// as specified discriminating on x [string]
/// b_description A description of elements that should belong to the B set,
/// as specified discriminating on y [string]
/// relationship A description of a relation, in the form of a boolean expression
/// depending on x and y, and resolving to true if x is related to
/// y [string]
#[new]
fn py_new(
a_set: String,
b_set: String,
a_description: String,
b_description: String,
relationship: String,
conjecture: String,
) -> PyResult<Self> {
// Parse ASTs
let relationship = match vm::parsing::parse_relation(&relationship) {
Ok(relationship) => relationship,
Err(msg) => {
@ -28,38 +216,171 @@ impl Prover {
}
};
let a_set = match vm::parsing::parse_relation(&a_set) {
Ok(a_set) => a_set,
let a_description = match vm::parsing::parse_relation(&a_description) {
Ok(a_description) => a_description,
Err(msg) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"When parsing `a_set`:\n{}",
"When parsing `a_description`:\n{}",
msg
)));
}
};
let b_set = match vm::parsing::parse_relation(&b_set) {
Ok(b_set) => b_set,
let b_description = match vm::parsing::parse_relation(&b_description) {
Ok(b_description) => b_description,
Err(msg) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"When parsing `b_set`:\n{}",
"When parsing `b_description`:\n{}",
msg
)));
}
};
let conjecture = match vm::parsing::parse_arithmetic(&conjecture) {
Ok(conjecture) => conjecture,
Err(msg) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"When parsing `conjecture`:\n{}",
msg
)));
}
// Compile into bytecode
let relationship = vm::compile_boolean(relationship);
let a_description = vm::compile_boolean(a_description);
let b_description = vm::compile_boolean(b_description);
// Ensure that a_description only references X, and b_description only references Y
if a_description
.any(|opcode| matches!(opcode, vm::OpCode::Variable(vm::parsing::ast::Variable::Y)))
{
return Err(pyo3::exceptions::PyValueError::new_err(
"`a_description` cannot reference variable y.",
));
}
if b_description
.any(|opcode| matches!(opcode, vm::OpCode::Variable(vm::parsing::ast::Variable::X)))
{
return Err(pyo3::exceptions::PyValueError::new_err(
"`b_description` cannot reference variable x.",
));
}
Ok(Prover {
relationship,
a_description,
b_description,
})
}
/// Finds a lower bound to the query complexity according to the specified parameters.
///
/// Arguments:
/// n The maximum number of bits to consider when exhausting bitstrings
/// belonging to either of the specified sets [int]
/// p The number of parallel queries [int]
/// k Parameter accepted in set specifications and relations [int]
fn find_bounds(&self, n: u8, p: u8, k: u8) -> PyResult<BoundsResult> {
if n > 63 {
return Err(pyo3::exceptions::PyValueError::new_err(
"More than 63 bits is not supported. (Because of `n` parameter.)",
));
}
if p > 63 {
return Err(pyo3::exceptions::PyValueError::new_err(
"More than 63 bits is not supported. (Because of `p` parameter.)",
));
}
if k > 63 {
return Err(pyo3::exceptions::PyValueError::new_err(
"More than 63 bits is not supported. (Because of `k` parameter.)",
));
}
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
let window_iterator = FixedHammingWeight::new(n, p);
let mut min_x_relations = u64::MAX;
let mut min_y_relations = u64::MAX;
let mut max_joint_relations = 0_u128;
let mut single_bounding_x = 0_u64;
let mut single_bounding_y = 0_u64;
let mut joint_bounding_x = 0_u64;
let mut joint_bounding_y = 0_u64;
let mut bounding_window = 0_u64;
let related = |x: u64, y: u64| {
vm::Vm::load(&self.relationship, vm::Registers::load(x, y, n, p, k))
.run()
.unwrap_bool()
};
todo!();
Ok(Prover {})
for window in window_iterator {
let mut max_x_relations = 0_u64;
let mut max_y_relations = 0_u64;
let mut bounding_x_candidate = 0_u64;
let mut bounding_y_candidate = 0_u64;
for x in &mut a_set_iterator {
let mut min_x_relations_candidate = 0_u64;
let mut max_x_relations_candidate = 0_u64;
for y in &mut b_set_iterator {
if related(x, y) {
min_x_relations_candidate += 1;
if (x & window) != (y & window) {
max_x_relations_candidate += 1;
}
}
}
if min_x_relations_candidate < min_x_relations {
min_x_relations = min_x_relations_candidate;
single_bounding_x = x;
}
if max_x_relations_candidate > max_x_relations {
max_x_relations = max_x_relations_candidate;
bounding_x_candidate = x;
}
}
for y in &mut b_set_iterator {
let mut min_y_relations_candidate = 0_u64;
let mut max_y_relations_candidate = 0_u64;
for x in &mut a_set_iterator {
if related(x, y) {
min_y_relations_candidate += 1;
if (x & window) != (y & window) {
max_y_relations_candidate += 1;
}
}
}
if min_y_relations_candidate < min_y_relations {
min_y_relations = min_y_relations_candidate;
single_bounding_y = y;
}
if max_y_relations_candidate > max_y_relations {
max_y_relations = max_y_relations_candidate;
bounding_y_candidate = y;
}
}
let max_joint_relations_candidate = max_x_relations as u128 * max_y_relations as u128;
if max_joint_relations_candidate > max_joint_relations {
max_joint_relations = max_joint_relations_candidate;
joint_bounding_x = bounding_x_candidate;
joint_bounding_y = bounding_y_candidate;
bounding_window = window;
}
}
Ok(BoundsResult {
min_x_relations,
min_y_relations,
max_joint_relations,
single_bounding_x,
single_bounding_y,
joint_bounding_x,
joint_bounding_y,
bounding_window,
})
}
}
@ -67,6 +388,31 @@ impl Prover {
/// method.
#[pymodule]
fn adversary(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<BoundsResult>()?;
m.add_class::<Prover>()?;
Ok(())
}
#[test]
fn test_cached_iterator() {
let n = 10;
let k = 3;
let p = 5;
let filter =
vm::compile_boolean(vm::parsing::parse_relation("<= ham x k").expect("Valid expression"));
let mut iterator = CachedSetIterator::create_x(&filter, n, p, k);
let first_count = iterator.count();
let second_count = iterator.count();
for x in &mut iterator {
assert!(x < 2_u64.pow(n as u32));
assert!(x.count_ones() <= k.into());
}
assert!(first_count == second_count);
}
#[test]
fn test_hamming_strings() {
for x in FixedHammingWeight::new(10, 3) {
assert!(x.count_ones() <= 3);
}
}

View File

@ -39,7 +39,7 @@ pub enum ArithmeticValue {
}
#[derive(Debug)]
enum OpCode {
pub enum OpCode {
UnaryBooleanOperator(ast::UnaryBooleanOperator),
BinaryArithmeticOperator(ast::BinaryArithmeticOperator),
ComparisonOperator(ast::ComparisonOperator),
@ -54,25 +54,61 @@ enum Expression {
Arithmetic(ast::ArithmeticExpression),
}
impl VmOutput {
pub fn unwrap_bool(self) -> bool {
match self {
VmOutput::Boolean(value) => value,
VmOutput::Arithmetic(_) => panic!("Expected boolean, got arithmetic value."),
}
}
pub fn unwrap_arithmetic(self) -> f64 {
match self {
VmOutput::Boolean(_) => panic!("Expected arithmetic, got boolean value."),
VmOutput::Arithmetic(value) => match value {
ArithmeticValue::Integer(value) => value as f64,
ArithmeticValue::Floating(value) => value,
},
}
}
}
impl Registers {
pub fn load(x: i64, y: i64, n: u32, p: u32, k: u32) -> Self {
Self { x, y, n, p, k }
pub fn load(x: u64, y: u64, n: u8, p: u8, k: u8) -> Self {
Self {
x: x as i64,
y: y as i64,
n: n as u32,
p: p as u32,
k: k as u32,
}
}
}
pub fn compile_boolean(expression: ast::BooleanExpression) -> VmCode {
let mut code = Vec::<Bytecode>::new();
compile_expression(Expression::Boolean(expression), &mut code);
VmCode { code }
}
pub fn compile_arithmetic(expression: ast::ArithmeticExpression) -> VmCode {
let mut code = Vec::<Bytecode>::new();
compile_expression(Expression::Arithmetic(expression), &mut code);
VmCode { code }
}
impl VmCode {
pub fn any<P: Fn(&OpCode) -> bool>(&self, predicate: P) -> bool {
for (opcode, _) in &self.code {
if predicate(opcode) {
return true;
}
}
false
}
}
impl<'code> Vm<'code> {
pub fn compile_boolean(expression: ast::BooleanExpression) -> VmCode {
let mut code = Vec::<Bytecode>::new();
compile_expression(Expression::Boolean(expression), &mut code);
VmCode { code }
}
pub fn compile_arithmetic(expression: ast::ArithmeticExpression) -> VmCode {
let mut code = Vec::<Bytecode>::new();
compile_expression(Expression::Arithmetic(expression), &mut code);
VmCode { code }
}
pub fn load(code: &'code VmCode, registers: Registers) -> Self {
Vm { code, registers }
}
@ -194,6 +230,15 @@ impl<'code> Vm<'code> {
ast::BinaryArithmeticOperator::Xor => {
ArithmeticValue::Integer(left_operand ^ right_operand)
}
ast::BinaryArithmeticOperator::Pow => {
if right_operand > 0 {
ArithmeticValue::Integer(left_operand.pow(right_operand as u32))
} else {
ArithmeticValue::Floating(
(left_operand as f64).powi(right_operand as i32),
)
}
}
};
let operations_as_floating =
@ -213,6 +258,9 @@ impl<'code> Vm<'code> {
ast::BinaryArithmeticOperator::Xor => ArithmeticValue::Integer(
left_operand as i64 ^ right_operand as i64,
),
ast::BinaryArithmeticOperator::Pow => {
ArithmeticValue::Floating(left_operand.powf(right_operand))
}
};
VmOutput::Arithmetic(match left_operand {
@ -254,6 +302,12 @@ impl<'code> Vm<'code> {
}
})
}
ast::UnaryArithmeticOperator::Sqrt => {
ArithmeticValue::Floating(match value {
ArithmeticValue::Integer(value) => (value as f64).sqrt(),
ArithmeticValue::Floating(value) => value.sqrt(),
})
}
})
}
OpCode::BinaryBooleanOperator(op) => {
@ -390,8 +444,9 @@ fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize
#[test]
fn boolean_compilation_test() {
let expression = parsing::parse_relation("and (and (< x 1) (!= 1 y)) (< ham x k)").expect("Valid AST");
println!("{:?}", Vm::compile_boolean(expression));
let expression =
parsing::parse_relation("and (and (< x 1) (!= 1 y)) (< ham x k)").expect("Valid AST");
println!("{:?}", compile_boolean(expression));
}
#[test]
@ -401,7 +456,7 @@ fn run_test() {
println!("{}", e);
panic!();
}
let code = Vm::compile_boolean(expression.unwrap());
let code = compile_boolean(expression.unwrap());
let vm = Vm::load(&code, Registers::load(0b_1011, 0b_1111, 10, 2, 3));
println!("{:?}", vm.run());
}

View File

@ -12,6 +12,7 @@ pub enum BinaryArithmeticOperator {
Plus,
Minus,
Xor,
Pow,
}
#[derive(Debug)]
@ -28,6 +29,7 @@ pub enum ComparisonOperator {
pub enum UnaryArithmeticOperator {
Negative,
Ham,
Sqrt,
}
#[derive(Debug)]

View File

@ -12,9 +12,9 @@ arithmetic_operand = { arithmetic_expression | number_literal }
arithmetic_expression = { "(" ~ (binary_arithmetic_conjunction | unary_arithmetic_conjunction | variable) ~ ")" |
binary_arithmetic_conjunction | unary_arithmetic_conjunction | variable }
binary_arithmetic_conjunction = { binary_arithmetic_operator ~ arithmetic_operand ~ arithmetic_operand }
binary_arithmetic_operator = { "*" | "/" | "+" | "-" | "^" }
binary_arithmetic_operator = { "*" | "/" | "+" | "-" | "^" | "pow" }
unary_arithmetic_conjunction = { unary_arithmetic_operator ~ arithmetic_operand }
unary_arithmetic_operator = { "neg" | "ham" }
unary_arithmetic_operator = { "neg" | "ham" | "sqrt" }
variable = { "x" | "y" | "n" | "p" | "k" }
number_literal = { ('0'..'9')+ }
WHITESPACE = _{ " " }

View File

@ -225,6 +225,7 @@ fn parse_unary_arithmetic_operator(
match rule.as_str().trim() {
"neg" => ast::UnaryArithmeticOperator::Negative,
"ham" => ast::UnaryArithmeticOperator::Ham,
"sqrt" => ast::UnaryArithmeticOperator::Sqrt,
_ => unreachable!(),
}
}
@ -247,6 +248,7 @@ fn parse_binary_arithmetic_operator(rule: pest::iterators::Pair<Rule>) -> Binary
"+" => BinaryArithmeticOperator::Plus,
"-" => BinaryArithmeticOperator::Minus,
"^" => BinaryArithmeticOperator::Xor,
"pow" => BinaryArithmeticOperator::Pow,
_ => unreachable!(),
}
}