use thiserror::Error; use std::{ cell::{RefCell, RefMut}, collections::HashMap, fmt::Debug, hash::Hash, rc::Rc, }; use crate::{ ast::{BinaryOperator, LuaNumber, UnaryOperator}, compile, }; pub type VMNumber = u64; #[derive(Clone, Hash, PartialEq, Eq)] pub enum Constant { String(String), Number(VMNumber), } impl Debug for Constant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::String(arg0) => f.debug_tuple("String").field(arg0).finish(), Self::Number(arg0) => f .debug_tuple("Number") .field(&LuaNumber::from_bits(*arg0)) .finish(), } } } #[derive(Clone, Copy)] pub enum Instruction { /// R(A) := R(B) Move(u16, u16), /// R(A) := K(Bx) LoadK(u16, u32), /// R(A), ..., R(B) := nil LoadNil(u16, u16), /// G[K(Bx)] := R(A) SetGlobal(u16, u32), /// R(A) := G[K(Bx)] GetGlobal(u16, u32), /// R(A) := U[B] GetUpVal(u16, u16), /// U[B] := R(A) SetUpVal(u16, u16), /// R(A)[R(B)] := R(C) SetTable(u16, u16, u16), /// R(A) := {} NewTable(u16), /// R(A) := R(B) + R(C) Add(u16, u16, u16), /// R(A) := -R(B) Unm(u16, u16), /// R(A) := R(B) == R(C) Equal(u16, u16, u16), /// R(A) := R(B) < R(C) LessThan(u16, u16, u16), /// R(A) := R(B) <= R(C) LessThanOrEqual(u16, u16, u16), /// R(A) := R(B) or R(C) Or(u16, u16, u16), /// R(A) := R(B) and R(C) And(u16, u16, u16), /// PC += sAx Jmp(i32), /// if (R(B) <=> C) then R(A) := R(B) else PC++ Test(u16, u16, u16), /// [func] [params.len()] [ret_regs.len()] /// R(A), ... R(A+C-2) := R(A)(R(A+1), ... R(A+B-1)) Call(u16, u16, u16), /// return R(A), ... , R(B) Return(u16, u16), /// close stack variables up to R(A) Close(u16), /// R(A) := closure(KPROTO[Bx], R(A), ..., R(A+n)) Closure(u16, u32), } impl Debug for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Instruction::Move(arg0, arg1) => write!(f, "MOVE {} {}", arg0, arg1), Instruction::LoadK(arg0, arg1) => write!(f, "LOADK {} {}", arg0, arg1), Instruction::SetGlobal(arg0, arg1) => write!(f, "SETGLOBAL {} {}", arg0, arg1), Instruction::GetGlobal(arg0, arg1) => write!(f, "GETGLOBAL {} {}", arg0, arg1), Instruction::GetUpVal(arg0, arg1) => write!(f, "GETUPVAL {} {}", arg0, arg1), Instruction::SetUpVal(arg0, arg1) => write!(f, "SETUPVAL {} {}", arg0, arg1), Instruction::SetTable(arg0, arg1, arg2) => { write!(f, "SETTABLE {} {} {}", arg0, arg1, arg2) } Instruction::NewTable(arg0) => write!(f, "NEWTABLE {}", arg0), Instruction::Jmp(arg0) => write!(f, "JMP {}", arg0), Instruction::Test(arg0, arg1, arg2) => write!(f, "TEST {} {} {}", arg0, arg1, arg2), Instruction::Call(arg0, arg1, arg2) => write!(f, "CALL {} {} {}", arg0, arg1, arg2), Instruction::Close(arg0) => write!(f, "CLOSE {}", arg0), Instruction::Closure(arg0, arg1) => write!(f, "CLOSURE {} {}", arg0, arg1), Instruction::Return(arg0, arg1) => write!(f, "RETURN {} {}", arg0, arg1), Instruction::Equal(arg0, arg1, arg2) => write!(f, "EQ {} {} {}", arg0, arg1, arg2), Instruction::LessThan(arg0, arg1, arg2) => write!(f, "LT {} {} {}", arg0, arg1, arg2), Instruction::LessThanOrEqual(arg0, arg1, arg2) => { write!(f, "LE {} {} {}", arg0, arg1, arg2) } Instruction::Add(arg0, arg1, arg2) => write!(f, "ADD {} {} {}", arg0, arg1, arg2), Instruction::LoadNil(arg0, arg1) => write!(f, "LOADNIL {} {}", arg0, arg1), Instruction::Unm(arg0, arg1) => write!(f, "UNM {} {}", arg0, arg1), Instruction::Or(arg0, arg1, arg2) => write!(f, "OR {} {} {}", arg0, arg1, arg2), Instruction::And(arg0, arg1, arg2) => write!(f, "AND {} {} {}", arg0, arg1, arg2), } } } #[derive(Error, Debug)] pub enum RuntimeError { #[error("Unable to perform {0:?} operator between {1:?} and {2:?}")] InvalidOperands(BinaryOperator, Value, Value), #[error("Unable to perform {0:?} operator to {1:?}")] InvalidOperand(UnaryOperator, Value), #[error("Tried calling a non-function: {0:?}")] TriedCallingNonFunction(Value), #[error("Global not found: {0:?}")] GlobalNotFound(Option), #[error("Unable to index tables with {0:?}")] InvalidTableIndex(Value), #[error("Value is not a table: {0:?}")] NotTable(Value), #[error("{0}")] Custom(String), } #[derive(Debug, Clone, Default)] pub struct Environment { pub globals: HashMap>>, } impl Environment { pub fn get_global(&mut self, key: &Constant) -> Option { let value = self.globals.get_mut(key)?; Some(match &*value.borrow() { _ => StackValue::Value(value.borrow().clone()), }) } pub fn set_global(&mut self, key: Constant, value: StackValue) { if let Some(existing) = self.globals.get_mut(&key) { match value { StackValue::Value(value) => { *existing.borrow_mut() = value; } StackValue::Ref(reference) => { *existing = reference; } } } else { match value { StackValue::Value(value) => { self.globals.insert(key, Rc::new(RefCell::new(value))); } StackValue::Ref(reference) => { self.globals.insert(key, reference); } } } } } #[derive(Clone)] pub enum Value { String(String), Number(VMNumber), RustFunction(Rc>), Function(Closure), Nil, Table(Rc>>), } impl Value { pub fn as_indexable(self) -> Result { match self { Value::String(value) => Ok(IndexableValue::String(value)), Value::Number(value) => Ok(IndexableValue::Number(value)), Value::RustFunction(value) => { Ok(IndexableValue::RustFunction(value.borrow().as_indexable())) } Value::Function(closure) => Ok(IndexableValue::Function(closure.prototype)), Value::Nil => Err(RuntimeError::InvalidTableIndex(self)), Value::Table(_) => Err(RuntimeError::InvalidTableIndex(self)), } } } #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum IndexableValue { String(String), Number(VMNumber), RustFunction(String), Function(u32), } impl Value { pub fn add(&self, other: &Value) -> Result { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) + LuaNumber::from_bits(*rhs); Ok(Value::Number(res.to_bits())) } _ => Err(RuntimeError::InvalidOperands( BinaryOperator::Add, self.clone(), other.clone(), )), } } pub fn eq(&self, other: &Value) -> Result { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = lhs == rhs; Ok(Value::Number((res as u64 as f64).to_bits())) } _ => Err(RuntimeError::InvalidOperands( BinaryOperator::Equal, self.clone(), other.clone(), )), } } pub fn lt(&self, other: &Value) -> Result { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) < LuaNumber::from_bits(*rhs); Ok(Value::Number((res as u64 as f64).to_bits())) } _ => Err(RuntimeError::InvalidOperands( BinaryOperator::LessThan, self.clone(), other.clone(), )), } } pub fn lte(&self, other: &Value) -> Result { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) <= LuaNumber::from_bits(*rhs); Ok(Value::Number((res as u64 as f64).to_bits())) } _ => Err(RuntimeError::InvalidOperands( BinaryOperator::LessThanOrEqual, self.clone(), other.clone(), )), } } pub fn unm(&self) -> Result { match self { Value::Number(lhs) => { let res = -LuaNumber::from_bits(*lhs); Ok(Value::Number(res.to_bits())) } _ => Err(RuntimeError::InvalidOperand( UnaryOperator::Negation, self.clone(), )), } } pub fn and(&self, other: &Value) -> Result { match (self, other) { (Value::Nil, _) | (_, Value::Nil) => Ok(Value::Nil), (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) > 0. && LuaNumber::from_bits(*rhs) > 0.; Ok(Value::Number((res as u64 as f64).to_bits())) } (Value::Number(value), _) | (_, Value::Number(value)) => { let res = LuaNumber::from_bits(*value) > 0.; Ok(Value::Number((res as u64 as f64).to_bits())) } _ => Ok(Value::Nil), } } pub fn or(&self, other: &Value) -> Result { match (self, other) { (Value::Nil, value) | (value, Value::Nil) => Ok(value.clone()), (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) > 0. || LuaNumber::from_bits(*rhs) > 0.; Ok(Value::Number((res as u64 as f64).to_bits())) } (Value::Number(value), other) => { if LuaNumber::from_bits(*value) > 0. { Ok(Value::Number(*value)) } else { Ok(other.clone()) } } (value, _) => Ok(value.clone()), _ => Ok(Value::Nil), } } } impl Debug for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Value::Number(arg0) => f .debug_tuple("Number") .field(&LuaNumber::from_bits(*arg0)) .finish(), Value::String(value) => f.debug_tuple("String").field(value).finish(), Value::RustFunction(arg0) => f.debug_tuple("RustFunction").field(arg0).finish(), Value::Function(closure) => f .debug_tuple(&format!("Function({})", closure.prototype)) .finish(), Value::Nil => write!(f, "Nil"), Value::Table(hash_map) => { let mut table = f.debug_struct("Table"); for (key, value) in hash_map.borrow().iter() { table.field(&format!("{:?}", key), value); } table.finish() } } } } pub trait RustFunction: Debug { fn execute(&self, parameters: Vec) -> Result, RuntimeError>; fn as_indexable(&self) -> String; } #[derive(Debug, Clone)] pub struct VirtualMachine { pub environment: Rc>, pub constants: Vec, pub prototypes: HashMap>, pub proto_counter: u32, } impl VirtualMachine { pub fn new_prototype(&mut self, instructions: Vec) -> u32 { let proto_id = self.proto_counter; self.proto_counter += 1; self.prototypes.insert(proto_id, instructions); proto_id } } impl VirtualMachine { pub fn create_closure(&self, prototype: u32) -> Closure { Closure { vm: self.clone(), prototype, environment: self.environment.clone(), upvalues: HashMap::new(), } } } #[derive(Debug, Clone)] pub struct Closure { pub vm: VirtualMachine, pub prototype: u32, pub environment: Rc>, pub upvalues: HashMap>>, } impl Closure { pub fn run(&self, params: Vec) -> ClosureRunner { let mut stack = HashMap::new(); for (i, param) in params.iter().enumerate() { stack.insert(i as u16, Rc::new(RefCell::new(param.clone()))); } ClosureRunner { closure: self.clone(), program_counter: 0, stack, inner: None, function_register: 0, return_registers: Vec::new(), top: 0, } } fn get_upvalue(&mut self, idx: u16) -> StackValue { let value = self.upvalues.get(&idx); if let Some(value) = value { match &*value.borrow() { _ => StackValue::Value(value.borrow().clone()), } } else { StackValue::Value(Value::Nil) } } } pub struct ClosureRunner { pub closure: Closure, pub program_counter: usize, pub stack: HashMap>>, pub inner: Option>, pub function_register: u16, pub return_registers: Vec, pub top: u16, } #[derive(Clone, Debug)] pub enum StackValue { Value(Value), Ref(Rc>), } impl ClosureRunner { pub fn set_stack(&mut self, idx: u16, value: StackValue) { if let Some(stack_slot) = self.stack.get_mut(&idx) { match value { StackValue::Value(value) => { *stack_slot.borrow_mut() = value; } StackValue::Ref(ref_cell) => *stack_slot = ref_cell, } } else { match value { StackValue::Value(value) => { self.stack.insert(idx, Rc::new(RefCell::new(value))); } StackValue::Ref(reference) => { self.stack.insert(idx, reference); } } } } pub fn get_stack(&mut self, idx: u16) -> StackValue { let value = self.stack.get(&idx); if let Some(value) = value { match &*value.borrow() { _ => StackValue::Value(value.borrow().clone()), } } else { StackValue::Value(Value::Nil) } } fn close_upvalues(&self) -> HashMap>> { let highest_upvalue = self .closure .upvalues .iter() .map(|(v, _)| *v) .max() .unwrap_or(0); let mut upvalues = self.closure.upvalues.clone(); for (reg, value) in &self.stack { upvalues.insert(reg + highest_upvalue + 1, value.clone()); } upvalues } pub fn execute( &mut self, instructions: Vec, state: compile::State, constants: Vec, ) -> ClosureRunner { let mut vm = self.closure.vm.clone(); vm.constants = constants; let proto_id = vm.new_prototype(instructions); for prototype in state.prototypes { vm.new_prototype(prototype); } let closure = Closure { vm, prototype: proto_id, environment: self.closure.environment.clone(), upvalues: self.close_upvalues(), }; closure.run(Vec::new()) } pub fn next(&mut self) -> Result>, RuntimeError> { if let Some(inner) = &mut self.inner { match inner.next() { Ok(ret_values) => { if let Some(ret_values) = ret_values { self.inner = None; if self.return_registers.len() == 0 { for (i, value) in ret_values.iter().enumerate() { self.stack.insert( self.function_register + i as u16 + 1, Rc::new(RefCell::new(value.clone())), ); } self.top = self.function_register + ret_values.len() as u16; } for (i, reg) in self.return_registers.iter().enumerate() { self.stack.insert( *reg, Rc::new(RefCell::new( ret_values.get(i).cloned().unwrap_or(Value::Nil), )), ); } } else { return Ok(None); } } Err(e) => return Err(e), } } let instructions = self .closure .vm .prototypes .get(&self.closure.prototype) .unwrap() .clone(); let constants = self.closure.vm.constants.clone(); if let Some(instr) = instructions.get(self.program_counter) { match instr { Instruction::Move(a, b) => { let b = self.get_stack(*b); self.set_stack(*a, b); } Instruction::LoadK(reg, constant) => { self.set_stack( *reg, StackValue::Value(match constants.get(*constant as usize).unwrap() { Constant::String(value) => Value::String(value.clone()), Constant::Number(value) => Value::Number(*value), }), ); } Instruction::LoadNil(from_reg, to_reg) => { for i in *from_reg..=*to_reg { self.stack.insert(i, Rc::new(RefCell::new(Value::Nil))); } } Instruction::SetGlobal(reg, global) => { let value = self.get_stack(*reg); dbg!(&value); self.closure .environment .borrow_mut() .set_global(constants.get(*global as usize).unwrap().clone(), value); dbg!(&self.closure.environment); } Instruction::GetGlobal(reg, global) => { let glob = self .closure .environment .borrow_mut() .get_global(constants.get(*global as usize).unwrap()); if let Some(global) = glob { self.set_stack(*reg, global); } else { return Err(RuntimeError::GlobalNotFound( constants.get(*global as usize).cloned(), )); } } Instruction::GetUpVal(reg, upvalreg) => { let upvalue = self.closure.get_upvalue(*upvalreg); self.set_stack(*reg, upvalue); } Instruction::SetUpVal(upvalreg, reg) => { *self.closure.upvalues.get(upvalreg).unwrap().borrow_mut() = self .stack .get(reg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); } Instruction::SetTable(tablereg, indexreg, valuereg) => { let table = self.stack.get(tablereg); match table { Some(value) => { let mut table = value.borrow_mut(); if let Value::Table(table) = &mut *table { let index_value = self .stack .get(indexreg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) .as_indexable()?; let value = self .stack .get(valuereg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); match value { Value::Nil => { table.borrow_mut().remove(&index_value); } _ => { table.borrow_mut().insert(index_value, value); } } } else { return Err(RuntimeError::NotTable(table.clone())); } } None => todo!(), } } Instruction::NewTable(reg) => { self.set_stack( *reg, StackValue::Value(Value::Table(Rc::new(RefCell::new(HashMap::new())))), ); } Instruction::Jmp(b) => { self.program_counter = (self.program_counter as i32 + *b) as usize } Instruction::Test(a, b, c) => { let is_true = match self .stack .get(b) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) { Value::Number(val) => (LuaNumber::from_bits(val) as u16) == *c, _ => false, }; if is_true { let b = self.get_stack(*b); self.set_stack(*a, b); } else { self.program_counter += 1; } } Instruction::Call(func_reg, param_len, ret_len) => { let param_start_func_reg = if *param_len == 0 { self.function_register } else { *func_reg }; let param_len = if *param_len == 0 { self.top - self.top.min(param_start_func_reg) } else { *param_len }; self.function_register = *func_reg; let mut params = Vec::new(); for i in 0..param_len { params.push( self.stack .get(&(param_start_func_reg + i + 1)) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) .clone(), ); } let value = self .stack .get(func_reg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); match value { Value::RustFunction(func) => { let ret_values = func.borrow_mut().execute(params)?; if *ret_len != 0 { for i in 0..=(*ret_len - 2) { self.set_stack( *func_reg + i, StackValue::Value( ret_values .get(i as usize) .cloned() .unwrap_or(Value::Nil), ), ); } } else { for (i, value) in ret_values.iter().enumerate() { self.set_stack( *func_reg + i as u16 + 1, StackValue::Value(value.clone()), ); } self.top = *func_reg + ret_values.len() as u16; } } Value::Function(closure) => { self.return_registers = Vec::new(); if *ret_len != 0 { for i in 0..=(*ret_len - 2) { self.return_registers.push(*func_reg + i); } } self.inner = Some(Box::new(closure.run(params))); } _ => return Err(RuntimeError::TriedCallingNonFunction(value.clone())), } } Instruction::Close(_) => {} Instruction::Closure(reg, protok) => { self.set_stack( *reg, StackValue::Value(Value::Function(Closure { vm: self.closure.vm.clone(), prototype: *protok, environment: self.closure.environment.clone(), upvalues: self.close_upvalues(), })), ); } Instruction::Return(reg_start, reg_end) => { self.program_counter += 1; let mut ret_values = Vec::new(); let (reg_start, reg_end) = if *reg_end == 0 { if self.function_register > 0 && self.top > 0 { (self.function_register + 1, self.top) } else { (*reg_start, *reg_end) } } else { (*reg_start, *reg_end) }; for i in reg_start..=reg_end { ret_values.push( self.stack .get(&i) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil), ); } return Ok(Some(ret_values)); } Instruction::Add(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, StackValue::Value(lhs.add(&rhs)?)); } Instruction::Equal(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, StackValue::Value(lhs.eq(&rhs)?)); } Instruction::LessThan(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, StackValue::Value(lhs.lt(&rhs)?)); } Instruction::LessThanOrEqual(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, StackValue::Value(lhs.lte(&rhs)?)); } Instruction::Unm(res, reg) => { self.set_stack( *res, StackValue::Value( self.stack .get(reg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) .unm()?, ), ); } Instruction::Or(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, StackValue::Value(lhs.or(&rhs)?)); } Instruction::And(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, StackValue::Value(lhs.and(&rhs)?)); } }; self.program_counter += 1; Ok(None) } else { Ok(Some(Vec::new())) } } }