touchups
This commit is contained in:
parent
e85346e916
commit
39cfb20f13
30
src/lib.rs
30
src/lib.rs
|
@ -54,9 +54,8 @@ struct CachedSetIterator<'code> {
|
||||||
p: u32,
|
p: u32,
|
||||||
k: u32,
|
k: u32,
|
||||||
cached: bool,
|
cached: bool,
|
||||||
counter: u64,
|
counter: usize,
|
||||||
cache_counter: usize,
|
top: usize,
|
||||||
top: u64,
|
|
||||||
is_x: bool,
|
is_x: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,8 +70,7 @@ impl<'code> CachedSetIterator<'code> {
|
||||||
k,
|
k,
|
||||||
cached: false,
|
cached: false,
|
||||||
counter: 0,
|
counter: 0,
|
||||||
cache_counter: 0,
|
top: 2_usize.pow(n as u32),
|
||||||
top: 2_u64.pow(n as u32),
|
|
||||||
is_x: true,
|
is_x: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,15 +85,13 @@ impl<'code> CachedSetIterator<'code> {
|
||||||
k,
|
k,
|
||||||
cached: false,
|
cached: false,
|
||||||
counter: 0,
|
counter: 0,
|
||||||
cache_counter: 0,
|
top: 2_usize.pow(n as u32),
|
||||||
top: 2_u64.pow(n as u32),
|
|
||||||
is_x: false,
|
is_x: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset(&mut self) {
|
fn reset(&mut self) {
|
||||||
self.counter = 0;
|
self.counter = 0;
|
||||||
self.cache_counter = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,12 +100,12 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.cached {
|
if self.cached {
|
||||||
if self.cache_counter < self.cache.len() {
|
if self.counter < self.cache.len() {
|
||||||
let result = self.cache[self.cache_counter];
|
let result = self.cache[self.counter as usize];
|
||||||
self.cache_counter += 1;
|
self.counter += 1;
|
||||||
|
|
||||||
if self.cache_counter == CACHE_SIZE_LIMIT {
|
if self.counter == CACHE_SIZE_LIMIT {
|
||||||
self.counter = self.cache.last().unwrap() + 1;
|
self.counter = *self.cache.last().unwrap() as usize + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Some(result);
|
return Some(result);
|
||||||
|
@ -129,7 +125,7 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
||||||
let included = if self.is_x {
|
let included = if self.is_x {
|
||||||
vm::Vm::load(
|
vm::Vm::load(
|
||||||
self.filter,
|
self.filter,
|
||||||
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
|
vm::Registers::load(self.counter as u64, 0, self.n, self.p, self.k),
|
||||||
&mut self.stack,
|
&mut self.stack,
|
||||||
)
|
)
|
||||||
.run()
|
.run()
|
||||||
|
@ -138,7 +134,7 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
||||||
} else {
|
} else {
|
||||||
vm::Vm::load(
|
vm::Vm::load(
|
||||||
self.filter,
|
self.filter,
|
||||||
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
|
vm::Registers::load(0, self.counter as u64, self.n, self.p, self.k),
|
||||||
&mut self.stack,
|
&mut self.stack,
|
||||||
)
|
)
|
||||||
.run()
|
.run()
|
||||||
|
@ -150,7 +146,7 @@ impl<'code> Iterator for &mut CachedSetIterator<'code> {
|
||||||
self.counter += 1;
|
self.counter += 1;
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
let result = self.counter;
|
let result = self.counter as u64;
|
||||||
if self.cache.len() < CACHE_SIZE_LIMIT {
|
if self.cache.len() < CACHE_SIZE_LIMIT {
|
||||||
self.cache.push(result);
|
self.cache.push(result);
|
||||||
}
|
}
|
||||||
|
@ -307,7 +303,7 @@ impl Prover {
|
||||||
/// belonging to either of the specified sets [int]
|
/// belonging to either of the specified sets [int]
|
||||||
/// p The number of parallel queries [int]
|
/// p The number of parallel queries [int]
|
||||||
/// k Parameter accepted in set specifications and relations [int]
|
/// k Parameter accepted in set specifications and relations [int]
|
||||||
pub fn find_bounds(&self, n: u32, p: u32, k: u32) -> PyResult<BoundsResult> {
|
pub fn find_bounds(&self, n: u32, k: u32, p: u32) -> PyResult<BoundsResult> {
|
||||||
if n > 63 {
|
if n > 63 {
|
||||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||||
"More than 63 bits is not supported. (Because of `n` parameter.)",
|
"More than 63 bits is not supported. (Because of `n` parameter.)",
|
||||||
|
|
229
src/vm/mod.rs
229
src/vm/mod.rs
|
@ -340,235 +340,6 @@ impl<'code> Vm<'code> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*struct Resolver<'f> {
|
|
||||||
f: &'f dyn Fn(&Resolver, usize) -> VmOutput,
|
|
||||||
}
|
|
||||||
|
|
||||||
let resolver = Resolver {
|
|
||||||
f: &|resolver: &Resolver, op_index: usize| -> VmOutput {
|
|
||||||
let (opcode, offset) = &code[op_index];
|
|
||||||
match opcode {
|
|
||||||
OpCode::UnaryBooleanOperator(op) => match op {
|
|
||||||
ast::UnaryBooleanOperator::Not => {
|
|
||||||
let value = (resolver.f)(resolver, op_index + 1);
|
|
||||||
match value {
|
|
||||||
VmOutput::Arithmetic(_) => panic!("Bad bytecode! Unary operator is followed by arithmetic resolution."),
|
|
||||||
VmOutput::Boolean(value) => VmOutput::Boolean(!value),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
OpCode::ComparisonOperator(op) => {
|
|
||||||
let left_operand = match (resolver.f)(resolver, op_index + 1) {
|
|
||||||
VmOutput::Arithmetic(value) => value,
|
|
||||||
VmOutput::Boolean(_) => panic!(
|
|
||||||
"Bad bytecode! Left operand of arithmetic operator is boolean."
|
|
||||||
),
|
|
||||||
};
|
|
||||||
let right_operand = match (resolver.f)(resolver, op_index + offset) {
|
|
||||||
VmOutput::Arithmetic(value) => value,
|
|
||||||
VmOutput::Boolean(_) => panic!(
|
|
||||||
"Bad bytecode! Right operand of arithmetic operator is boolean."
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
fn comparison<T: std::cmp::PartialOrd>(
|
|
||||||
op: &ComparisonOperator,
|
|
||||||
left_operand: T,
|
|
||||||
right_operand: T,
|
|
||||||
) -> VmOutput {
|
|
||||||
VmOutput::Boolean(match op {
|
|
||||||
ast::ComparisonOperator::GreaterOrEqual => {
|
|
||||||
left_operand.ge(&right_operand)
|
|
||||||
}
|
|
||||||
ast::ComparisonOperator::LessOrEqual => {
|
|
||||||
left_operand.le(&right_operand)
|
|
||||||
}
|
|
||||||
ast::ComparisonOperator::GreaterThan => {
|
|
||||||
left_operand.gt(&right_operand)
|
|
||||||
}
|
|
||||||
ast::ComparisonOperator::LessThan => {
|
|
||||||
left_operand.lt(&right_operand)
|
|
||||||
}
|
|
||||||
ast::ComparisonOperator::NotEqual => {
|
|
||||||
left_operand.ne(&right_operand)
|
|
||||||
}
|
|
||||||
ast::ComparisonOperator::Equal => left_operand.eq(&right_operand),
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
match left_operand {
|
|
||||||
ArithmeticValue::Integer(left_operand) => match right_operand {
|
|
||||||
ArithmeticValue::Integer(right_operand) => {
|
|
||||||
comparison(op, left_operand, right_operand)
|
|
||||||
}
|
|
||||||
ArithmeticValue::Floating(right_operand) => {
|
|
||||||
comparison(op, left_operand as f64, right_operand)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ArithmeticValue::Floating(left_operand) => match right_operand {
|
|
||||||
ArithmeticValue::Integer(right_operand) => {
|
|
||||||
comparison(op, left_operand, right_operand as f64)
|
|
||||||
}
|
|
||||||
ArithmeticValue::Floating(right_operand) => {
|
|
||||||
comparison(op, left_operand, right_operand)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
OpCode::BinaryArithmeticOperator(op) => {
|
|
||||||
let left_operand = match (resolver.f)(resolver, op_index + 1) {
|
|
||||||
VmOutput::Arithmetic(value) => value,
|
|
||||||
VmOutput::Boolean(_) => panic!(
|
|
||||||
"Bad bytecode! Left operand of arithmetic operator is boolean."
|
|
||||||
),
|
|
||||||
};
|
|
||||||
let right_operand = match (resolver.f)(resolver, op_index + offset) {
|
|
||||||
VmOutput::Arithmetic(value) => value,
|
|
||||||
VmOutput::Boolean(_) => panic!(
|
|
||||||
"Bad bytecode! Right operand of arithmetic operator is boolean."
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
let operations_as_integer = |left_operand: i64, right_operand: i64| match op
|
|
||||||
{
|
|
||||||
ast::BinaryArithmeticOperator::Times => {
|
|
||||||
ArithmeticValue::Integer(left_operand * right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Divide => {
|
|
||||||
if left_operand % right_operand == 0 {
|
|
||||||
ArithmeticValue::Integer(left_operand / right_operand)
|
|
||||||
} else {
|
|
||||||
ArithmeticValue::Floating(
|
|
||||||
left_operand as f64 / right_operand as f64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Plus => {
|
|
||||||
ArithmeticValue::Integer(left_operand + right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Minus => {
|
|
||||||
ArithmeticValue::Integer(left_operand - right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Xor => {
|
|
||||||
ArithmeticValue::Integer(left_operand ^ right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Pow => {
|
|
||||||
if right_operand > 0 {
|
|
||||||
ArithmeticValue::Integer(left_operand.pow(right_operand as u32))
|
|
||||||
} else {
|
|
||||||
ArithmeticValue::Floating(
|
|
||||||
(left_operand as f64).powi(right_operand as i32),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let operations_as_floating =
|
|
||||||
|left_operand: f64, right_operand: f64| match op {
|
|
||||||
ast::BinaryArithmeticOperator::Times => {
|
|
||||||
ArithmeticValue::Floating(left_operand * right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Divide => ArithmeticValue::Floating(
|
|
||||||
left_operand as f64 / right_operand as f64,
|
|
||||||
),
|
|
||||||
ast::BinaryArithmeticOperator::Plus => {
|
|
||||||
ArithmeticValue::Floating(left_operand + right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Minus => {
|
|
||||||
ArithmeticValue::Floating(left_operand - right_operand)
|
|
||||||
}
|
|
||||||
ast::BinaryArithmeticOperator::Xor => ArithmeticValue::Integer(
|
|
||||||
left_operand as i64 ^ right_operand as i64,
|
|
||||||
),
|
|
||||||
ast::BinaryArithmeticOperator::Pow => {
|
|
||||||
ArithmeticValue::Floating(left_operand.powf(right_operand))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
VmOutput::Arithmetic(match left_operand {
|
|
||||||
ArithmeticValue::Integer(left_operand) => match right_operand {
|
|
||||||
ArithmeticValue::Integer(right_operand) => {
|
|
||||||
operations_as_integer(left_operand, right_operand)
|
|
||||||
}
|
|
||||||
ArithmeticValue::Floating(right_operand) => {
|
|
||||||
operations_as_floating(left_operand as f64, right_operand)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ArithmeticValue::Floating(left_operand) => match right_operand {
|
|
||||||
ArithmeticValue::Integer(right_operand) => {
|
|
||||||
operations_as_floating(left_operand, right_operand as f64)
|
|
||||||
}
|
|
||||||
ArithmeticValue::Floating(right_operand) => {
|
|
||||||
operations_as_floating(left_operand, right_operand)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
OpCode::UnaryArithmeticOperator(op) => {
|
|
||||||
let value = match (resolver.f)(resolver, op_index + 1) {
|
|
||||||
VmOutput::Arithmetic(value) => value,
|
|
||||||
VmOutput::Boolean(_) => panic!("Bad bytecode! Arithmetic unary operator followed by boolean resolution."),
|
|
||||||
};
|
|
||||||
VmOutput::Arithmetic(match op {
|
|
||||||
ast::UnaryArithmeticOperator::Negative => match value {
|
|
||||||
ArithmeticValue::Integer(value) => ArithmeticValue::Integer(-value),
|
|
||||||
ArithmeticValue::Floating(value) => {
|
|
||||||
ArithmeticValue::Floating(-value)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ast::UnaryArithmeticOperator::Ham => {
|
|
||||||
ArithmeticValue::Integer(match value {
|
|
||||||
ArithmeticValue::Integer(value) => value.count_ones() as i64,
|
|
||||||
ArithmeticValue::Floating(value) => {
|
|
||||||
(value as u64).count_ones() as i64
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
ast::UnaryArithmeticOperator::Sqrt => {
|
|
||||||
ArithmeticValue::Floating(match value {
|
|
||||||
ArithmeticValue::Integer(value) => (value as f64).sqrt(),
|
|
||||||
ArithmeticValue::Floating(value) => value.sqrt(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
OpCode::BinaryBooleanOperator(op) => {
|
|
||||||
let left_operand = match (resolver.f)(resolver, op_index + 1) {
|
|
||||||
VmOutput::Boolean(value) => value,
|
|
||||||
VmOutput::Arithmetic(_) => panic!(
|
|
||||||
"Bad bytecode! Left operand of boolean operator is arithmetic."
|
|
||||||
),
|
|
||||||
};
|
|
||||||
let right_operand = match (resolver.f)(resolver, op_index + offset) {
|
|
||||||
VmOutput::Boolean(value) => value,
|
|
||||||
VmOutput::Arithmetic(_) => panic!(
|
|
||||||
"Bad bytecode! Right operand of boolean operator is arithmetic."
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
VmOutput::Boolean(match op {
|
|
||||||
ast::BinaryBooleanOperator::And => left_operand & right_operand,
|
|
||||||
ast::BinaryBooleanOperator::Or => left_operand | right_operand,
|
|
||||||
ast::BinaryBooleanOperator::Xor => left_operand ^ right_operand,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
OpCode::Variable(var) => {
|
|
||||||
VmOutput::Arithmetic(ArithmeticValue::Integer(match var {
|
|
||||||
ast::Variable::X => registers.x,
|
|
||||||
ast::Variable::Y => registers.y,
|
|
||||||
ast::Variable::N => registers.n,
|
|
||||||
ast::Variable::P => registers.p,
|
|
||||||
ast::Variable::K => registers.k,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
OpCode::Literal(lit) => VmOutput::Arithmetic(ArithmeticValue::Integer(*lit)),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
(resolver.f)(&resolver, 0)*/
|
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -323,6 +323,19 @@ fn rule_as_text(rule: &Rule) -> &'static str {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_bad() {
|
||||||
|
match parse_relation("< ham x cheese") {
|
||||||
|
Err(e) => {
|
||||||
|
println!("{}", e);
|
||||||
|
}
|
||||||
|
Ok(ast) => {
|
||||||
|
println!("{:?}", ast);
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_test() {
|
fn parse_test() {
|
||||||
match parse_relation("< ham x 3") {
|
match parse_relation("< ham x 3") {
|
||||||
|
@ -331,7 +344,7 @@ fn parse_test() {
|
||||||
panic!();
|
panic!();
|
||||||
}
|
}
|
||||||
Ok(ast) => {
|
Ok(ast) => {
|
||||||
println!("{:?}", ast)
|
println!("{:#?}", ast)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue