Final commit, in principle.
Got things running at about ~1ns per iteration, seems good enough. Adopted a linear strategy to evaluating the bytecode, rather than a recursive or even imperative evaluation strategy; this also lets me elide the offsets and store the bytecode in half the size. Looking forward to finding out that formulas are evaluated wrongly, but couldn't find a counterexample. Also restructured things a bit to avoid multiple alocations when evaluating by this strategy.
This commit is contained in:
parent
3a6c511ed5
commit
b899ddd63a
|
@ -6,6 +6,7 @@ version = 3
|
||||||
name = "adversary"
|
name = "adversary"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
"criterion",
|
"criterion",
|
||||||
"pest",
|
"pest",
|
||||||
"pest_derive",
|
"pest_derive",
|
||||||
|
|
|
@ -9,6 +9,7 @@ name = "adversary"
|
||||||
crate-type = ["cdylib", "rlib"]
|
crate-type = ["cdylib", "rlib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
bitflags = "1.3.2"
|
||||||
pest = "2.5.5"
|
pest = "2.5.5"
|
||||||
pest_derive = "2.5.5"
|
pest_derive = "2.5.5"
|
||||||
pyo3 = { version = "0.18.0", features = ["extension-module"] }
|
pyo3 = { version = "0.18.0", features = ["extension-module"] }
|
||||||
|
@ -17,5 +18,5 @@ pyo3 = { version = "0.18.0", features = ["extension-module"] }
|
||||||
criterion = "0.4.0"
|
criterion = "0.4.0"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "search"
|
name = "benches"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
#[macro_use]
|
||||||
|
extern crate criterion;
|
||||||
|
extern crate adversary;
|
||||||
|
|
||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
|
||||||
|
fn vm(c: &mut Criterion) {
|
||||||
|
let expression =
|
||||||
|
adversary::vm::parsing::parse_relation("and (= (ham (^ x y)) (ham (+ 3 x))) (> (* x y) 5)");
|
||||||
|
if let Err(e) = expression {
|
||||||
|
println!("{}", e);
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
let code = adversary::vm::compile_boolean(expression.unwrap());
|
||||||
|
let mut stack = adversary::vm::VmStack::from_code(&code);
|
||||||
|
c.bench_function("vm", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
for x in 0..(1 << 6) {
|
||||||
|
for y in 0..(1 << 6) {
|
||||||
|
criterion::black_box(
|
||||||
|
adversary::vm::Vm::load(
|
||||||
|
&code,
|
||||||
|
adversary::vm::Registers::load(x, y, 6, 0, 0),
|
||||||
|
&mut stack,
|
||||||
|
)
|
||||||
|
.run(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search(c: &mut Criterion) {
|
||||||
|
pyo3::prepare_freethreaded_python();
|
||||||
|
pyo3::Python::with_gil(|_py| {
|
||||||
|
let obj = adversary::Prover::py_new(
|
||||||
|
"= (ham x) k".to_string(),
|
||||||
|
"= (ham y) (+ k 1)".to_string(),
|
||||||
|
"<= ham (^ x y) p".to_string(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
c.bench_function("search", |b| {
|
||||||
|
b.iter(|| criterion::black_box(obj.find_bounds(10, 5, 3)));
|
||||||
|
});
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn search_big(c: &mut Criterion) {
|
||||||
|
pyo3::prepare_freethreaded_python();
|
||||||
|
pyo3::Python::with_gil(|_py| {
|
||||||
|
let obj = adversary::Prover::py_new(
|
||||||
|
"= (ham x) k".to_string(),
|
||||||
|
"= (ham y) (+ k 1)".to_string(),
|
||||||
|
"<= ham (^ x y) p".to_string(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
c.bench_function("search_big", |b| {
|
||||||
|
b.iter(|| criterion::black_box(obj.find_bounds(12, 5, 8)));
|
||||||
|
});
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, vm, search, search_big);
|
||||||
|
criterion_main!(benches);
|
|
@ -1,23 +0,0 @@
|
||||||
#[macro_use]
|
|
||||||
extern crate criterion;
|
|
||||||
extern crate adversary;
|
|
||||||
|
|
||||||
use criterion::{criterion_group, criterion_main, Criterion};
|
|
||||||
|
|
||||||
pub fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
pyo3::prepare_freethreaded_python();
|
|
||||||
pyo3::Python::with_gil(|_py| {
|
|
||||||
let obj = adversary::Prover::py_new(
|
|
||||||
"= (ham x) k".to_string(),
|
|
||||||
"= (ham y) (+ k 1)".to_string(),
|
|
||||||
"<= ham (^ x y) p".to_string(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
c.bench_function("find_bounds", |b| {
|
|
||||||
b.iter(|| criterion::black_box(obj.find_bounds(10, 5, 3)));
|
|
||||||
});
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
||||||
criterion_main!(benches);
|
|
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 484 KiB |
|
@ -0,0 +1,6 @@
|
||||||
|
import adversary
|
||||||
|
import sys
|
||||||
|
|
||||||
|
prover = adversary.Prover(
|
||||||
|
"= (ham x) k", "= (ham y) (+ k 1)", "= ham (^ x y) p").hint_symmetric()
|
||||||
|
bounds = prover.find_bounds(int(sys.argv[1]), 5, 8)
|
Before Width: | Height: | Size: 399 KiB After Width: | Height: | Size: 399 KiB |
89
src/lib.rs
89
src/lib.rs
|
@ -1,11 +1,12 @@
|
||||||
extern crate pest;
|
extern crate pest;
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate pest_derive;
|
extern crate pest_derive;
|
||||||
|
#[macro_use]
|
||||||
|
extern crate bitflags;
|
||||||
|
|
||||||
mod vm;
|
pub mod vm;
|
||||||
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use vm::VmCode;
|
|
||||||
|
|
||||||
// We cache the A and B set up to a total of ~2Gb of memory,
|
// We cache the A and B set up to a total of ~2Gb of memory,
|
||||||
// meaning ~ 10^7 elements in each set
|
// meaning ~ 10^7 elements in each set
|
||||||
|
@ -33,14 +34,22 @@ pub struct BoundsResult {
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
pub struct Prover {
|
pub struct Prover {
|
||||||
a_description: VmCode,
|
a_description: vm::VmCode,
|
||||||
b_description: VmCode,
|
b_description: vm::VmCode,
|
||||||
relationship: VmCode,
|
relationship: vm::VmCode,
|
||||||
|
hints: ProofHints,
|
||||||
|
}
|
||||||
|
|
||||||
|
bitflags! {
|
||||||
|
struct ProofHints: u8 {
|
||||||
|
const SymmetricFn = 0b00000001;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CachedSetIterator<'code> {
|
struct CachedSetIterator<'code> {
|
||||||
cache: Vec<u64>,
|
cache: Vec<u64>,
|
||||||
filter: &'code VmCode,
|
filter: &'code vm::VmCode,
|
||||||
|
stack: vm::VmStack,
|
||||||
n: u32,
|
n: u32,
|
||||||
p: u32,
|
p: u32,
|
||||||
k: u32,
|
k: u32,
|
||||||
|
@ -52,10 +61,11 @@ struct CachedSetIterator<'code> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'code> CachedSetIterator<'code> {
|
impl<'code> CachedSetIterator<'code> {
|
||||||
fn create_x(filter: &'code VmCode, n: u32, p: u32, k: u32) -> Self {
|
fn create_x(filter: &'code vm::VmCode, n: u32, p: u32, k: u32) -> Self {
|
||||||
CachedSetIterator {
|
CachedSetIterator {
|
||||||
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
|
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
|
||||||
filter,
|
filter,
|
||||||
|
stack: vm::VmStack::from_code(filter),
|
||||||
n,
|
n,
|
||||||
p,
|
p,
|
||||||
k,
|
k,
|
||||||
|
@ -67,10 +77,11 @@ impl<'code> CachedSetIterator<'code> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_y(filter: &'code VmCode, n: u32, p: u32, k: u32) -> Self {
|
fn create_y(filter: &'code vm::VmCode, n: u32, p: u32, k: u32) -> Self {
|
||||||
CachedSetIterator {
|
CachedSetIterator {
|
||||||
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
|
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
|
||||||
filter,
|
filter,
|
||||||
|
stack: vm::VmStack::from_code(filter),
|
||||||
n,
|
n,
|
||||||
p,
|
p,
|
||||||
k,
|
k,
|
||||||
|
@ -119,16 +130,20 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
||||||
vm::Vm::load(
|
vm::Vm::load(
|
||||||
self.filter,
|
self.filter,
|
||||||
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
|
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
|
||||||
|
&mut self.stack,
|
||||||
)
|
)
|
||||||
.run()
|
.run()
|
||||||
.unwrap_bool()
|
.output_bool()
|
||||||
|
.unwrap()
|
||||||
} else {
|
} else {
|
||||||
vm::Vm::load(
|
vm::Vm::load(
|
||||||
self.filter,
|
self.filter,
|
||||||
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
|
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
|
||||||
|
&mut self.stack,
|
||||||
)
|
)
|
||||||
.run()
|
.run()
|
||||||
.unwrap_bool()
|
.output_bool()
|
||||||
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
if !included {
|
if !included {
|
||||||
|
@ -150,6 +165,7 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
||||||
// See https://stackoverflow.com/a/27755938
|
// See https://stackoverflow.com/a/27755938
|
||||||
// In turn from http://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
|
// In turn from http://graphics.stanford.edu/~seander/bithacks.html#NextBitPermutation
|
||||||
struct FixedHammingWeight {
|
struct FixedHammingWeight {
|
||||||
|
ones: u32,
|
||||||
permutation: u64,
|
permutation: u64,
|
||||||
top: u64,
|
top: u64,
|
||||||
exhausted: bool,
|
exhausted: bool,
|
||||||
|
@ -158,11 +174,17 @@ struct FixedHammingWeight {
|
||||||
impl FixedHammingWeight {
|
impl FixedHammingWeight {
|
||||||
fn new(bits: u32, ones: u32) -> Self {
|
fn new(bits: u32, ones: u32) -> Self {
|
||||||
FixedHammingWeight {
|
FixedHammingWeight {
|
||||||
|
ones,
|
||||||
permutation: (1 << ones) - 1,
|
permutation: (1 << ones) - 1,
|
||||||
top: ((1 << ones) - 1) << (bits - ones),
|
top: ((1 << ones) - 1) << (bits - ones),
|
||||||
exhausted: false,
|
exhausted: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reset(&mut self) {
|
||||||
|
self.permutation = (1 << self.ones) - 1;
|
||||||
|
self.exhausted = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Iterator for FixedHammingWeight {
|
impl Iterator for FixedHammingWeight {
|
||||||
|
@ -170,6 +192,7 @@ impl Iterator for FixedHammingWeight {
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.exhausted {
|
if self.exhausted {
|
||||||
|
self.reset();
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let to_yield = self.permutation;
|
let to_yield = self.permutation;
|
||||||
|
@ -263,9 +286,20 @@ impl Prover {
|
||||||
relationship,
|
relationship,
|
||||||
a_description,
|
a_description,
|
||||||
b_description,
|
b_description,
|
||||||
|
hints: ProofHints::empty(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Hints to the prover that the function in question is totally symmetric,
|
||||||
|
/// or, equivalently, that `a_description` only depends on the Hamming weight
|
||||||
|
/// of x, and `b_description` only depends on the Hamming weight of y.
|
||||||
|
///
|
||||||
|
/// If you l your results will lower bound the upper bound.
|
||||||
|
pub fn hint_symmetric<'s>(mut py_self: PyRefMut<'s, Self>, py: Python) -> PyRefMut<'s, Self> {
|
||||||
|
(*py_self).hints |= ProofHints::SymmetricFn;
|
||||||
|
py_self
|
||||||
|
}
|
||||||
|
|
||||||
/// Finds a lower bound to the query complexity according to the specified parameters.
|
/// Finds a lower bound to the query complexity according to the specified parameters.
|
||||||
///
|
///
|
||||||
/// Arguments:
|
/// Arguments:
|
||||||
|
@ -290,9 +324,10 @@ impl Prover {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut vm_stack = vm::VmStack::from_code(&self.relationship);
|
||||||
|
|
||||||
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
|
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
|
||||||
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
|
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
|
||||||
let window_iterator = FixedHammingWeight::new(n, p);
|
|
||||||
|
|
||||||
let mut min_x_relations = u64::MAX;
|
let mut min_x_relations = u64::MAX;
|
||||||
let mut min_y_relations = u64::MAX;
|
let mut min_y_relations = u64::MAX;
|
||||||
|
@ -303,7 +338,15 @@ impl Prover {
|
||||||
let mut joint_bounding_y = 0_u64;
|
let mut joint_bounding_y = 0_u64;
|
||||||
let mut bounding_window = 0_u64;
|
let mut bounding_window = 0_u64;
|
||||||
|
|
||||||
for window in window_iterator {
|
let mut window_iterator = FixedHammingWeight::new(n, p).into_iter();
|
||||||
|
let mut first_entry = [(1_u64 << p) - 1].into_iter();
|
||||||
|
let effective_window_iterator = if self.hints.contains(ProofHints::SymmetricFn) {
|
||||||
|
&mut first_entry as &mut dyn Iterator<Item = u64>
|
||||||
|
} else {
|
||||||
|
&mut window_iterator as &mut dyn Iterator<Item = u64>
|
||||||
|
};
|
||||||
|
|
||||||
|
for window in effective_window_iterator {
|
||||||
let mut max_x_relations = 0_u64;
|
let mut max_x_relations = 0_u64;
|
||||||
let mut max_y_relations = 0_u64;
|
let mut max_y_relations = 0_u64;
|
||||||
let mut bounding_x_candidate = 0_u64;
|
let mut bounding_x_candidate = 0_u64;
|
||||||
|
@ -317,9 +360,11 @@ impl Prover {
|
||||||
let related = vm::Vm::load(
|
let related = vm::Vm::load(
|
||||||
&self.relationship,
|
&self.relationship,
|
||||||
vm::Registers::load(x, y, n, p, k),
|
vm::Registers::load(x, y, n, p, k),
|
||||||
|
&mut vm_stack,
|
||||||
)
|
)
|
||||||
.run()
|
.run()
|
||||||
.unwrap_bool();
|
.output_bool()
|
||||||
|
.unwrap();
|
||||||
if related {
|
if related {
|
||||||
min_x_relations_candidate += 1;
|
min_x_relations_candidate += 1;
|
||||||
|
|
||||||
|
@ -344,10 +389,14 @@ impl Prover {
|
||||||
let mut max_y_relations_candidate = 0_u64;
|
let mut max_y_relations_candidate = 0_u64;
|
||||||
|
|
||||||
for x in &mut a_set_iterator {
|
for x in &mut a_set_iterator {
|
||||||
let related =
|
let related = vm::Vm::load(
|
||||||
vm::Vm::load(&self.relationship, vm::Registers::load(x, y, n, p, k))
|
&self.relationship,
|
||||||
.run()
|
vm::Registers::load(x, y, n, p, k),
|
||||||
.unwrap_bool();
|
&mut vm_stack,
|
||||||
|
)
|
||||||
|
.run()
|
||||||
|
.output_bool()
|
||||||
|
.unwrap();
|
||||||
if related {
|
if related {
|
||||||
min_y_relations_candidate += 1;
|
min_y_relations_candidate += 1;
|
||||||
|
|
||||||
|
@ -374,9 +423,9 @@ impl Prover {
|
||||||
joint_bounding_x = bounding_x_candidate;
|
joint_bounding_x = bounding_x_candidate;
|
||||||
joint_bounding_y = bounding_y_candidate;
|
joint_bounding_y = bounding_y_candidate;
|
||||||
bounding_window = window;
|
bounding_window = window;
|
||||||
|
|
||||||
Python::with_gil(|py| py.check_signals())?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Python::with_gil(|py| py.check_signals())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(BoundsResult {
|
Ok(BoundsResult {
|
||||||
|
@ -437,4 +486,4 @@ fn test_full_run() {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
std::hint::black_box(obj.find_bounds(10, 5, 3)).expect("Success");
|
std::hint::black_box(obj.find_bounds(10, 5, 3)).expect("Success");
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
330
src/vm/mod.rs
330
src/vm/mod.rs
|
@ -4,7 +4,8 @@ use parsing::ast;
|
||||||
|
|
||||||
use crate::vm::parsing::ast::ComparisonOperator;
|
use crate::vm::parsing::ast::ComparisonOperator;
|
||||||
|
|
||||||
type Bytecode = (OpCode, usize);
|
// type Bytecode = (OpCode, usize);
|
||||||
|
type Bytecode = OpCode;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct VmCode {
|
pub struct VmCode {
|
||||||
|
@ -15,6 +16,13 @@ pub struct VmCode {
|
||||||
pub struct Vm<'code> {
|
pub struct Vm<'code> {
|
||||||
code: &'code VmCode,
|
code: &'code VmCode,
|
||||||
registers: Registers,
|
registers: Registers,
|
||||||
|
stack: &'code mut VmStack,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct VmStack {
|
||||||
|
boolean_stack: Vec<bool>,
|
||||||
|
arithmetic_stack: Vec<ArithmeticValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -26,11 +34,12 @@ pub struct Registers {
|
||||||
k: i64,
|
k: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum VmOutput {
|
pub enum VmOutput {
|
||||||
Boolean(bool),
|
Boolean(bool),
|
||||||
Arithmetic(ArithmeticValue),
|
Arithmetic(ArithmeticValue),
|
||||||
}
|
} */
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ArithmeticValue {
|
pub enum ArithmeticValue {
|
||||||
|
@ -54,7 +63,23 @@ enum Expression {
|
||||||
Arithmetic(ast::ArithmeticExpression),
|
Arithmetic(ast::ArithmeticExpression),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VmOutput {
|
impl VmStack {
|
||||||
|
pub fn from_code(code: &VmCode) -> Self {
|
||||||
|
// Each op_code can produce at most one value onto the stack
|
||||||
|
let code = &code.code;
|
||||||
|
VmStack {
|
||||||
|
boolean_stack: Vec::with_capacity(code.len()),
|
||||||
|
arithmetic_stack: Vec::with_capacity(code.len()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset(&mut self) {
|
||||||
|
self.boolean_stack.clear();
|
||||||
|
self.arithmetic_stack.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*impl VmOutput {
|
||||||
pub fn unwrap_bool(self) -> bool {
|
pub fn unwrap_bool(self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
VmOutput::Boolean(value) => value,
|
VmOutput::Boolean(value) => value,
|
||||||
|
@ -73,6 +98,7 @@ impl VmOutput {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
impl Registers {
|
impl Registers {
|
||||||
pub fn load(x: u64, y: u64, n: u32, p: u32, k: u32) -> Self {
|
pub fn load(x: u64, y: u64, n: u32, p: u32, k: u32) -> Self {
|
||||||
|
@ -101,7 +127,7 @@ pub fn compile_arithmetic(expression: ast::ArithmeticExpression) -> VmCode {
|
||||||
|
|
||||||
impl VmCode {
|
impl VmCode {
|
||||||
pub fn any<P: Fn(&OpCode) -> bool>(&self, predicate: P) -> bool {
|
pub fn any<P: Fn(&OpCode) -> bool>(&self, predicate: P) -> bool {
|
||||||
for (opcode, _) in &self.code {
|
for opcode in &self.code {
|
||||||
if predicate(opcode) {
|
if predicate(opcode) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -111,16 +137,211 @@ impl VmCode {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'code> Vm<'code> {
|
impl<'code> Vm<'code> {
|
||||||
pub fn load(code: &'code VmCode, registers: Registers) -> Self {
|
pub fn load(code: &'code VmCode, registers: Registers, stack: &'code mut VmStack) -> Self {
|
||||||
Vm { code, registers }
|
Vm {
|
||||||
|
code,
|
||||||
|
registers,
|
||||||
|
stack,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(&self) -> VmOutput {
|
pub fn run(&mut self) -> &mut Self {
|
||||||
|
// Arithmetic operations for `ArithmeticValue`s
|
||||||
|
fn operations_as_integer(
|
||||||
|
op: &ast::BinaryArithmeticOperator,
|
||||||
|
left_operand: i64,
|
||||||
|
right_operand: i64,
|
||||||
|
) -> ArithmeticValue {
|
||||||
|
match op {
|
||||||
|
ast::BinaryArithmeticOperator::Times => {
|
||||||
|
ArithmeticValue::Integer(left_operand * right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Divide => {
|
||||||
|
if left_operand % right_operand == 0 {
|
||||||
|
ArithmeticValue::Integer(left_operand / right_operand)
|
||||||
|
} else {
|
||||||
|
ArithmeticValue::Floating(left_operand as f64 / right_operand as f64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Plus => {
|
||||||
|
ArithmeticValue::Integer(left_operand + right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Minus => {
|
||||||
|
ArithmeticValue::Integer(left_operand - right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Xor => {
|
||||||
|
ArithmeticValue::Integer(left_operand ^ right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Pow => {
|
||||||
|
if right_operand > 0 {
|
||||||
|
ArithmeticValue::Integer(left_operand.pow(right_operand as u32))
|
||||||
|
} else {
|
||||||
|
ArithmeticValue::Floating((left_operand as f64).powi(right_operand as i32))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn operations_as_floating(
|
||||||
|
op: &ast::BinaryArithmeticOperator,
|
||||||
|
left_operand: f64,
|
||||||
|
right_operand: f64,
|
||||||
|
) -> ArithmeticValue {
|
||||||
|
match op {
|
||||||
|
ast::BinaryArithmeticOperator::Times => {
|
||||||
|
ArithmeticValue::Floating(left_operand * right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Divide => {
|
||||||
|
ArithmeticValue::Floating(left_operand as f64 / right_operand as f64)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Plus => {
|
||||||
|
ArithmeticValue::Floating(left_operand + right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Minus => {
|
||||||
|
ArithmeticValue::Floating(left_operand - right_operand)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Xor => {
|
||||||
|
ArithmeticValue::Integer(left_operand as i64 ^ right_operand as i64)
|
||||||
|
}
|
||||||
|
ast::BinaryArithmeticOperator::Pow => {
|
||||||
|
ArithmeticValue::Floating(left_operand.powf(right_operand))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn comparison<T: std::cmp::PartialOrd>(
|
||||||
|
op: &ComparisonOperator,
|
||||||
|
left_operand: T,
|
||||||
|
right_operand: T,
|
||||||
|
) -> bool {
|
||||||
|
match op {
|
||||||
|
ast::ComparisonOperator::GreaterOrEqual => left_operand.ge(&right_operand),
|
||||||
|
ast::ComparisonOperator::LessOrEqual => left_operand.le(&right_operand),
|
||||||
|
ast::ComparisonOperator::GreaterThan => left_operand.gt(&right_operand),
|
||||||
|
ast::ComparisonOperator::LessThan => left_operand.lt(&right_operand),
|
||||||
|
ast::ComparisonOperator::NotEqual => left_operand.ne(&right_operand),
|
||||||
|
ast::ComparisonOperator::Equal => left_operand.eq(&right_operand),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Alias for convenience
|
// Alias for convenience
|
||||||
let code = &self.code.code;
|
let code = &self.code.code;
|
||||||
let registers = &self.registers;
|
let registers = &self.registers;
|
||||||
|
let stack = &mut self.stack;
|
||||||
|
|
||||||
struct Resolver<'f> {
|
stack.reset();
|
||||||
|
|
||||||
|
for opcode in code.iter().rev() {
|
||||||
|
match opcode {
|
||||||
|
OpCode::UnaryBooleanOperator(op) => {
|
||||||
|
let operand = stack.boolean_stack.pop().unwrap();
|
||||||
|
match op {
|
||||||
|
ast::UnaryBooleanOperator::Not => {
|
||||||
|
stack.boolean_stack.push(!operand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OpCode::BinaryArithmeticOperator(op) => {
|
||||||
|
let left_operand = stack.arithmetic_stack.pop().unwrap();
|
||||||
|
let right_operand = stack.arithmetic_stack.pop().unwrap();
|
||||||
|
let value = match left_operand {
|
||||||
|
ArithmeticValue::Integer(left_operand) => match right_operand {
|
||||||
|
ArithmeticValue::Integer(right_operand) => {
|
||||||
|
operations_as_integer(op, left_operand, right_operand)
|
||||||
|
}
|
||||||
|
ArithmeticValue::Floating(right_operand) => {
|
||||||
|
operations_as_floating(op, left_operand as f64, right_operand)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ArithmeticValue::Floating(left_operand) => match right_operand {
|
||||||
|
ArithmeticValue::Integer(right_operand) => {
|
||||||
|
operations_as_floating(op, left_operand, right_operand as f64)
|
||||||
|
}
|
||||||
|
ArithmeticValue::Floating(right_operand) => {
|
||||||
|
operations_as_floating(op, left_operand, right_operand)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
stack.arithmetic_stack.push(value);
|
||||||
|
}
|
||||||
|
OpCode::ComparisonOperator(op) => {
|
||||||
|
let left_operand = stack.arithmetic_stack.pop().unwrap();
|
||||||
|
let right_operand = stack.arithmetic_stack.pop().unwrap();
|
||||||
|
let value = match left_operand {
|
||||||
|
ArithmeticValue::Integer(left_operand) => match right_operand {
|
||||||
|
ArithmeticValue::Integer(right_operand) => {
|
||||||
|
comparison(op, left_operand, right_operand)
|
||||||
|
}
|
||||||
|
ArithmeticValue::Floating(right_operand) => {
|
||||||
|
comparison(op, left_operand as f64, right_operand)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ArithmeticValue::Floating(left_operand) => match right_operand {
|
||||||
|
ArithmeticValue::Integer(right_operand) => {
|
||||||
|
comparison(op, left_operand, right_operand as f64)
|
||||||
|
}
|
||||||
|
ArithmeticValue::Floating(right_operand) => {
|
||||||
|
comparison(op, left_operand, right_operand)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
stack.boolean_stack.push(value);
|
||||||
|
}
|
||||||
|
OpCode::UnaryArithmeticOperator(op) => {
|
||||||
|
let operand = stack.arithmetic_stack.pop().unwrap();
|
||||||
|
let value = match op {
|
||||||
|
ast::UnaryArithmeticOperator::Negative => match operand {
|
||||||
|
ArithmeticValue::Integer(operand) => ArithmeticValue::Integer(-operand),
|
||||||
|
ArithmeticValue::Floating(operand) => {
|
||||||
|
ArithmeticValue::Floating(-operand)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ast::UnaryArithmeticOperator::Ham => {
|
||||||
|
ArithmeticValue::Integer(match operand {
|
||||||
|
ArithmeticValue::Integer(operand) => operand.count_ones() as i64,
|
||||||
|
ArithmeticValue::Floating(operand) => {
|
||||||
|
(operand.round() as i64).count_ones() as i64
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ast::UnaryArithmeticOperator::Sqrt => {
|
||||||
|
ArithmeticValue::Floating(match operand {
|
||||||
|
ArithmeticValue::Integer(operand) => (operand as f64).sqrt(),
|
||||||
|
ArithmeticValue::Floating(operand) => operand.sqrt(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
stack.arithmetic_stack.push(value);
|
||||||
|
}
|
||||||
|
OpCode::BinaryBooleanOperator(op) => {
|
||||||
|
let left_operand = stack.boolean_stack.pop().unwrap();
|
||||||
|
let right_operand = stack.boolean_stack.pop().unwrap();
|
||||||
|
let value = match op {
|
||||||
|
ast::BinaryBooleanOperator::And => left_operand & right_operand,
|
||||||
|
ast::BinaryBooleanOperator::Or => left_operand | right_operand,
|
||||||
|
ast::BinaryBooleanOperator::Xor => left_operand ^ right_operand,
|
||||||
|
};
|
||||||
|
stack.boolean_stack.push(value);
|
||||||
|
}
|
||||||
|
OpCode::Variable(variable) => {
|
||||||
|
let value = ArithmeticValue::Integer(match variable {
|
||||||
|
ast::Variable::X => registers.x,
|
||||||
|
ast::Variable::Y => registers.y,
|
||||||
|
ast::Variable::N => registers.n,
|
||||||
|
ast::Variable::P => registers.p,
|
||||||
|
ast::Variable::K => registers.k,
|
||||||
|
});
|
||||||
|
stack.arithmetic_stack.push(value);
|
||||||
|
}
|
||||||
|
OpCode::Literal(literal) => {
|
||||||
|
stack
|
||||||
|
.arithmetic_stack
|
||||||
|
.push(ArithmeticValue::Integer(*literal));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*struct Resolver<'f> {
|
||||||
f: &'f dyn Fn(&Resolver, usize) -> VmOutput,
|
f: &'f dyn Fn(&Resolver, usize) -> VmOutput,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,7 +567,18 @@ impl<'code> Vm<'code> {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
(resolver.f)(&resolver, 0)
|
(resolver.f)(&resolver, 0)*/
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn output_bool(&mut self) -> Result<bool, ()> {
|
||||||
|
self.stack.boolean_stack.pop().ok_or(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn output_arithmetic(&mut self) -> Result<ArithmeticValue, ()> {
|
||||||
|
self.stack.arithmetic_stack.pop().ok_or(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,28 +588,32 @@ fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize
|
||||||
Expression::Boolean(expression) => match expression {
|
Expression::Boolean(expression) => match expression {
|
||||||
ast::BooleanExpression::BinaryBooleanConjunction(expression) => {
|
ast::BooleanExpression::BinaryBooleanConjunction(expression) => {
|
||||||
let expression = *expression;
|
let expression = *expression;
|
||||||
code.push((OpCode::BinaryBooleanOperator(expression.operator), 0));
|
// code.push((OpCode::BinaryBooleanOperator(expression.operator), 0));
|
||||||
|
code.push(OpCode::BinaryBooleanOperator(expression.operator));
|
||||||
let index = code.len() - 1;
|
let index = code.len() - 1;
|
||||||
let left_operand_size =
|
let left_operand_size =
|
||||||
compile_expression(Expression::Boolean(expression.left_operand), code);
|
compile_expression(Expression::Boolean(expression.left_operand), code);
|
||||||
let right_operand_size =
|
let right_operand_size =
|
||||||
compile_expression(Expression::Boolean(expression.right_operand), code);
|
compile_expression(Expression::Boolean(expression.right_operand), code);
|
||||||
code[index].1 = left_operand_size + 1;
|
// code[index].1 = left_operand_size + 1;
|
||||||
return 1 + left_operand_size + right_operand_size;
|
return 1 + left_operand_size + right_operand_size;
|
||||||
}
|
}
|
||||||
ast::BooleanExpression::UnaryBooleanConjunction(expression) => {
|
ast::BooleanExpression::UnaryBooleanConjunction(expression) => {
|
||||||
let expression = *expression;
|
let expression = *expression;
|
||||||
code.push((OpCode::UnaryBooleanOperator(expression.operator), 0));
|
// code.push((OpCode::UnaryBooleanOperator(expression.operator), 0));
|
||||||
|
code.push(OpCode::UnaryBooleanOperator(expression.operator));
|
||||||
let operand_size =
|
let operand_size =
|
||||||
compile_expression(Expression::Boolean(expression.operand), code);
|
compile_expression(Expression::Boolean(expression.operand), code);
|
||||||
return 1 + operand_size;
|
return 1 + operand_size;
|
||||||
}
|
}
|
||||||
ast::BooleanExpression::ComparisonConjunction(expression) => {
|
ast::BooleanExpression::ComparisonConjunction(expression) => {
|
||||||
code.push((OpCode::ComparisonOperator(expression.operator), 0));
|
// code.push((OpCode::ComparisonOperator(expression.operator), 0));
|
||||||
|
code.push(OpCode::ComparisonOperator(expression.operator));
|
||||||
let index = code.len() - 1;
|
let index = code.len() - 1;
|
||||||
let left_operand_size = match expression.left_operand {
|
let left_operand_size = match expression.left_operand {
|
||||||
ast::ArithmeticOperand::Literal(literal) => {
|
ast::ArithmeticOperand::Literal(literal) => {
|
||||||
code.push((OpCode::Literal(literal), 0));
|
// code.push((OpCode::Literal(literal), 0));
|
||||||
|
code.push(OpCode::Literal(literal));
|
||||||
1_usize
|
1_usize
|
||||||
}
|
}
|
||||||
ast::ArithmeticOperand::Expression(expression) => {
|
ast::ArithmeticOperand::Expression(expression) => {
|
||||||
|
@ -386,28 +622,32 @@ fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize
|
||||||
};
|
};
|
||||||
let right_operand_size = match expression.right_operand {
|
let right_operand_size = match expression.right_operand {
|
||||||
ast::ArithmeticOperand::Literal(literal) => {
|
ast::ArithmeticOperand::Literal(literal) => {
|
||||||
code.push((OpCode::Literal(literal), 0));
|
// code.push((OpCode::Literal(literal), 0));
|
||||||
|
code.push(OpCode::Literal(literal));
|
||||||
1_usize
|
1_usize
|
||||||
}
|
}
|
||||||
ast::ArithmeticOperand::Expression(expression) => {
|
ast::ArithmeticOperand::Expression(expression) => {
|
||||||
compile_expression(Expression::Arithmetic(*expression), code)
|
compile_expression(Expression::Arithmetic(*expression), code)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
code[index].1 = left_operand_size + 1;
|
// code[index].1 = left_operand_size + 1;
|
||||||
return 1 + left_operand_size + right_operand_size;
|
return 1 + left_operand_size + right_operand_size;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Expression::Arithmetic(expression) => match expression {
|
Expression::Arithmetic(expression) => match expression {
|
||||||
ast::ArithmeticExpression::Variable(variable) => {
|
ast::ArithmeticExpression::Variable(variable) => {
|
||||||
code.push((OpCode::Variable(variable), 0));
|
// code.push((OpCode::Variable(variable), 0));
|
||||||
|
code.push(OpCode::Variable(variable));
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
ast::ArithmeticExpression::UnaryArithmeticConjunction(expression) => {
|
ast::ArithmeticExpression::UnaryArithmeticConjunction(expression) => {
|
||||||
let expression = *expression;
|
let expression = *expression;
|
||||||
code.push((OpCode::UnaryArithmeticOperator(expression.operator), 0));
|
// code.push((OpCode::UnaryArithmeticOperator(expression.operator), 0));
|
||||||
|
code.push(OpCode::UnaryArithmeticOperator(expression.operator));
|
||||||
let operand_size = match expression.operand {
|
let operand_size = match expression.operand {
|
||||||
ast::ArithmeticOperand::Literal(literal) => {
|
ast::ArithmeticOperand::Literal(literal) => {
|
||||||
code.push((OpCode::Literal(literal), 0));
|
// code.push((OpCode::Literal(literal), 0));
|
||||||
|
code.push(OpCode::Literal(literal));
|
||||||
1_usize
|
1_usize
|
||||||
}
|
}
|
||||||
ast::ArithmeticOperand::Expression(expression) => {
|
ast::ArithmeticOperand::Expression(expression) => {
|
||||||
|
@ -417,11 +657,13 @@ fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize
|
||||||
return 1 + operand_size;
|
return 1 + operand_size;
|
||||||
}
|
}
|
||||||
ast::ArithmeticExpression::BinaryArithmeticConjunction(expression) => {
|
ast::ArithmeticExpression::BinaryArithmeticConjunction(expression) => {
|
||||||
code.push((OpCode::BinaryArithmeticOperator(expression.operator), 0));
|
// code.push((OpCode::BinaryArithmeticOperator(expression.operator), 0));
|
||||||
|
code.push(OpCode::BinaryArithmeticOperator(expression.operator));
|
||||||
let index = code.len() - 1;
|
let index = code.len() - 1;
|
||||||
let left_operand_size = match expression.left_operand {
|
let left_operand_size = match expression.left_operand {
|
||||||
ast::ArithmeticOperand::Literal(literal) => {
|
ast::ArithmeticOperand::Literal(literal) => {
|
||||||
code.push((OpCode::Literal(literal), 0));
|
// code.push((OpCode::Literal(literal), 0));
|
||||||
|
code.push(OpCode::Literal(literal));
|
||||||
1_usize
|
1_usize
|
||||||
}
|
}
|
||||||
ast::ArithmeticOperand::Expression(expression) => {
|
ast::ArithmeticOperand::Expression(expression) => {
|
||||||
|
@ -430,14 +672,15 @@ fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize
|
||||||
};
|
};
|
||||||
let right_operand_size = match expression.right_operand {
|
let right_operand_size = match expression.right_operand {
|
||||||
ast::ArithmeticOperand::Literal(literal) => {
|
ast::ArithmeticOperand::Literal(literal) => {
|
||||||
code.push((OpCode::Literal(literal), 0));
|
// code.push((OpCode::Literal(literal), 0));
|
||||||
|
code.push(OpCode::Literal(literal));
|
||||||
1_usize
|
1_usize
|
||||||
}
|
}
|
||||||
ast::ArithmeticOperand::Expression(expression) => {
|
ast::ArithmeticOperand::Expression(expression) => {
|
||||||
compile_expression(Expression::Arithmetic(*expression), code)
|
compile_expression(Expression::Arithmetic(*expression), code)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
code[index].1 = left_operand_size + 1;
|
// code[index].1 = left_operand_size + 1;
|
||||||
return 1 + left_operand_size + right_operand_size;
|
return 1 + left_operand_size + right_operand_size;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -452,13 +695,48 @@ fn boolean_compilation_test() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn run_test() {
|
fn test_vm_simple() {
|
||||||
let expression = parsing::parse_relation("(= (ham (^ x y)) (ham 1))");
|
let expression = parsing::parse_relation("(= (ham (^ x y)) (ham 1))");
|
||||||
if let Err(e) = expression {
|
if let Err(e) = expression {
|
||||||
println!("{}", e);
|
println!("{}", e);
|
||||||
panic!();
|
panic!();
|
||||||
}
|
}
|
||||||
let code = compile_boolean(expression.unwrap());
|
let code = compile_boolean(expression.unwrap());
|
||||||
let vm = Vm::load(&code, Registers::load(0b_1011, 0b_1111, 10, 2, 3));
|
let mut stack = VmStack::from_code(&code);
|
||||||
println!("{:?}", vm.run());
|
let mut vm = Vm::load(
|
||||||
|
&code,
|
||||||
|
Registers::load(0b_1011, 0b_1111, 10, 2, 3),
|
||||||
|
&mut stack,
|
||||||
|
);
|
||||||
|
let output = vm.run().output_bool().unwrap();
|
||||||
|
assert_eq!(output, true);
|
||||||
|
let mut vm = Vm::load(
|
||||||
|
&code,
|
||||||
|
Registers::load(0b_1011, 0b_1110, 10, 2, 3),
|
||||||
|
&mut stack,
|
||||||
|
);
|
||||||
|
let output = vm.run().output_bool().unwrap();
|
||||||
|
assert_eq!(output, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_vm_contrived() {
|
||||||
|
let expression = parsing::parse_relation("and (= (ham (^ x y)) (ham (+ 3 x))) (> (* x y) 5)");
|
||||||
|
if let Err(e) = expression {
|
||||||
|
println!("{}", e);
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
let code = compile_boolean(expression.unwrap());
|
||||||
|
let mut stack = VmStack::from_code(&code);
|
||||||
|
|
||||||
|
let as_rust = |x: u64, y: u64| ((x + 3).count_ones() == (x ^ y).count_ones()) && (x * y > 5);
|
||||||
|
|
||||||
|
for x in 0..(1 << 4) {
|
||||||
|
for y in 0..(1 << 4) {
|
||||||
|
let mut vm = Vm::load(&code, Registers::load(x, y, 10, 2, 3), &mut stack);
|
||||||
|
let output = vm.run().output_bool().unwrap();
|
||||||
|
let expected = as_rust(x, y);
|
||||||
|
assert_eq!(output, expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue