514 lines
20 KiB
Rust
514 lines
20 KiB
Rust
pub mod parsing;
|
|
|
|
use parsing::ast;
|
|
|
|
use crate::vm::parsing::ast::ComparisonOperator;
|
|
|
|
// type Bytecode = (OpCode, usize);
|
|
type Bytecode = OpCode;
|
|
|
|
#[derive(Debug)]
|
|
pub struct VmCode {
|
|
code: Vec<Bytecode>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Vm<'code> {
|
|
code: &'code VmCode,
|
|
registers: Registers,
|
|
stack: &'code mut VmStack,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct VmStack {
|
|
boolean_stack: Vec<bool>,
|
|
arithmetic_stack: Vec<ArithmeticValue>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Registers {
|
|
x: i64,
|
|
y: i64,
|
|
n: i64,
|
|
p: i64,
|
|
k: i64,
|
|
}
|
|
|
|
/*
|
|
#[derive(Debug)]
|
|
pub enum VmOutput {
|
|
Boolean(bool),
|
|
Arithmetic(ArithmeticValue),
|
|
} */
|
|
|
|
#[derive(Debug)]
|
|
pub enum ArithmeticValue {
|
|
Integer(i64),
|
|
Floating(f64),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum OpCode {
|
|
UnaryBooleanOperator(ast::UnaryBooleanOperator),
|
|
BinaryArithmeticOperator(ast::BinaryArithmeticOperator),
|
|
ComparisonOperator(ast::ComparisonOperator),
|
|
UnaryArithmeticOperator(ast::UnaryArithmeticOperator),
|
|
BinaryBooleanOperator(ast::BinaryBooleanOperator),
|
|
Variable(ast::Variable),
|
|
Literal(i64),
|
|
}
|
|
|
|
enum Expression {
|
|
Boolean(ast::BooleanExpression),
|
|
Arithmetic(ast::ArithmeticExpression),
|
|
}
|
|
|
|
impl VmStack {
|
|
pub fn from_code(code: &VmCode) -> Self {
|
|
// Each op_code can produce at most one value onto the stack
|
|
let code = &code.code;
|
|
VmStack {
|
|
boolean_stack: Vec::with_capacity(code.len()),
|
|
arithmetic_stack: Vec::with_capacity(code.len()),
|
|
}
|
|
}
|
|
|
|
fn reset(&mut self) {
|
|
self.boolean_stack.clear();
|
|
self.arithmetic_stack.clear();
|
|
}
|
|
}
|
|
|
|
/*impl VmOutput {
|
|
pub fn unwrap_bool(self) -> bool {
|
|
match self {
|
|
VmOutput::Boolean(value) => value,
|
|
VmOutput::Arithmetic(_) => panic!("Expected boolean, got arithmetic value."),
|
|
}
|
|
}
|
|
|
|
#[allow(unused)]
|
|
pub fn unwrap_arithmetic(self) -> f64 {
|
|
match self {
|
|
VmOutput::Boolean(_) => panic!("Expected arithmetic, got boolean value."),
|
|
VmOutput::Arithmetic(value) => match value {
|
|
ArithmeticValue::Integer(value) => value as f64,
|
|
ArithmeticValue::Floating(value) => value,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
*/
|
|
|
|
impl Registers {
|
|
pub fn load(x: u64, y: u64, n: u32, p: u32, k: u32) -> Self {
|
|
Self {
|
|
x: x as i64,
|
|
y: y as i64,
|
|
n: n as i64,
|
|
p: p as i64,
|
|
k: k as i64,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn compile_boolean(expression: ast::BooleanExpression) -> VmCode {
|
|
let mut code = Vec::<Bytecode>::new();
|
|
compile_expression(Expression::Boolean(expression), &mut code);
|
|
VmCode { code }
|
|
}
|
|
|
|
#[allow(unused)]
|
|
pub fn compile_arithmetic(expression: ast::ArithmeticExpression) -> VmCode {
|
|
let mut code = Vec::<Bytecode>::new();
|
|
compile_expression(Expression::Arithmetic(expression), &mut code);
|
|
VmCode { code }
|
|
}
|
|
|
|
impl VmCode {
|
|
pub fn any<P: Fn(&OpCode) -> bool>(&self, predicate: P) -> bool {
|
|
for opcode in &self.code {
|
|
if predicate(opcode) {
|
|
return true;
|
|
}
|
|
}
|
|
false
|
|
}
|
|
}
|
|
|
|
impl<'code> Vm<'code> {
|
|
pub fn load(code: &'code VmCode, registers: Registers, stack: &'code mut VmStack) -> Self {
|
|
Vm {
|
|
code,
|
|
registers,
|
|
stack,
|
|
}
|
|
}
|
|
|
|
pub fn run(&mut self) -> &mut Self {
|
|
// Arithmetic operations for `ArithmeticValue`s
|
|
fn operations_as_integer(
|
|
op: &ast::BinaryArithmeticOperator,
|
|
left_operand: i64,
|
|
right_operand: i64,
|
|
) -> ArithmeticValue {
|
|
match op {
|
|
ast::BinaryArithmeticOperator::Times => {
|
|
ArithmeticValue::Integer(left_operand * right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Divide => {
|
|
if left_operand % right_operand == 0 {
|
|
ArithmeticValue::Integer(left_operand / right_operand)
|
|
} else {
|
|
ArithmeticValue::Floating(left_operand as f64 / right_operand as f64)
|
|
}
|
|
}
|
|
ast::BinaryArithmeticOperator::Plus => {
|
|
ArithmeticValue::Integer(left_operand + right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Minus => {
|
|
ArithmeticValue::Integer(left_operand - right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Xor => {
|
|
ArithmeticValue::Integer(left_operand ^ right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Pow => {
|
|
if right_operand > 0 {
|
|
ArithmeticValue::Integer(left_operand.pow(right_operand as u32))
|
|
} else {
|
|
ArithmeticValue::Floating((left_operand as f64).powi(right_operand as i32))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn operations_as_floating(
|
|
op: &ast::BinaryArithmeticOperator,
|
|
left_operand: f64,
|
|
right_operand: f64,
|
|
) -> ArithmeticValue {
|
|
match op {
|
|
ast::BinaryArithmeticOperator::Times => {
|
|
ArithmeticValue::Floating(left_operand * right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Divide => {
|
|
ArithmeticValue::Floating(left_operand as f64 / right_operand as f64)
|
|
}
|
|
ast::BinaryArithmeticOperator::Plus => {
|
|
ArithmeticValue::Floating(left_operand + right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Minus => {
|
|
ArithmeticValue::Floating(left_operand - right_operand)
|
|
}
|
|
ast::BinaryArithmeticOperator::Xor => {
|
|
ArithmeticValue::Integer(left_operand as i64 ^ right_operand as i64)
|
|
}
|
|
ast::BinaryArithmeticOperator::Pow => {
|
|
ArithmeticValue::Floating(left_operand.powf(right_operand))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn comparison<T: std::cmp::PartialOrd>(
|
|
op: &ComparisonOperator,
|
|
left_operand: T,
|
|
right_operand: T,
|
|
) -> bool {
|
|
match op {
|
|
ast::ComparisonOperator::GreaterOrEqual => left_operand.ge(&right_operand),
|
|
ast::ComparisonOperator::LessOrEqual => left_operand.le(&right_operand),
|
|
ast::ComparisonOperator::GreaterThan => left_operand.gt(&right_operand),
|
|
ast::ComparisonOperator::LessThan => left_operand.lt(&right_operand),
|
|
ast::ComparisonOperator::NotEqual => left_operand.ne(&right_operand),
|
|
ast::ComparisonOperator::Equal => left_operand.eq(&right_operand),
|
|
}
|
|
}
|
|
|
|
// Alias for convenience
|
|
let code = &self.code.code;
|
|
let registers = &self.registers;
|
|
let stack = &mut self.stack;
|
|
|
|
stack.reset();
|
|
|
|
for opcode in code.iter().rev() {
|
|
match opcode {
|
|
OpCode::UnaryBooleanOperator(op) => {
|
|
let operand = stack.boolean_stack.pop().unwrap();
|
|
match op {
|
|
ast::UnaryBooleanOperator::Not => {
|
|
stack.boolean_stack.push(!operand);
|
|
}
|
|
}
|
|
}
|
|
OpCode::BinaryArithmeticOperator(op) => {
|
|
let left_operand = stack.arithmetic_stack.pop().unwrap();
|
|
let right_operand = stack.arithmetic_stack.pop().unwrap();
|
|
let value = match left_operand {
|
|
ArithmeticValue::Integer(left_operand) => match right_operand {
|
|
ArithmeticValue::Integer(right_operand) => {
|
|
operations_as_integer(op, left_operand, right_operand)
|
|
}
|
|
ArithmeticValue::Floating(right_operand) => {
|
|
operations_as_floating(op, left_operand as f64, right_operand)
|
|
}
|
|
},
|
|
ArithmeticValue::Floating(left_operand) => match right_operand {
|
|
ArithmeticValue::Integer(right_operand) => {
|
|
operations_as_floating(op, left_operand, right_operand as f64)
|
|
}
|
|
ArithmeticValue::Floating(right_operand) => {
|
|
operations_as_floating(op, left_operand, right_operand)
|
|
}
|
|
},
|
|
};
|
|
stack.arithmetic_stack.push(value);
|
|
}
|
|
OpCode::ComparisonOperator(op) => {
|
|
let left_operand = stack.arithmetic_stack.pop().unwrap();
|
|
let right_operand = stack.arithmetic_stack.pop().unwrap();
|
|
let value = match left_operand {
|
|
ArithmeticValue::Integer(left_operand) => match right_operand {
|
|
ArithmeticValue::Integer(right_operand) => {
|
|
comparison(op, left_operand, right_operand)
|
|
}
|
|
ArithmeticValue::Floating(right_operand) => {
|
|
comparison(op, left_operand as f64, right_operand)
|
|
}
|
|
},
|
|
ArithmeticValue::Floating(left_operand) => match right_operand {
|
|
ArithmeticValue::Integer(right_operand) => {
|
|
comparison(op, left_operand, right_operand as f64)
|
|
}
|
|
ArithmeticValue::Floating(right_operand) => {
|
|
comparison(op, left_operand, right_operand)
|
|
}
|
|
},
|
|
};
|
|
stack.boolean_stack.push(value);
|
|
}
|
|
OpCode::UnaryArithmeticOperator(op) => {
|
|
let operand = stack.arithmetic_stack.pop().unwrap();
|
|
let value = match op {
|
|
ast::UnaryArithmeticOperator::Negative => match operand {
|
|
ArithmeticValue::Integer(operand) => ArithmeticValue::Integer(-operand),
|
|
ArithmeticValue::Floating(operand) => {
|
|
ArithmeticValue::Floating(-operand)
|
|
}
|
|
},
|
|
ast::UnaryArithmeticOperator::Ham => {
|
|
ArithmeticValue::Integer(match operand {
|
|
ArithmeticValue::Integer(operand) => operand.count_ones() as i64,
|
|
ArithmeticValue::Floating(operand) => {
|
|
(operand.round() as i64).count_ones() as i64
|
|
}
|
|
})
|
|
}
|
|
ast::UnaryArithmeticOperator::Sqrt => {
|
|
ArithmeticValue::Floating(match operand {
|
|
ArithmeticValue::Integer(operand) => (operand as f64).sqrt(),
|
|
ArithmeticValue::Floating(operand) => operand.sqrt(),
|
|
})
|
|
}
|
|
};
|
|
stack.arithmetic_stack.push(value);
|
|
}
|
|
OpCode::BinaryBooleanOperator(op) => {
|
|
let left_operand = stack.boolean_stack.pop().unwrap();
|
|
let right_operand = stack.boolean_stack.pop().unwrap();
|
|
let value = match op {
|
|
ast::BinaryBooleanOperator::And => left_operand & right_operand,
|
|
ast::BinaryBooleanOperator::Or => left_operand | right_operand,
|
|
ast::BinaryBooleanOperator::Xor => left_operand ^ right_operand,
|
|
};
|
|
stack.boolean_stack.push(value);
|
|
}
|
|
OpCode::Variable(variable) => {
|
|
let value = ArithmeticValue::Integer(match variable {
|
|
ast::Variable::X => registers.x,
|
|
ast::Variable::Y => registers.y,
|
|
ast::Variable::N => registers.n,
|
|
ast::Variable::P => registers.p,
|
|
ast::Variable::K => registers.k,
|
|
});
|
|
stack.arithmetic_stack.push(value);
|
|
}
|
|
OpCode::Literal(literal) => {
|
|
stack
|
|
.arithmetic_stack
|
|
.push(ArithmeticValue::Integer(*literal));
|
|
}
|
|
}
|
|
}
|
|
self
|
|
}
|
|
|
|
pub fn output_bool(&mut self) -> Result<bool, ()> {
|
|
self.stack.boolean_stack.pop().ok_or(())
|
|
}
|
|
|
|
#[allow(unused)]
|
|
pub fn output_arithmetic(&mut self) -> Result<ArithmeticValue, ()> {
|
|
self.stack.arithmetic_stack.pop().ok_or(())
|
|
}
|
|
}
|
|
|
|
/// Returns how many code-points (Bytecode units) were emitted
|
|
fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize {
|
|
match expression {
|
|
Expression::Boolean(expression) => match expression {
|
|
ast::BooleanExpression::BinaryBooleanConjunction(expression) => {
|
|
let expression = *expression;
|
|
// code.push((OpCode::BinaryBooleanOperator(expression.operator), 0));
|
|
code.push(OpCode::BinaryBooleanOperator(expression.operator));
|
|
let index = code.len() - 1;
|
|
let left_operand_size =
|
|
compile_expression(Expression::Boolean(expression.left_operand), code);
|
|
let right_operand_size =
|
|
compile_expression(Expression::Boolean(expression.right_operand), code);
|
|
// code[index].1 = left_operand_size + 1;
|
|
return 1 + left_operand_size + right_operand_size;
|
|
}
|
|
ast::BooleanExpression::UnaryBooleanConjunction(expression) => {
|
|
let expression = *expression;
|
|
// code.push((OpCode::UnaryBooleanOperator(expression.operator), 0));
|
|
code.push(OpCode::UnaryBooleanOperator(expression.operator));
|
|
let operand_size =
|
|
compile_expression(Expression::Boolean(expression.operand), code);
|
|
return 1 + operand_size;
|
|
}
|
|
ast::BooleanExpression::ComparisonConjunction(expression) => {
|
|
// code.push((OpCode::ComparisonOperator(expression.operator), 0));
|
|
code.push(OpCode::ComparisonOperator(expression.operator));
|
|
let index = code.len() - 1;
|
|
let left_operand_size = match expression.left_operand {
|
|
ast::ArithmeticOperand::Literal(literal) => {
|
|
// code.push((OpCode::Literal(literal), 0));
|
|
code.push(OpCode::Literal(literal));
|
|
1_usize
|
|
}
|
|
ast::ArithmeticOperand::Expression(expression) => {
|
|
compile_expression(Expression::Arithmetic(*expression), code)
|
|
}
|
|
};
|
|
let right_operand_size = match expression.right_operand {
|
|
ast::ArithmeticOperand::Literal(literal) => {
|
|
// code.push((OpCode::Literal(literal), 0));
|
|
code.push(OpCode::Literal(literal));
|
|
1_usize
|
|
}
|
|
ast::ArithmeticOperand::Expression(expression) => {
|
|
compile_expression(Expression::Arithmetic(*expression), code)
|
|
}
|
|
};
|
|
// code[index].1 = left_operand_size + 1;
|
|
return 1 + left_operand_size + right_operand_size;
|
|
}
|
|
},
|
|
Expression::Arithmetic(expression) => match expression {
|
|
ast::ArithmeticExpression::Variable(variable) => {
|
|
// code.push((OpCode::Variable(variable), 0));
|
|
code.push(OpCode::Variable(variable));
|
|
return 1;
|
|
}
|
|
ast::ArithmeticExpression::UnaryArithmeticConjunction(expression) => {
|
|
let expression = *expression;
|
|
// code.push((OpCode::UnaryArithmeticOperator(expression.operator), 0));
|
|
code.push(OpCode::UnaryArithmeticOperator(expression.operator));
|
|
let operand_size = match expression.operand {
|
|
ast::ArithmeticOperand::Literal(literal) => {
|
|
// code.push((OpCode::Literal(literal), 0));
|
|
code.push(OpCode::Literal(literal));
|
|
1_usize
|
|
}
|
|
ast::ArithmeticOperand::Expression(expression) => {
|
|
compile_expression(Expression::Arithmetic(*expression), code)
|
|
}
|
|
};
|
|
return 1 + operand_size;
|
|
}
|
|
ast::ArithmeticExpression::BinaryArithmeticConjunction(expression) => {
|
|
// code.push((OpCode::BinaryArithmeticOperator(expression.operator), 0));
|
|
code.push(OpCode::BinaryArithmeticOperator(expression.operator));
|
|
let index = code.len() - 1;
|
|
let left_operand_size = match expression.left_operand {
|
|
ast::ArithmeticOperand::Literal(literal) => {
|
|
// code.push((OpCode::Literal(literal), 0));
|
|
code.push(OpCode::Literal(literal));
|
|
1_usize
|
|
}
|
|
ast::ArithmeticOperand::Expression(expression) => {
|
|
compile_expression(Expression::Arithmetic(*expression), code)
|
|
}
|
|
};
|
|
let right_operand_size = match expression.right_operand {
|
|
ast::ArithmeticOperand::Literal(literal) => {
|
|
// code.push((OpCode::Literal(literal), 0));
|
|
code.push(OpCode::Literal(literal));
|
|
1_usize
|
|
}
|
|
ast::ArithmeticOperand::Expression(expression) => {
|
|
compile_expression(Expression::Arithmetic(*expression), code)
|
|
}
|
|
};
|
|
// code[index].1 = left_operand_size + 1;
|
|
return 1 + left_operand_size + right_operand_size;
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn boolean_compilation_test() {
|
|
let expression =
|
|
parsing::parse_relation("and (and (< x 1) (!= 1 y)) (< ham x k)").expect("Valid AST");
|
|
println!("{:?}", compile_boolean(expression));
|
|
}
|
|
|
|
#[test]
|
|
fn test_vm_simple() {
|
|
let expression = parsing::parse_relation("(= (ham (^ x y)) (ham 1))");
|
|
if let Err(e) = expression {
|
|
println!("{}", e);
|
|
panic!();
|
|
}
|
|
let code = compile_boolean(expression.unwrap());
|
|
let mut stack = VmStack::from_code(&code);
|
|
let mut vm = Vm::load(
|
|
&code,
|
|
Registers::load(0b_1011, 0b_1111, 10, 2, 3),
|
|
&mut stack,
|
|
);
|
|
let output = vm.run().output_bool().unwrap();
|
|
assert_eq!(output, true);
|
|
let mut vm = Vm::load(
|
|
&code,
|
|
Registers::load(0b_1011, 0b_1110, 10, 2, 3),
|
|
&mut stack,
|
|
);
|
|
let output = vm.run().output_bool().unwrap();
|
|
assert_eq!(output, false);
|
|
}
|
|
|
|
#[test]
|
|
fn test_vm_contrived() {
|
|
let expression = parsing::parse_relation("and (= (ham (^ x y)) (ham (+ 3 x))) (> (* x y) 5)");
|
|
if let Err(e) = expression {
|
|
println!("{}", e);
|
|
panic!();
|
|
}
|
|
let code = compile_boolean(expression.unwrap());
|
|
let mut stack = VmStack::from_code(&code);
|
|
|
|
let as_rust = |x: u64, y: u64| ((x + 3).count_ones() == (x ^ y).count_ones()) && (x * y > 5);
|
|
|
|
for x in 0..(1 << 4) {
|
|
for y in 0..(1 << 4) {
|
|
let mut vm = Vm::load(&code, Registers::load(x, y, 10, 2, 3), &mut stack);
|
|
let output = vm.run().output_bool().unwrap();
|
|
let expected = as_rust(x, y);
|
|
assert_eq!(output, expected);
|
|
}
|
|
}
|
|
}
|