diff --git a/Cargo.lock b/Cargo.lock
index d28eaa5..da16b04 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -6,6 +6,7 @@ version = 3
name = "adversary"
version = "0.1.0"
dependencies = [
+ "bitflags",
"criterion",
"pest",
"pest_derive",
diff --git a/Cargo.toml b/Cargo.toml
index a022880..e25e8f4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,6 +9,7 @@ name = "adversary"
crate-type = ["cdylib", "rlib"]
[dependencies]
+bitflags = "1.3.2"
pest = "2.5.5"
pest_derive = "2.5.5"
pyo3 = { version = "0.18.0", features = ["extension-module"] }
@@ -17,5 +18,5 @@ pyo3 = { version = "0.18.0", features = ["extension-module"] }
criterion = "0.4.0"
[[bench]]
-name = "search"
-harness = false
\ No newline at end of file
+name = "benches"
+harness = false
diff --git a/benches/benches.rs b/benches/benches.rs
new file mode 100644
index 0000000..faa2981
--- /dev/null
+++ b/benches/benches.rs
@@ -0,0 +1,65 @@
+#[macro_use]
+extern crate criterion;
+extern crate adversary;
+
+use criterion::{criterion_group, criterion_main, Criterion};
+
+fn vm(c: &mut Criterion) {
+ let expression =
+ adversary::vm::parsing::parse_relation("and (= (ham (^ x y)) (ham (+ 3 x))) (> (* x y) 5)");
+ if let Err(e) = expression {
+ println!("{}", e);
+ panic!();
+ }
+ let code = adversary::vm::compile_boolean(expression.unwrap());
+ let mut stack = adversary::vm::VmStack::from_code(&code);
+ c.bench_function("vm", |b| {
+ b.iter(|| {
+ for x in 0..(1 << 6) {
+ for y in 0..(1 << 6) {
+ criterion::black_box(
+ adversary::vm::Vm::load(
+ &code,
+ adversary::vm::Registers::load(x, y, 6, 0, 0),
+ &mut stack,
+ )
+ .run(),
+ );
+ }
+ }
+ })
+ });
+}
+
+fn search(c: &mut Criterion) {
+ pyo3::prepare_freethreaded_python();
+ pyo3::Python::with_gil(|_py| {
+ let obj = adversary::Prover::py_new(
+ "= (ham x) k".to_string(),
+ "= (ham y) (+ k 1)".to_string(),
+ "<= ham (^ x y) p".to_string(),
+ )
+ .unwrap();
+ c.bench_function("search", |b| {
+ b.iter(|| criterion::black_box(obj.find_bounds(10, 5, 3)));
+ });
+ })
+}
+
+fn search_big(c: &mut Criterion) {
+ pyo3::prepare_freethreaded_python();
+ pyo3::Python::with_gil(|_py| {
+ let obj = adversary::Prover::py_new(
+ "= (ham x) k".to_string(),
+ "= (ham y) (+ k 1)".to_string(),
+ "<= ham (^ x y) p".to_string(),
+ )
+ .unwrap();
+ c.bench_function("search_big", |b| {
+ b.iter(|| criterion::black_box(obj.find_bounds(12, 5, 8)));
+ });
+ })
+}
+
+criterion_group!(benches, vm, search, search_big);
+criterion_main!(benches);
diff --git a/benches/search.rs b/benches/search.rs
deleted file mode 100644
index 69a0957..0000000
--- a/benches/search.rs
+++ /dev/null
@@ -1,23 +0,0 @@
-#[macro_use]
-extern crate criterion;
-extern crate adversary;
-
-use criterion::{criterion_group, criterion_main, Criterion};
-
-pub fn criterion_benchmark(c: &mut Criterion) {
- pyo3::prepare_freethreaded_python();
- pyo3::Python::with_gil(|_py| {
- let obj = adversary::Prover::py_new(
- "= (ham x) k".to_string(),
- "= (ham y) (+ k 1)".to_string(),
- "<= ham (^ x y) p".to_string(),
- )
- .unwrap();
- c.bench_function("find_bounds", |b| {
- b.iter(|| criterion::black_box(obj.find_bounds(10, 5, 3)));
- });
- })
-}
-
-criterion_group!(benches, criterion_benchmark);
-criterion_main!(benches);
diff --git a/linear_vm_flamegraph.svg b/linear_vm_flamegraph.svg
new file mode 100644
index 0000000..35855a7
--- /dev/null
+++ b/linear_vm_flamegraph.svg
@@ -0,0 +1,491 @@
+
\ No newline at end of file
diff --git a/test/bounds.py b/pytest/bounds.py
similarity index 100%
rename from test/bounds.py
rename to pytest/bounds.py
diff --git a/pytest/timing.py b/pytest/timing.py
new file mode 100644
index 0000000..026a507
--- /dev/null
+++ b/pytest/timing.py
@@ -0,0 +1,6 @@
+import adversary
+import sys
+
+prover = adversary.Prover(
+ "= (ham x) k", "= (ham y) (+ k 1)", "= ham (^ x y) p").hint_symmetric()
+bounds = prover.find_bounds(int(sys.argv[1]), 5, 8)
diff --git a/flamegraph.svg b/recursive_vm_flamegraph.svg
similarity index 100%
rename from flamegraph.svg
rename to recursive_vm_flamegraph.svg
diff --git a/src/lib.rs b/src/lib.rs
index a79e3f0..c00fce8 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,11 +1,12 @@
extern crate pest;
#[macro_use]
extern crate pest_derive;
+#[macro_use]
+extern crate bitflags;
-mod vm;
+pub mod vm;
use pyo3::prelude::*;
-use vm::VmCode;
// We cache the A and B set up to a total of ~2Gb of memory,
// meaning ~ 10^7 elements in each set
@@ -33,14 +34,22 @@ pub struct BoundsResult {
#[pyclass]
pub struct Prover {
- a_description: VmCode,
- b_description: VmCode,
- relationship: VmCode,
+ a_description: vm::VmCode,
+ b_description: vm::VmCode,
+ relationship: vm::VmCode,
+ hints: ProofHints,
+}
+
+bitflags! {
+ struct ProofHints: u8 {
+ const SymmetricFn = 0b00000001;
+ }
}
struct CachedSetIterator<'code> {
cache: Vec,
- filter: &'code VmCode,
+ filter: &'code vm::VmCode,
+ stack: vm::VmStack,
n: u32,
p: u32,
k: u32,
@@ -52,10 +61,11 @@ struct CachedSetIterator<'code> {
}
impl<'code> CachedSetIterator<'code> {
- fn create_x(filter: &'code VmCode, n: u32, p: u32, k: u32) -> Self {
+ fn create_x(filter: &'code vm::VmCode, n: u32, p: u32, k: u32) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
+ stack: vm::VmStack::from_code(filter),
n,
p,
k,
@@ -67,10 +77,11 @@ impl<'code> CachedSetIterator<'code> {
}
}
- fn create_y(filter: &'code VmCode, n: u32, p: u32, k: u32) -> Self {
+ fn create_y(filter: &'code vm::VmCode, n: u32, p: u32, k: u32) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
+ stack: vm::VmStack::from_code(filter),
n,
p,
k,
@@ -119,16 +130,20 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
vm::Vm::load(
self.filter,
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
+ &mut self.stack,
)
.run()
- .unwrap_bool()
+ .output_bool()
+ .unwrap()
} else {
vm::Vm::load(
self.filter,
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
+ &mut self.stack,
)
.run()
- .unwrap_bool()
+ .output_bool()
+ .unwrap()
};
if !included {
@@ -150,6 +165,7 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
// See https://stackoverflow.com/a/27755938
// In turn from http://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
struct FixedHammingWeight {
+ ones: u32,
permutation: u64,
top: u64,
exhausted: bool,
@@ -158,11 +174,17 @@ struct FixedHammingWeight {
impl FixedHammingWeight {
fn new(bits: u32, ones: u32) -> Self {
FixedHammingWeight {
+ ones,
permutation: (1 << ones) - 1,
top: ((1 << ones) - 1) << (bits - ones),
exhausted: false,
}
}
+
+ fn reset(&mut self) {
+ self.permutation = (1 << self.ones) - 1;
+ self.exhausted = false;
+ }
}
impl Iterator for FixedHammingWeight {
@@ -170,6 +192,7 @@ impl Iterator for FixedHammingWeight {
fn next(&mut self) -> Option {
if self.exhausted {
+ self.reset();
return None;
}
let to_yield = self.permutation;
@@ -263,9 +286,20 @@ impl Prover {
relationship,
a_description,
b_description,
+ hints: ProofHints::empty(),
})
}
+ /// Hints to the prover that the function in question is totally symmetric,
+ /// or, equivalently, that `a_description` only depends on the Hamming weight
+ /// of x, and `b_description` only depends on the Hamming weight of y.
+ ///
+ /// If you l your results will lower bound the upper bound.
+ pub fn hint_symmetric<'s>(mut py_self: PyRefMut<'s, Self>, py: Python) -> PyRefMut<'s, Self> {
+ (*py_self).hints |= ProofHints::SymmetricFn;
+ py_self
+ }
+
/// Finds a lower bound to the query complexity according to the specified parameters.
///
/// Arguments:
@@ -290,9 +324,10 @@ impl Prover {
));
}
+ let mut vm_stack = vm::VmStack::from_code(&self.relationship);
+
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
- let window_iterator = FixedHammingWeight::new(n, p);
let mut min_x_relations = u64::MAX;
let mut min_y_relations = u64::MAX;
@@ -303,7 +338,15 @@ impl Prover {
let mut joint_bounding_y = 0_u64;
let mut bounding_window = 0_u64;
- for window in window_iterator {
+ let mut window_iterator = FixedHammingWeight::new(n, p).into_iter();
+ let mut first_entry = [(1_u64 << p) - 1].into_iter();
+ let effective_window_iterator = if self.hints.contains(ProofHints::SymmetricFn) {
+ &mut first_entry as &mut dyn Iterator-
+ } else {
+ &mut window_iterator as &mut dyn Iterator
-
+ };
+
+ for window in effective_window_iterator {
let mut max_x_relations = 0_u64;
let mut max_y_relations = 0_u64;
let mut bounding_x_candidate = 0_u64;
@@ -317,9 +360,11 @@ impl Prover {
let related = vm::Vm::load(
&self.relationship,
vm::Registers::load(x, y, n, p, k),
+ &mut vm_stack,
)
.run()
- .unwrap_bool();
+ .output_bool()
+ .unwrap();
if related {
min_x_relations_candidate += 1;
@@ -344,10 +389,14 @@ impl Prover {
let mut max_y_relations_candidate = 0_u64;
for x in &mut a_set_iterator {
- let related =
- vm::Vm::load(&self.relationship, vm::Registers::load(x, y, n, p, k))
- .run()
- .unwrap_bool();
+ let related = vm::Vm::load(
+ &self.relationship,
+ vm::Registers::load(x, y, n, p, k),
+ &mut vm_stack,
+ )
+ .run()
+ .output_bool()
+ .unwrap();
if related {
min_y_relations_candidate += 1;
@@ -374,9 +423,9 @@ impl Prover {
joint_bounding_x = bounding_x_candidate;
joint_bounding_y = bounding_y_candidate;
bounding_window = window;
-
- Python::with_gil(|py| py.check_signals())?;
}
+
+ Python::with_gil(|py| py.check_signals())?;
}
Ok(BoundsResult {
@@ -437,4 +486,4 @@ fn test_full_run() {
.unwrap();
std::hint::black_box(obj.find_bounds(10, 5, 3)).expect("Success");
})
-}
\ No newline at end of file
+}
diff --git a/src/vm/mod.rs b/src/vm/mod.rs
index 991a53d..d00972f 100644
--- a/src/vm/mod.rs
+++ b/src/vm/mod.rs
@@ -4,7 +4,8 @@ use parsing::ast;
use crate::vm::parsing::ast::ComparisonOperator;
-type Bytecode = (OpCode, usize);
+// type Bytecode = (OpCode, usize);
+type Bytecode = OpCode;
#[derive(Debug)]
pub struct VmCode {
@@ -15,6 +16,13 @@ pub struct VmCode {
pub struct Vm<'code> {
code: &'code VmCode,
registers: Registers,
+ stack: &'code mut VmStack,
+}
+
+#[derive(Debug)]
+pub struct VmStack {
+ boolean_stack: Vec,
+ arithmetic_stack: Vec,
}
#[derive(Debug)]
@@ -26,11 +34,12 @@ pub struct Registers {
k: i64,
}
+/*
#[derive(Debug)]
pub enum VmOutput {
Boolean(bool),
Arithmetic(ArithmeticValue),
-}
+} */
#[derive(Debug)]
pub enum ArithmeticValue {
@@ -54,7 +63,23 @@ enum Expression {
Arithmetic(ast::ArithmeticExpression),
}
-impl VmOutput {
+impl VmStack {
+ pub fn from_code(code: &VmCode) -> Self {
+ // Each op_code can produce at most one value onto the stack
+ let code = &code.code;
+ VmStack {
+ boolean_stack: Vec::with_capacity(code.len()),
+ arithmetic_stack: Vec::with_capacity(code.len()),
+ }
+ }
+
+ fn reset(&mut self) {
+ self.boolean_stack.clear();
+ self.arithmetic_stack.clear();
+ }
+}
+
+/*impl VmOutput {
pub fn unwrap_bool(self) -> bool {
match self {
VmOutput::Boolean(value) => value,
@@ -73,6 +98,7 @@ impl VmOutput {
}
}
}
+*/
impl Registers {
pub fn load(x: u64, y: u64, n: u32, p: u32, k: u32) -> Self {
@@ -101,7 +127,7 @@ pub fn compile_arithmetic(expression: ast::ArithmeticExpression) -> VmCode {
impl VmCode {
pub fn any bool>(&self, predicate: P) -> bool {
- for (opcode, _) in &self.code {
+ for opcode in &self.code {
if predicate(opcode) {
return true;
}
@@ -111,16 +137,211 @@ impl VmCode {
}
impl<'code> Vm<'code> {
- pub fn load(code: &'code VmCode, registers: Registers) -> Self {
- Vm { code, registers }
+ pub fn load(code: &'code VmCode, registers: Registers, stack: &'code mut VmStack) -> Self {
+ Vm {
+ code,
+ registers,
+ stack,
+ }
}
- pub fn run(&self) -> VmOutput {
+ pub fn run(&mut self) -> &mut Self {
+ // Arithmetic operations for `ArithmeticValue`s
+ fn operations_as_integer(
+ op: &ast::BinaryArithmeticOperator,
+ left_operand: i64,
+ right_operand: i64,
+ ) -> ArithmeticValue {
+ match op {
+ ast::BinaryArithmeticOperator::Times => {
+ ArithmeticValue::Integer(left_operand * right_operand)
+ }
+ ast::BinaryArithmeticOperator::Divide => {
+ if left_operand % right_operand == 0 {
+ ArithmeticValue::Integer(left_operand / right_operand)
+ } else {
+ ArithmeticValue::Floating(left_operand as f64 / right_operand as f64)
+ }
+ }
+ ast::BinaryArithmeticOperator::Plus => {
+ ArithmeticValue::Integer(left_operand + right_operand)
+ }
+ ast::BinaryArithmeticOperator::Minus => {
+ ArithmeticValue::Integer(left_operand - right_operand)
+ }
+ ast::BinaryArithmeticOperator::Xor => {
+ ArithmeticValue::Integer(left_operand ^ right_operand)
+ }
+ ast::BinaryArithmeticOperator::Pow => {
+ if right_operand > 0 {
+ ArithmeticValue::Integer(left_operand.pow(right_operand as u32))
+ } else {
+ ArithmeticValue::Floating((left_operand as f64).powi(right_operand as i32))
+ }
+ }
+ }
+ }
+
+ fn operations_as_floating(
+ op: &ast::BinaryArithmeticOperator,
+ left_operand: f64,
+ right_operand: f64,
+ ) -> ArithmeticValue {
+ match op {
+ ast::BinaryArithmeticOperator::Times => {
+ ArithmeticValue::Floating(left_operand * right_operand)
+ }
+ ast::BinaryArithmeticOperator::Divide => {
+ ArithmeticValue::Floating(left_operand as f64 / right_operand as f64)
+ }
+ ast::BinaryArithmeticOperator::Plus => {
+ ArithmeticValue::Floating(left_operand + right_operand)
+ }
+ ast::BinaryArithmeticOperator::Minus => {
+ ArithmeticValue::Floating(left_operand - right_operand)
+ }
+ ast::BinaryArithmeticOperator::Xor => {
+ ArithmeticValue::Integer(left_operand as i64 ^ right_operand as i64)
+ }
+ ast::BinaryArithmeticOperator::Pow => {
+ ArithmeticValue::Floating(left_operand.powf(right_operand))
+ }
+ }
+ }
+
+ fn comparison(
+ op: &ComparisonOperator,
+ left_operand: T,
+ right_operand: T,
+ ) -> bool {
+ match op {
+ ast::ComparisonOperator::GreaterOrEqual => left_operand.ge(&right_operand),
+ ast::ComparisonOperator::LessOrEqual => left_operand.le(&right_operand),
+ ast::ComparisonOperator::GreaterThan => left_operand.gt(&right_operand),
+ ast::ComparisonOperator::LessThan => left_operand.lt(&right_operand),
+ ast::ComparisonOperator::NotEqual => left_operand.ne(&right_operand),
+ ast::ComparisonOperator::Equal => left_operand.eq(&right_operand),
+ }
+ }
+
// Alias for convenience
let code = &self.code.code;
let registers = &self.registers;
+ let stack = &mut self.stack;
- struct Resolver<'f> {
+ stack.reset();
+
+ for opcode in code.iter().rev() {
+ match opcode {
+ OpCode::UnaryBooleanOperator(op) => {
+ let operand = stack.boolean_stack.pop().unwrap();
+ match op {
+ ast::UnaryBooleanOperator::Not => {
+ stack.boolean_stack.push(!operand);
+ }
+ }
+ }
+ OpCode::BinaryArithmeticOperator(op) => {
+ let left_operand = stack.arithmetic_stack.pop().unwrap();
+ let right_operand = stack.arithmetic_stack.pop().unwrap();
+ let value = match left_operand {
+ ArithmeticValue::Integer(left_operand) => match right_operand {
+ ArithmeticValue::Integer(right_operand) => {
+ operations_as_integer(op, left_operand, right_operand)
+ }
+ ArithmeticValue::Floating(right_operand) => {
+ operations_as_floating(op, left_operand as f64, right_operand)
+ }
+ },
+ ArithmeticValue::Floating(left_operand) => match right_operand {
+ ArithmeticValue::Integer(right_operand) => {
+ operations_as_floating(op, left_operand, right_operand as f64)
+ }
+ ArithmeticValue::Floating(right_operand) => {
+ operations_as_floating(op, left_operand, right_operand)
+ }
+ },
+ };
+ stack.arithmetic_stack.push(value);
+ }
+ OpCode::ComparisonOperator(op) => {
+ let left_operand = stack.arithmetic_stack.pop().unwrap();
+ let right_operand = stack.arithmetic_stack.pop().unwrap();
+ let value = match left_operand {
+ ArithmeticValue::Integer(left_operand) => match right_operand {
+ ArithmeticValue::Integer(right_operand) => {
+ comparison(op, left_operand, right_operand)
+ }
+ ArithmeticValue::Floating(right_operand) => {
+ comparison(op, left_operand as f64, right_operand)
+ }
+ },
+ ArithmeticValue::Floating(left_operand) => match right_operand {
+ ArithmeticValue::Integer(right_operand) => {
+ comparison(op, left_operand, right_operand as f64)
+ }
+ ArithmeticValue::Floating(right_operand) => {
+ comparison(op, left_operand, right_operand)
+ }
+ },
+ };
+ stack.boolean_stack.push(value);
+ }
+ OpCode::UnaryArithmeticOperator(op) => {
+ let operand = stack.arithmetic_stack.pop().unwrap();
+ let value = match op {
+ ast::UnaryArithmeticOperator::Negative => match operand {
+ ArithmeticValue::Integer(operand) => ArithmeticValue::Integer(-operand),
+ ArithmeticValue::Floating(operand) => {
+ ArithmeticValue::Floating(-operand)
+ }
+ },
+ ast::UnaryArithmeticOperator::Ham => {
+ ArithmeticValue::Integer(match operand {
+ ArithmeticValue::Integer(operand) => operand.count_ones() as i64,
+ ArithmeticValue::Floating(operand) => {
+ (operand.round() as i64).count_ones() as i64
+ }
+ })
+ }
+ ast::UnaryArithmeticOperator::Sqrt => {
+ ArithmeticValue::Floating(match operand {
+ ArithmeticValue::Integer(operand) => (operand as f64).sqrt(),
+ ArithmeticValue::Floating(operand) => operand.sqrt(),
+ })
+ }
+ };
+ stack.arithmetic_stack.push(value);
+ }
+ OpCode::BinaryBooleanOperator(op) => {
+ let left_operand = stack.boolean_stack.pop().unwrap();
+ let right_operand = stack.boolean_stack.pop().unwrap();
+ let value = match op {
+ ast::BinaryBooleanOperator::And => left_operand & right_operand,
+ ast::BinaryBooleanOperator::Or => left_operand | right_operand,
+ ast::BinaryBooleanOperator::Xor => left_operand ^ right_operand,
+ };
+ stack.boolean_stack.push(value);
+ }
+ OpCode::Variable(variable) => {
+ let value = ArithmeticValue::Integer(match variable {
+ ast::Variable::X => registers.x,
+ ast::Variable::Y => registers.y,
+ ast::Variable::N => registers.n,
+ ast::Variable::P => registers.p,
+ ast::Variable::K => registers.k,
+ });
+ stack.arithmetic_stack.push(value);
+ }
+ OpCode::Literal(literal) => {
+ stack
+ .arithmetic_stack
+ .push(ArithmeticValue::Integer(*literal));
+ }
+ }
+ }
+
+ /*struct Resolver<'f> {
f: &'f dyn Fn(&Resolver, usize) -> VmOutput,
}
@@ -346,7 +567,18 @@ impl<'code> Vm<'code> {
},
};
- (resolver.f)(&resolver, 0)
+ (resolver.f)(&resolver, 0)*/
+
+ self
+ }
+
+ pub fn output_bool(&mut self) -> Result {
+ self.stack.boolean_stack.pop().ok_or(())
+ }
+
+ #[allow(unused)]
+ pub fn output_arithmetic(&mut self) -> Result {
+ self.stack.arithmetic_stack.pop().ok_or(())
}
}
@@ -356,28 +588,32 @@ fn compile_expression(expression: Expression, code: &mut Vec) -> usize
Expression::Boolean(expression) => match expression {
ast::BooleanExpression::BinaryBooleanConjunction(expression) => {
let expression = *expression;
- code.push((OpCode::BinaryBooleanOperator(expression.operator), 0));
+ // code.push((OpCode::BinaryBooleanOperator(expression.operator), 0));
+ code.push(OpCode::BinaryBooleanOperator(expression.operator));
let index = code.len() - 1;
let left_operand_size =
compile_expression(Expression::Boolean(expression.left_operand), code);
let right_operand_size =
compile_expression(Expression::Boolean(expression.right_operand), code);
- code[index].1 = left_operand_size + 1;
+ // code[index].1 = left_operand_size + 1;
return 1 + left_operand_size + right_operand_size;
}
ast::BooleanExpression::UnaryBooleanConjunction(expression) => {
let expression = *expression;
- code.push((OpCode::UnaryBooleanOperator(expression.operator), 0));
+ // code.push((OpCode::UnaryBooleanOperator(expression.operator), 0));
+ code.push(OpCode::UnaryBooleanOperator(expression.operator));
let operand_size =
compile_expression(Expression::Boolean(expression.operand), code);
return 1 + operand_size;
}
ast::BooleanExpression::ComparisonConjunction(expression) => {
- code.push((OpCode::ComparisonOperator(expression.operator), 0));
+ // code.push((OpCode::ComparisonOperator(expression.operator), 0));
+ code.push(OpCode::ComparisonOperator(expression.operator));
let index = code.len() - 1;
let left_operand_size = match expression.left_operand {
ast::ArithmeticOperand::Literal(literal) => {
- code.push((OpCode::Literal(literal), 0));
+ // code.push((OpCode::Literal(literal), 0));
+ code.push(OpCode::Literal(literal));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
@@ -386,28 +622,32 @@ fn compile_expression(expression: Expression, code: &mut Vec) -> usize
};
let right_operand_size = match expression.right_operand {
ast::ArithmeticOperand::Literal(literal) => {
- code.push((OpCode::Literal(literal), 0));
+ // code.push((OpCode::Literal(literal), 0));
+ code.push(OpCode::Literal(literal));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
- code[index].1 = left_operand_size + 1;
+ // code[index].1 = left_operand_size + 1;
return 1 + left_operand_size + right_operand_size;
}
},
Expression::Arithmetic(expression) => match expression {
ast::ArithmeticExpression::Variable(variable) => {
- code.push((OpCode::Variable(variable), 0));
+ // code.push((OpCode::Variable(variable), 0));
+ code.push(OpCode::Variable(variable));
return 1;
}
ast::ArithmeticExpression::UnaryArithmeticConjunction(expression) => {
let expression = *expression;
- code.push((OpCode::UnaryArithmeticOperator(expression.operator), 0));
+ // code.push((OpCode::UnaryArithmeticOperator(expression.operator), 0));
+ code.push(OpCode::UnaryArithmeticOperator(expression.operator));
let operand_size = match expression.operand {
ast::ArithmeticOperand::Literal(literal) => {
- code.push((OpCode::Literal(literal), 0));
+ // code.push((OpCode::Literal(literal), 0));
+ code.push(OpCode::Literal(literal));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
@@ -417,11 +657,13 @@ fn compile_expression(expression: Expression, code: &mut Vec) -> usize
return 1 + operand_size;
}
ast::ArithmeticExpression::BinaryArithmeticConjunction(expression) => {
- code.push((OpCode::BinaryArithmeticOperator(expression.operator), 0));
+ // code.push((OpCode::BinaryArithmeticOperator(expression.operator), 0));
+ code.push(OpCode::BinaryArithmeticOperator(expression.operator));
let index = code.len() - 1;
let left_operand_size = match expression.left_operand {
ast::ArithmeticOperand::Literal(literal) => {
- code.push((OpCode::Literal(literal), 0));
+ // code.push((OpCode::Literal(literal), 0));
+ code.push(OpCode::Literal(literal));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
@@ -430,14 +672,15 @@ fn compile_expression(expression: Expression, code: &mut Vec) -> usize
};
let right_operand_size = match expression.right_operand {
ast::ArithmeticOperand::Literal(literal) => {
- code.push((OpCode::Literal(literal), 0));
+ // code.push((OpCode::Literal(literal), 0));
+ code.push(OpCode::Literal(literal));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
- code[index].1 = left_operand_size + 1;
+ // code[index].1 = left_operand_size + 1;
return 1 + left_operand_size + right_operand_size;
}
},
@@ -452,13 +695,48 @@ fn boolean_compilation_test() {
}
#[test]
-fn run_test() {
+fn test_vm_simple() {
let expression = parsing::parse_relation("(= (ham (^ x y)) (ham 1))");
if let Err(e) = expression {
println!("{}", e);
panic!();
}
let code = compile_boolean(expression.unwrap());
- let vm = Vm::load(&code, Registers::load(0b_1011, 0b_1111, 10, 2, 3));
- println!("{:?}", vm.run());
+ let mut stack = VmStack::from_code(&code);
+ let mut vm = Vm::load(
+ &code,
+ Registers::load(0b_1011, 0b_1111, 10, 2, 3),
+ &mut stack,
+ );
+ let output = vm.run().output_bool().unwrap();
+ assert_eq!(output, true);
+ let mut vm = Vm::load(
+ &code,
+ Registers::load(0b_1011, 0b_1110, 10, 2, 3),
+ &mut stack,
+ );
+ let output = vm.run().output_bool().unwrap();
+ assert_eq!(output, false);
+}
+
+#[test]
+fn test_vm_contrived() {
+ let expression = parsing::parse_relation("and (= (ham (^ x y)) (ham (+ 3 x))) (> (* x y) 5)");
+ if let Err(e) = expression {
+ println!("{}", e);
+ panic!();
+ }
+ let code = compile_boolean(expression.unwrap());
+ let mut stack = VmStack::from_code(&code);
+
+ let as_rust = |x: u64, y: u64| ((x + 3).count_ones() == (x ^ y).count_ones()) && (x * y > 5);
+
+ for x in 0..(1 << 4) {
+ for y in 0..(1 << 4) {
+ let mut vm = Vm::load(&code, Registers::load(x, y, 10, 2, 3), &mut stack);
+ let output = vm.run().output_bool().unwrap();
+ let expected = as_rust(x, y);
+ assert_eq!(output, expected);
+ }
+ }
}
diff --git a/time.ps1 b/time.ps1
new file mode 100644
index 0000000..ed460fe
--- /dev/null
+++ b/time.ps1
@@ -0,0 +1,2 @@
+venv/Scripts/activate
+python test/bounds.py