quantum_queries/src/vm/mod.rs

337 lines
14 KiB
Rust

pub mod parsing;
use parsing::ast;
type Bytecode = (OpCode, usize);
#[derive(Debug)]
pub struct VmCode {
code: Vec<Bytecode>,
}
#[derive(Debug)]
pub struct Vm<'code> {
code: &'code VmCode,
registers: Registers,
}
#[derive(Debug)]
pub struct Registers {
x: u128,
y: u128,
n: u32,
p: u32,
k: u32,
xpop: u32,
ypop: u32,
npop: u32,
ppop: u32,
kpop: u32,
}
#[derive(Debug)]
pub enum VmOutput {
Boolean(bool),
Arithmetic(f64),
}
#[derive(Debug)]
enum OpCode {
UnaryBooleanOperator(ast::UnaryBooleanOperator),
BinaryArithmeticOperator(ast::BinaryArithmeticOperator),
ComparisonOperator(ast::ComparisonOperator),
UnaryArithmeticOperator(ast::UnaryArithmeticOperator),
BinaryBooleanOperator(ast::BinaryBooleanOperator),
Variable(ast::Variable),
Literal(usize),
}
enum Expression {
Boolean(ast::BooleanExpression),
Arithmetic(ast::ArithmeticExpression),
}
impl Registers {
pub fn load(x: u128, y: u128, n: u32, p: u32, k: u32) -> Self {
Self {
x,
y,
n,
p,
k,
xpop: x.count_ones(),
ypop: y.count_ones(),
npop: n.count_ones(),
ppop: p.count_ones(),
kpop: k.count_ones(),
}
}
}
impl<'code> Vm<'code> {
pub fn compile_boolean(expression: ast::BooleanExpression) -> VmCode {
let mut code = Vec::<Bytecode>::new();
compile_expression(Expression::Boolean(expression), &mut code);
VmCode { code }
}
pub fn compile_arithmetic(expression: ast::ArithmeticExpression) -> VmCode {
let mut code = Vec::<Bytecode>::new();
compile_expression(Expression::Arithmetic(expression), &mut code);
VmCode { code }
}
pub fn load(code: &'code VmCode, registers: Registers) -> Self {
Vm { code, registers }
}
pub fn run(&self) -> VmOutput {
// Alias for convenience
let code = &self.code.code;
let registers = &self.registers;
struct Resolver<'f> {
f: &'f dyn Fn(&Resolver, usize) -> VmOutput,
}
let resolver = Resolver {
f: &|resolver: &Resolver, op_index: usize| -> VmOutput {
let (opcode, offset) = &code[op_index];
match opcode {
OpCode::UnaryBooleanOperator(op) => match op {
ast::UnaryBooleanOperator::Not => {
let value = (resolver.f)(resolver, op_index + 1);
match value {
VmOutput::Arithmetic(_) => panic!("Bad bytecode! Unary operator is followed by arithmetic resolution."),
VmOutput::Boolean(value) => VmOutput::Boolean(!value),
}
}
},
OpCode::ComparisonOperator(op) => {
let left_operand = match (resolver.f)(resolver, op_index + 1) {
VmOutput::Arithmetic(value) => value,
VmOutput::Boolean(_) => panic!(
"Bad bytecode! Left operand of arithmetic operator is boolean."
),
};
let right_operand = match (resolver.f)(resolver, op_index + offset) {
VmOutput::Arithmetic(value) => value,
VmOutput::Boolean(_) => panic!(
"Bad bytecode! Right operand of arithmetic operator is boolean."
),
};
match op {
ast::ComparisonOperator::GreaterOrEqual => {
VmOutput::Boolean(left_operand >= right_operand)
}
ast::ComparisonOperator::LessOrEqual => {
VmOutput::Boolean(left_operand <= right_operand)
}
ast::ComparisonOperator::GreaterThan => {
VmOutput::Boolean(left_operand > right_operand)
}
ast::ComparisonOperator::LessThan => {
VmOutput::Boolean(left_operand < right_operand)
}
ast::ComparisonOperator::NotEqual => {
VmOutput::Boolean(left_operand != right_operand)
}
ast::ComparisonOperator::Equal => {
VmOutput::Boolean(left_operand == right_operand)
}
}
}
OpCode::BinaryArithmeticOperator(op) => {
let left_operand = match (resolver.f)(resolver, op_index + 1) {
VmOutput::Arithmetic(value) => value,
VmOutput::Boolean(_) => panic!(
"Bad bytecode! Left operand of arithmetic operator is boolean."
),
};
let right_operand = match (resolver.f)(resolver, op_index + offset) {
VmOutput::Arithmetic(value) => value,
VmOutput::Boolean(_) => panic!(
"Bad bytecode! Right operand of arithmetic operator is boolean."
),
};
match op {
ast::BinaryArithmeticOperator::Times => {
VmOutput::Arithmetic(left_operand * right_operand)
}
ast::BinaryArithmeticOperator::Divide => {
VmOutput::Arithmetic(left_operand / right_operand)
}
ast::BinaryArithmeticOperator::Plus => {
VmOutput::Arithmetic(left_operand + right_operand)
}
ast::BinaryArithmeticOperator::Minus => {
VmOutput::Arithmetic(left_operand - right_operand)
}
ast::BinaryArithmeticOperator::Xor => VmOutput::Arithmetic(
((left_operand as u128) ^ (right_operand as u128)) as f64,
),
}
}
OpCode::UnaryArithmeticOperator(op) => {
let value = match (resolver.f)(resolver, op_index + 1) {
VmOutput::Arithmetic(value) => value,
VmOutput::Boolean(_) => panic!("Bad bytecode! Arithmetic unary operator followed by boolean resolution."),
};
match op {
ast::UnaryArithmeticOperator::Negative => VmOutput::Arithmetic(-value),
ast::UnaryArithmeticOperator::Ham => {
VmOutput::Arithmetic((value as u128).count_ones() as f64)
}
}
}
OpCode::BinaryBooleanOperator(op) => {
let left_operand = match (resolver.f)(resolver, op_index + 1) {
VmOutput::Boolean(value) => value,
VmOutput::Arithmetic(_) => panic!(
"Bad bytecode! Left operand of boolean operator is arithmetic."
),
};
let right_operand = match (resolver.f)(resolver, op_index + offset) {
VmOutput::Boolean(value) => value,
VmOutput::Arithmetic(_) => panic!(
"Bad bytecode! Right operand of boolean operator is arithmetic."
),
};
VmOutput::Boolean(match op {
ast::BinaryBooleanOperator::And => left_operand & right_operand,
ast::BinaryBooleanOperator::Or => left_operand | right_operand,
ast::BinaryBooleanOperator::Xor => left_operand ^ right_operand,
})
}
OpCode::Variable(var) => VmOutput::Arithmetic(match var {
ast::Variable::X => registers.x,
ast::Variable::Y => registers.y,
ast::Variable::N => registers.n as u128,
ast::Variable::P => registers.p as u128,
ast::Variable::K => registers.k as u128,
} as f64),
OpCode::Literal(lit) => VmOutput::Arithmetic(*lit as f64),
}
},
};
(resolver.f)(&resolver, 0)
}
}
/// Returns how many code-points (Bytecode units) were emitted
fn compile_expression(expression: Expression, code: &mut Vec<Bytecode>) -> usize {
match expression {
Expression::Boolean(expression) => match expression {
ast::BooleanExpression::BinaryBooleanConjunction(expression) => {
let expression = *expression;
code.push((OpCode::BinaryBooleanOperator(expression.operator), 0));
let index = code.len() - 1;
let left_operand_size =
compile_expression(Expression::Boolean(expression.left_operand), code);
let right_operand_size =
compile_expression(Expression::Boolean(expression.right_operand), code);
code[index].1 = left_operand_size + 1;
return 1 + left_operand_size + right_operand_size;
}
ast::BooleanExpression::UnaryBooleanConjunction(expression) => {
let expression = *expression;
code.push((OpCode::UnaryBooleanOperator(expression.operator), 0));
let operand_size =
compile_expression(Expression::Boolean(expression.operand), code);
return 1 + operand_size;
}
ast::BooleanExpression::ComparisonConjunction(expression) => {
code.push((OpCode::ComparisonOperator(expression.operator), 0));
let index = code.len() - 1;
let left_operand_size = match expression.left_operand {
ast::ArithmeticOperand::Literal(literal) => {
code.push((OpCode::Literal(literal), 0));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
let right_operand_size = match expression.right_operand {
ast::ArithmeticOperand::Literal(literal) => {
code.push((OpCode::Literal(literal), 0));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
code[index].1 = left_operand_size + 1;
return 1 + left_operand_size + right_operand_size;
}
},
Expression::Arithmetic(expression) => match expression {
ast::ArithmeticExpression::Variable(variable) => {
code.push((OpCode::Variable(variable), 0));
return 1;
}
ast::ArithmeticExpression::UnaryArithmeticConjunction(expression) => {
let expression = *expression;
code.push((OpCode::UnaryArithmeticOperator(expression.operator), 0));
let operand_size = match expression.operand {
ast::ArithmeticOperand::Literal(literal) => {
code.push((OpCode::Literal(literal), 0));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
return 1 + operand_size;
}
ast::ArithmeticExpression::BinaryArithmeticConjunction(expression) => {
code.push((OpCode::BinaryArithmeticOperator(expression.operator), 0));
let index = code.len() - 1;
let left_operand_size = match expression.left_operand {
ast::ArithmeticOperand::Literal(literal) => {
code.push((OpCode::Literal(literal), 0));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
let right_operand_size = match expression.right_operand {
ast::ArithmeticOperand::Literal(literal) => {
code.push((OpCode::Literal(literal), 0));
1_usize
}
ast::ArithmeticOperand::Expression(expression) => {
compile_expression(Expression::Arithmetic(*expression), code)
}
};
code[index].1 = left_operand_size + 1;
return 1 + left_operand_size + right_operand_size;
}
},
}
}
#[test]
fn boolean_compilation_test() {
let expression =
parsing::parse_relation("and and < x 1 != 1 y < ham x k").expect("Valid AST");
println!("{:?}", Vm::compile_boolean(expression));
}
#[test]
fn run_test() {
let expression = parsing::parse_relation("(= (ham (^ x y)) 1)");
if let Err(e) = expression {
println!("{}", e);
panic!();
}
let code = Vm::compile_boolean(expression.unwrap());
let vm = Vm::load(&code, Registers::load(0b_1001, 0b_1111, 10, 2, 3));
println!("{:?}", vm.run());
}