use std::{cell::RefCell, collections::HashMap, fmt::Debug, rc::Rc}; use crate::ast::LuaNumber; pub type VMNumber = u64; #[derive(Clone, Hash, PartialEq, Eq)] pub enum Constant { String(String), Number(VMNumber), } impl Debug for Constant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::String(arg0) => f.debug_tuple("String").field(arg0).finish(), Self::Number(arg0) => f .debug_tuple("Number") .field(&LuaNumber::from_bits(*arg0)) .finish(), } } } #[derive(Clone, Copy)] pub enum Instruction { /// R(A) := R(B) Move(u16, u16), /// R(A) := K(Bx) LoadK(u16, u32), /// R(A), ..., R(B) := nil LoadNil(u16, u16), /// G[K(Bx)] := R(A) SetGlobal(u16, u32), /// R(A) := G[K(Bx)] GetGlobal(u16, u32), /// R(A) := U[B] GetUpVal(u16, u16), /// U[B] := R(A) SetUpVal(u16, u16), /// R(A) := R(B) + R(C) Add(u16, u16, u16), /// R(A) := R(B) < R(C) LessThan(u16, u16, u16), /// R(A) := R(B) < R(C) LessThanOrEqual(u16, u16, u16), /// PC += sAx Jmp(i32), /// if (R(B) <=> C) then R(A) := R(B) else PC++ Test(u16, u16, u16), /// [func] [params.len()] [ret_regs.len()] /// R(A), ... R(A+C-2) := R(A)(R(A+1), ... R(A+B-1)) Call(u16, u16, u16), /// return R(A), ... , R(B) Return(u16, u16), /// close stack variables up to R(A) Close(u16), /// R(A) := closure(KPROTO[Bx], R(A), ..., R(A+n)) Closure(u16, u32), } impl Debug for Instruction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Instruction::Move(arg0, arg1) => write!(f, "MOVE {} {}", arg0, arg1), Instruction::LoadK(arg0, arg1) => write!(f, "LOADK {} {}", arg0, arg1), Instruction::SetGlobal(arg0, arg1) => write!(f, "SETGLOBAL {} {}", arg0, arg1), Instruction::GetGlobal(arg0, arg1) => write!(f, "GETGLOBAL {} {}", arg0, arg1), Instruction::GetUpVal(arg0, arg1) => write!(f, "GETUPVAL {} {}", arg0, arg1), Instruction::SetUpVal(arg0, arg1) => write!(f, "SETUPVAL {} {}", arg0, arg1), Instruction::Jmp(arg0) => write!(f, "JMP {}", arg0), Instruction::Test(arg0, arg1, arg2) => write!(f, "TEST {} {} {}", arg0, arg1, arg2), Instruction::Call(arg0, arg1, arg2) => write!(f, "CALL {} {} {}", arg0, arg1, arg2), Instruction::Close(arg0) => write!(f, "CLOSE {}", arg0), Instruction::Closure(arg0, arg1) => write!(f, "CLOSURE {} {}", arg0, arg1), Instruction::Return(arg0, arg1) => write!(f, "RETURN {} {}", arg0, arg1), Instruction::LessThan(arg0, arg1, arg2) => write!(f, "LT {} {} {}", arg0, arg1, arg2), Instruction::LessThanOrEqual(arg0, arg1, arg2) => { write!(f, "LE {} {} {}", arg0, arg1, arg2) } Instruction::Add(arg0, arg1, arg2) => write!(f, "ADD {} {} {}", arg0, arg1, arg2), Instruction::LoadNil(arg0, arg1) => write!(f, "LOADNIL {} {}", arg0, arg1), } } } #[derive(Debug, Clone, Default)] pub struct Environment { pub globals: HashMap, } #[derive(Clone)] pub enum Value { Number(VMNumber), RustFunction(Rc>), Function(Closure), Nil, } impl Value { pub fn add(&self, other: &Value) -> Value { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) + LuaNumber::from_bits(*rhs); Value::Number(res.to_bits()) } _ => Value::Nil, } } pub fn lt(&self, other: &Value) -> Value { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) < LuaNumber::from_bits(*rhs); Value::Number((res as u64 as f64).to_bits()) } _ => Value::Nil, } } pub fn lte(&self, other: &Value) -> Value { match (self, other) { (Value::Number(lhs), Value::Number(rhs)) => { let res = LuaNumber::from_bits(*lhs) <= LuaNumber::from_bits(*rhs); Value::Number((res as u64 as f64).to_bits()) } _ => Value::Nil, } } } impl Debug for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Number(arg0) => f .debug_tuple("Number") .field(&LuaNumber::from_bits(*arg0)) .finish(), Self::RustFunction(arg0) => f.debug_tuple("RustFunction").field(arg0).finish(), Self::Function(_) => f.debug_tuple("Function").finish(), Self::Nil => write!(f, "Nil"), } } } pub trait RustFunction: Debug { fn execute(&self, parameters: Vec) -> Vec; } #[derive(Debug, Clone)] pub struct VirtualMachine { pub environment: Rc>, pub constants: Vec, pub prototypes: HashMap>, } impl VirtualMachine { pub fn create_closure(&self, prototype: u32) -> Closure { Closure { vm: self.clone(), prototype, environment: self.environment.clone(), upvalues: HashMap::new(), } } } #[derive(Debug, Clone)] pub struct Closure { pub vm: VirtualMachine, pub prototype: u32, pub environment: Rc>, pub upvalues: HashMap>>, } impl Closure { pub fn run(&self, params: Vec) -> ClosureRunner { let mut stack = HashMap::new(); for (i, param) in params.iter().enumerate() { stack.insert(i as u16, Rc::new(RefCell::new(param.clone()))); } ClosureRunner { closure: self.clone(), program_counter: 0, stack, inner: None, function_register: 0, return_registers: Vec::new(), top: 0, } } } pub struct ClosureRunner { pub closure: Closure, pub program_counter: usize, pub stack: HashMap>>, pub inner: Option>, pub function_register: u16, pub return_registers: Vec, pub top: u16, } impl ClosureRunner { pub fn set_stack(&mut self, idx: u16, value: Value) { if let Some(stack_slot) = self.stack.get(&idx) { *stack_slot.borrow_mut() = value; } else { self.stack.insert(idx, Rc::new(RefCell::new(value))); } } pub fn next(&mut self) -> Option> { if let Some(inner) = &mut self.inner { if let Some(ret_values) = inner.next() { self.inner = None; if self.return_registers.len() == 0 { for (i, value) in ret_values.iter().enumerate() { self.stack.insert( self.function_register + i as u16 + 1, Rc::new(RefCell::new(value.clone())), ); } self.top = self.function_register + ret_values.len() as u16; } for (i, reg) in self.return_registers.iter().enumerate() { self.stack.insert( *reg, Rc::new(RefCell::new( ret_values.get(i).cloned().unwrap_or(Value::Nil), )), ); } } else { return None; } } let instructions = self .closure .vm .prototypes .get(&self.closure.prototype) .unwrap() .clone(); let constants = &self.closure.vm.constants; if let Some(instr) = instructions.get(self.program_counter) { match instr { Instruction::Move(a, b) => { self.set_stack( *a, self.stack .get(b) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) .clone(), ); } Instruction::LoadK(reg, constant) => { self.set_stack( *reg, match constants.get(*constant as usize).unwrap() { Constant::String(_) => todo!(), Constant::Number(value) => Value::Number(*value), }, ); } Instruction::LoadNil(from_reg, to_reg) => { for i in *from_reg..=*to_reg { self.stack.insert(i, Rc::new(RefCell::new(Value::Nil))); } } Instruction::SetGlobal(reg, global) => { self.closure.environment.borrow_mut().globals.insert( constants.get(*global as usize).unwrap().clone(), self.stack .get(reg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) .clone(), ); } Instruction::GetGlobal(reg, global) => { if let Some(global) = self .closure .environment .borrow() .globals .get(constants.get(*global as usize).unwrap()) { self.stack .insert(*reg, Rc::new(RefCell::new(global.clone()))); } else { todo!("Global not found: {:?}", constants.get(*global as usize)) } } Instruction::GetUpVal(reg, upvalreg) => { self.stack .insert(*reg, self.closure.upvalues.get(upvalreg).unwrap().clone()); } Instruction::SetUpVal(upvalreg, reg) => { *self.closure.upvalues.get(upvalreg).unwrap().borrow_mut() = self .stack .get(reg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); } Instruction::Jmp(b) => { self.program_counter = (self.program_counter as i32 + *b) as usize } Instruction::Test(a, b, c) => { let is_true = match self .stack .get(b) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) { Value::Number(val) => (LuaNumber::from_bits(val) as u16) == *c, _ => false, }; if is_true { self.set_stack( *a, self.stack .get(b) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil), ); } else { self.program_counter += 1; } } Instruction::Call(func_reg, param_len, ret_len) => { let param_start_func_reg = if *param_len == 0 { self.function_register } else { *func_reg }; let param_len = if *param_len == 0 { self.top - self.top.min(param_start_func_reg) } else { *param_len }; self.function_register = *func_reg; let mut params = Vec::new(); for i in 0..param_len { params.push( self.stack .get(&(param_start_func_reg + i + 1)) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil) .clone(), ); } let value = self .stack .get(func_reg) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); match value { Value::RustFunction(func) => { let ret_values = func.borrow_mut().execute(params); if *ret_len != 0 { for i in 0..=(*ret_len - 2) { self.set_stack( *func_reg + i, ret_values.get(i as usize).cloned().unwrap_or(Value::Nil), ); } } else { for (i, value) in ret_values.iter().enumerate() { self.set_stack(*func_reg + i as u16 + 1, value.clone()); } self.top = *func_reg + ret_values.len() as u16; } } Value::Function(closure) => { self.return_registers = Vec::new(); if *ret_len != 0 { for i in 0..=(*ret_len - 2) { self.return_registers.push(*func_reg + i); } } self.inner = Some(Box::new(closure.run(params))); } _ => { if *ret_len > 0 { for i in 0..=(*ret_len - 2) { self.set_stack(*func_reg + i, Value::Nil); } } } } } Instruction::Close(_) => {} Instruction::Closure(reg, protok) => { let highest_upvalue = self .closure .upvalues .iter() .map(|(v, _)| *v) .max() .unwrap_or(0); let mut upvalues = self.closure.upvalues.clone(); for (reg, value) in &self.stack { upvalues.insert(reg + highest_upvalue + 1, value.clone()); } self.set_stack( *reg, Value::Function(Closure { vm: self.closure.vm.clone(), prototype: *protok, environment: self.closure.environment.clone(), upvalues, }), ); } Instruction::Return(reg_start, reg_end) => { self.program_counter += 1; let mut ret_values = Vec::new(); let (reg_start, reg_end) = if *reg_end == 0 { if self.function_register > 0 && self.top > 0 { (self.function_register + 1, self.top) } else { (*reg_start, *reg_end) } } else { (*reg_start, *reg_end) }; for i in reg_start..=reg_end { ret_values.push( self.stack .get(&i) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil), ); } return Some(ret_values); } Instruction::Add(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, lhs.add(&rhs)); } Instruction::LessThan(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, lhs.lt(&rhs)); } Instruction::LessThanOrEqual(res, lhs, rhs) => { let lhs = self .stack .get(lhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); let rhs = self .stack .get(rhs) .map(|v| v.borrow().clone()) .unwrap_or(Value::Nil); self.set_stack(*res, lhs.lte(&rhs)); } }; self.program_counter += 1; None } else { Some(Vec::new()) } } }