From 641aa878bbe7c67da066dc4f9d59738415cb4ba2 Mon Sep 17 00:00:00 2001 From: Sofia Date: Sat, 21 Mar 2026 20:16:00 +0200 Subject: [PATCH] Add UserData generic everywhere --- examples/test.rs | 25 +++++++--- src/lib.rs | 23 +++++---- src/vm/mod.rs | 127 +++++++++++++++++++++++++---------------------- src/vm/value.rs | 111 +++++++++++++++++++++++++---------------- 4 files changed, 167 insertions(+), 119 deletions(-) diff --git a/examples/test.rs b/examples/test.rs index 4df8b64..dc3ea9c 100644 --- a/examples/test.rs +++ b/examples/test.rs @@ -1,14 +1,22 @@ use ferrite_lua::{ compile, - vm::{RuntimeError, VirtualMachine, value}, + vm::{ + RuntimeError, VirtualMachine, + value::{self, AsValue}, + }, }; static TEST: &str = include_str!("../examples/test.lua"); +type UserData = (); + #[derive(Debug, PartialEq, Eq)] pub struct Max; -impl value::RustFunction for Max { - fn execute(&self, parameters: Vec) -> Result, RuntimeError> { +impl value::RustFunction for Max { + fn execute( + &self, + parameters: Vec>, + ) -> Result>, RuntimeError> { let lhs = parameters.get(0).cloned().unwrap_or(value::Value::Nil); let rhs = parameters.get(1).cloned().unwrap_or(value::Value::Nil); match lhs.lt(&rhs)? { @@ -23,8 +31,11 @@ impl value::RustFunction for Max { } #[derive(Debug, PartialEq, Eq)] pub struct Print; -impl value::RustFunction for Print { - fn execute(&self, parameters: Vec) -> Result, RuntimeError> { +impl value::RustFunction for Print { + fn execute( + &self, + parameters: Vec>, + ) -> Result>, RuntimeError> { println!( "{}", parameters @@ -46,8 +57,8 @@ fn main() { let mut vm = VirtualMachine::default(); - vm.set_global("max".into(), Max.into()).unwrap(); - vm.set_global("print".into(), Print.into()).unwrap(); + vm.set_global("max".into(), Max.as_value()).unwrap(); + vm.set_global("print".into(), Print.as_value()).unwrap(); dbg!(&compilation_unit); diff --git a/src/lib.rs b/src/lib.rs index 3631d52..405646a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,13 @@ //! Usage example: //! ```rust -//! use ferrite_lua::{compile, vm::{self, value}}; +//! use ferrite_lua::{compile, vm::{self, value::{self, AsValue}}}; +//! +//! type UserData = (); //! //! #[derive(Debug, PartialEq, Eq)] //! pub struct Print; -//! impl value::RustFunction for Print { -//! fn execute(&self, parameters: Vec) -> Result, vm::RuntimeError> { +//! impl value::RustFunction for Print { +//! fn execute(&self, parameters: Vec>) -> Result>, vm::RuntimeError> { //! println!("{:?}", parameters); //! Ok(Vec::new()) //! } @@ -18,7 +20,7 @@ //! let compilation_unit = compile("print(\"hello\")", None).unwrap(); //! //! let mut vm = vm::VirtualMachine::default(); -//! vm.set_global("print".into(), Print.into()).unwrap(); +//! vm.set_global("print".into(), Print.as_value()).unwrap(); //! //! let mut runner = compilation_unit.with_virtual_machine(&mut vm).execute(); //! @@ -70,7 +72,10 @@ impl Debug for CompilationUnit { } impl CompilationUnit { - pub fn with_virtual_machine<'a>(&self, vm: &'a mut VirtualMachine) -> ExecutionUnit<'a> { + pub fn with_virtual_machine<'a, UserData: Clone>( + &self, + vm: &'a mut VirtualMachine, + ) -> ExecutionUnit<'a, UserData> { let chunk_id = vm.new_prototype(vm::Prototype { instructions: self.instructions.clone(), parameters: 0, @@ -85,13 +90,13 @@ impl CompilationUnit { } } -pub struct ExecutionUnit<'a> { +pub struct ExecutionUnit<'a, UserData: Clone> { chunk_id: u32, - vm: &'a mut VirtualMachine, + vm: &'a mut VirtualMachine, } -impl<'a> ExecutionUnit<'a> { - pub fn execute(self) -> ClosureRunner { +impl<'a, UserData: Clone> ExecutionUnit<'a, UserData> { + pub fn execute(self) -> ClosureRunner { let closure = self.vm.create_closure(self.chunk_id); closure.run(Vec::new()) } diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 81c34fb..2e0561e 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -12,8 +12,11 @@ pub mod value; #[derive(Debug)] pub struct SetMetatable; -impl RustFunction for SetMetatable { - fn execute(&self, parameters: Vec) -> Result, RuntimeError> { +impl RustFunction for SetMetatable { + fn execute( + &self, + parameters: Vec>, + ) -> Result>, RuntimeError> { let table = parameters .get(0) .ok_or(RuntimeError::NotTable(Value::Nil))?; @@ -215,35 +218,35 @@ impl Debug for Instruction { } #[derive(Error, Debug)] -pub enum RuntimeError { +pub enum RuntimeError { #[error("Unable to perform {0:?} operator between {1:?} and {2:?}")] - InvalidOperands(BinaryOperator, Value, Value), + InvalidOperands(BinaryOperator, Value, Value), #[error("Unable to call metamethod {0} for values {1:?} and {2:?}")] - InvalidBinop(String, Value, Value), + InvalidBinop(String, Value, Value), #[error("Unable to call metamethod {0} for value {1:?}")] - InvalidUnop(String, Value), + InvalidUnop(String, Value), #[error("Metatable is not a table: {0:?}")] - MetatableNotTable(Value), + MetatableNotTable(Value), #[error("Metafunction not found: {0} for {1:?}")] - MetafunctionNotFound(String, Value), + MetafunctionNotFound(String, Value), #[error("Metafunction is not callable: {0:?}")] - MetafunctionNotCallable(Value), + MetafunctionNotCallable(Value), #[error("Unable to perform {0:?} operator to {1:?}")] - InvalidOperand(UnaryOperator, Value), + InvalidOperand(UnaryOperator, Value), #[error("Tried calling a non-function: {0:?}")] - TriedCallingNonFunction(Value), + TriedCallingNonFunction(Value), #[error("Global not found: {0:?}")] GlobalNotFound(Option), #[error("Unable to index tables with {0:?}")] - InvalidTableIndex(Value), + InvalidTableIndex(Value), #[error("Value is not a table: {0:?}")] - NotTable(Value), + NotTable(Value), #[error("Value is not coercable to a float: {0:?}")] - NotFloatable(Value), + NotFloatable(Value), #[error("Value is not coercable to bits: {0:?}")] - NotBittable(Value), + NotBittable(Value), #[error("Value does not have a length: {0:?}")] - NotLengthable(Value), + NotLengthable(Value), #[error("{0}")] Custom(String), } @@ -255,19 +258,19 @@ pub struct Prototype { } #[derive(Debug, Clone)] -pub struct VirtualMachine { - pub(super) environment: Table, +pub struct VirtualMachine { + pub(super) environment: Table, pub(super) constants: Vec, pub(super) prototypes: HashMap, pub(super) proto_counter: u32, } -impl Default for VirtualMachine { +impl Default for VirtualMachine { fn default() -> Self { - let environment: Table = Default::default(); + let environment: Table = Default::default(); environment.borrow_mut().insert( IndexableValue::String("SETMETATABLE".into()), - SetMetatable.into(), + Value::RustFunction(Rc::new(RefCell::new(SetMetatable))), ); Self { environment, @@ -278,7 +281,7 @@ impl Default for VirtualMachine { } } -impl VirtualMachine { +impl VirtualMachine { pub fn new_prototype(&mut self, instructions: Prototype) -> u32 { let proto_id = self.proto_counter; self.proto_counter += 1; @@ -286,7 +289,7 @@ impl VirtualMachine { proto_id } - pub(crate) fn create_closure(&self, prototype: u32) -> Closure { + pub(crate) fn create_closure(&self, prototype: u32) -> Closure { Closure { vm: self.clone(), prototype, @@ -295,14 +298,18 @@ impl VirtualMachine { } } - pub fn set_global(&mut self, key: Value, value: Value) -> Result<(), RuntimeError> { + pub fn set_global( + &mut self, + key: Value, + value: Value, + ) -> Result<(), RuntimeError> { self.environment .borrow_mut() .insert(key.as_indexable()?, value); Ok(()) } - pub fn get_globals(&self) -> HashMap { + pub fn get_globals(&self) -> HashMap> { let mut values = HashMap::new(); for (key, value) in self.environment.borrow().iter() { values.insert(key.clone(), value.clone()); @@ -313,15 +320,15 @@ impl VirtualMachine { } #[derive(Debug, Clone)] -pub struct Closure { - vm: VirtualMachine, +pub struct Closure { + vm: VirtualMachine, pub(crate) prototype: u32, - environment: Table, - upvalues: HashMap>>, + environment: Table, + upvalues: HashMap>>>, } -impl Closure { - pub fn run(&self, params: Vec) -> ClosureRunner { +impl Closure { + pub fn run(&self, params: Vec>) -> ClosureRunner { let mut stack = HashMap::new(); for (i, param) in params.iter().enumerate() { stack.insert(i as u16, Rc::new(RefCell::new(param.clone()))); @@ -340,7 +347,7 @@ impl Closure { } } - fn get_upvalue(&mut self, idx: u16) -> StackValue { + fn get_upvalue(&mut self, idx: u16) -> StackValue { let value = self.upvalues.get(&idx); if let Some(value) = value { match &*value.borrow() { @@ -352,25 +359,25 @@ impl Closure { } } -pub struct ClosureRunner { - closure: Closure, +pub struct ClosureRunner { + closure: Closure, program_counter: usize, - stack: HashMap>>, - inner: Option>, + stack: HashMap>>>, + inner: Option>>, function_register: u16, return_registers: Vec, top: u16, - parameters: Vec, + parameters: Vec>, to_close_upvalues: u16, } #[derive(Clone, Debug)] -enum StackValue { - Value(Value), +enum StackValue { + Value(Value), } -impl ClosureRunner { - fn set_stack(&mut self, idx: u16, value: StackValue) { +impl ClosureRunner { + fn set_stack(&mut self, idx: u16, value: StackValue) { if let Some(stack_slot) = self.stack.get_mut(&idx) { match value { StackValue::Value(value) => { @@ -386,7 +393,7 @@ impl ClosureRunner { } } - fn get_stack(&mut self, idx: u16) -> StackValue { + fn get_stack(&mut self, idx: u16) -> StackValue { let value = self.stack.get(&idx); if let Some(value) = value { match &*value.borrow() { @@ -397,7 +404,7 @@ impl ClosureRunner { } } - fn close_upvalues(&self) -> HashMap>> { + fn close_upvalues(&self) -> HashMap>>> { let highest_upvalue = self .closure .upvalues @@ -417,7 +424,7 @@ impl ClosureRunner { upvalues } - pub fn execute(&mut self, unit: &CompilationUnit) -> ClosureRunner { + pub fn execute(&mut self, unit: &CompilationUnit) -> ClosureRunner { let mut vm = self.closure.vm.clone(); vm.constants = unit.constants.clone(); let proto_id = vm.new_prototype(Prototype { @@ -437,7 +444,7 @@ impl ClosureRunner { closure.run(Vec::new()) } - pub fn next(&mut self) -> Result>, RuntimeError> { + pub fn next(&mut self) -> Result>>, RuntimeError> { if let Some(inner) = &mut self.inner { match inner.next() { Ok(ret_values) => { @@ -815,7 +822,7 @@ impl ClosureRunner { Value::Table { metatable, .. } => { let mut metamethod_params = vec![value.clone()]; metamethod_params.extend(params); - let ret_values = + let ret_values: Vec> = self.call_metamethod(&metatable, "__call", metamethod_params)??; if *ret_len != 0 { @@ -1195,10 +1202,10 @@ impl ClosureRunner { fn result_or_metamethod_unop( &self, - value: Result, + value: Result, RuntimeError>, metamethod: &str, - param: Value, - ) -> Result { + param: Value, + ) -> Result, RuntimeError> { let metatable = extract_metatable(¶m); if let Some(metatable) = metatable { @@ -1213,11 +1220,11 @@ impl ClosureRunner { fn result_or_metamethod_binop( &self, - value: Result, + value: Result, RuntimeError>, metamethod: &str, - lhs: Value, - rhs: Value, - ) -> Result { + lhs: Value, + rhs: Value, + ) -> Result, RuntimeError> { let metatable = extract_metatable(&lhs).or(extract_metatable(&rhs)); if let Some(metatable) = metatable { @@ -1232,10 +1239,10 @@ impl ClosureRunner { fn call_metamethod( &self, - metatable: &Table, + metatable: &Table, metamethod: &str, - params: Vec, - ) -> Result, RuntimeError>, RuntimeError> { + params: Vec>, + ) -> Result>, RuntimeError>, RuntimeError> { if let Some(value) = metatable .borrow() .get(&IndexableValue::String(metamethod.to_owned())) @@ -1268,7 +1275,7 @@ impl ClosureRunner { } } - fn lhs_and_rhs(&self, lhs: &u16, rhs: &u16) -> (Value, Value) { + fn lhs_and_rhs(&self, lhs: &u16, rhs: &u16) -> (Value, Value) { ( self.stack .get(lhs) @@ -1282,13 +1289,15 @@ impl ClosureRunner { } } -fn extract_metatable(value: &Value) -> Option<&Table> { +fn extract_metatable(value: &Value) -> Option<&Table> { match value { Value::Table { metatable, .. } => Some(metatable), _ => None, } } -fn extract_ret_value(values: Result, RuntimeError>) -> Result { +fn extract_ret_value( + values: Result>, RuntimeError>, +) -> Result, RuntimeError> { values.map(|v| v.into_iter().next().unwrap()) } diff --git a/src/vm/value.rs b/src/vm/value.rs index 852f6a6..a15fb6f 100644 --- a/src/vm/value.rs +++ b/src/vm/value.rs @@ -82,16 +82,24 @@ impl From<&LuaBool> for LuaInteger { } } -pub trait RustFunction: Debug { - fn execute(&self, parameters: Vec) -> Result, RuntimeError>; +pub trait RustFunction: Debug { + fn execute( + &self, + parameters: Vec>, + ) -> Result>, RuntimeError>; fn as_indexable(&self) -> String; } -impl From for Value { - fn from(value: T) -> Self { - Self::RustFunction(Rc::new(RefCell::new(value))) +pub trait AsValue { + fn as_value(self) -> Value; +} + +impl + 'static> AsValue for T { + fn as_value(self) -> Value { + Value::RustFunction(Rc::new(RefCell::new(self))) } } + #[derive(Clone, Hash, PartialEq, Eq)] pub enum Constant { String(String), @@ -102,7 +110,7 @@ pub enum Constant { } impl Constant { - pub fn as_value(self) -> Value { + pub fn as_value(self) -> Value { match self { Constant::String(value) => Value::String(value), Constant::Float(vmfloat) => Value::Float(vmfloat), @@ -125,23 +133,26 @@ impl Debug for Constant { } } -pub type Table = Rc>; -pub type TableMap = HashMap; +pub type Table = Rc>>; +pub type TableMap = HashMap>; #[derive(Clone)] -pub enum Value { +pub enum Value { String(String), Float(VMFloat), Integer(LuaInteger), Boolean(LuaBool), - RustFunction(Rc>), - Function(Closure), + RustFunction(Rc>>), + Function(Closure), Nil, - Table { contents: Table, metatable: Table }, + Table { + contents: Table, + metatable: Table, + }, } -impl Value { - pub fn as_indexable(self) -> Result { +impl Value { + pub fn as_indexable(self) -> Result> { match self { Value::String(value) => Ok(IndexableValue::String(value)), Value::Float(value) => Ok(IndexableValue::Float(value)), @@ -152,11 +163,11 @@ impl Value { } Value::Function(closure) => Ok(IndexableValue::Function(closure.prototype)), Value::Nil => Err(RuntimeError::InvalidTableIndex(self)), - Value::Table { contents, .. } => Ok(IndexableValue::Table(contents.as_ptr())), + Value::Table { contents, .. } => Ok(IndexableValue::Table(contents.as_ptr() as u64)), } } - pub fn as_float(&self) -> Result { + pub fn as_float(&self) -> Result> { match self { Value::Float(vmfloat) => Ok(vmfloat.lua_number()), Value::Integer(lua_integer) => Ok(lua_integer.into()), @@ -165,7 +176,7 @@ impl Value { } } - pub fn as_integer(&self) -> Result { + pub fn as_integer(&self) -> Result> { match self { Value::Integer(lua_integer) => Ok(*lua_integer), Value::Boolean(lua_boolean) => Ok(LuaInteger(lua_boolean.0 as i64)), @@ -173,7 +184,7 @@ impl Value { } } - pub fn as_bits(&self) -> Result { + pub fn as_bits(&self) -> Result> { match self { Value::Float(vmfloat) => Ok(LuaInteger(vmfloat.lua_number().0 as i64)), Value::Integer(lua_integer) => Ok(LuaInteger(lua_integer.0 as i64)), @@ -184,13 +195,13 @@ impl Value { } } -impl From<&str> for Value { +impl From<&str> for Value { fn from(value: &str) -> Self { Value::String(value.to_owned()) } } -impl Display for Value { +impl Display for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Value::String(value) => Display::fmt(value, f), @@ -207,8 +218,11 @@ impl Display for Value { } } -impl Value { - pub fn concat(&self, other: &Value) -> Result { +impl Value { + pub fn concat( + &self, + other: &Value, + ) -> Result, RuntimeError> { match (self, other) { ( Value::String(_) | Value::Integer(_) | Value::Boolean(_), @@ -229,7 +243,7 @@ impl Value { } } - pub fn add(&self, other: &Value) -> Result { + pub fn add(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(_) | Value::Boolean(_), Value::Integer(_) | Value::Boolean(_)) => { let res = LuaInteger(self.as_integer()?.0 + other.as_integer()?.0); @@ -250,7 +264,7 @@ impl Value { } } - pub fn mult(&self, other: &Value) -> Result { + pub fn mult(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(_) | Value::Boolean(_), Value::Integer(_) | Value::Boolean(_)) => { let res = LuaInteger(self.as_integer()?.0 * other.as_integer()?.0); @@ -271,7 +285,7 @@ impl Value { } } - pub fn div(&self, other: &Value) -> Result { + pub fn div(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(_) | Value::Boolean(_), Value::Integer(_) | Value::Boolean(_)) => { let res = LuaFloat(self.as_integer()?.0 as f64 / other.as_integer()?.0 as f64); @@ -292,7 +306,7 @@ impl Value { } } - pub fn idiv(&self, other: &Value) -> Result { + pub fn idiv(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(_) | Value::Boolean(_), Value::Integer(_) | Value::Boolean(_)) => { let res = LuaInteger(self.as_integer()?.0 / other.as_integer()?.0); @@ -313,7 +327,10 @@ impl Value { } } - pub fn r#mod(&self, other: &Value) -> Result { + pub fn r#mod( + &self, + other: &Value, + ) -> Result, RuntimeError> { match (self, other) { (Value::Integer(_) | Value::Boolean(_), Value::Integer(_) | Value::Boolean(_)) => { let res = LuaInteger(self.as_integer()?.0 % other.as_integer()?.0); @@ -334,7 +351,7 @@ impl Value { } } - pub fn exp(&self, other: &Value) -> Result { + pub fn exp(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(_) | Value::Boolean(_), Value::Integer(_) | Value::Boolean(_)) => { let res = LuaInteger(self.as_integer()?.0.pow(other.as_integer()?.0 as u32)); @@ -355,7 +372,7 @@ impl Value { } } - pub fn band(&self, other: &Value) -> Result { + pub fn band(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { ( Value::Float(_) | Value::Integer(_) | Value::Boolean(_) | Value::Nil, @@ -372,7 +389,7 @@ impl Value { } } - pub fn bor(&self, other: &Value) -> Result { + pub fn bor(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { ( Value::Float(_) | Value::Integer(_) | Value::Boolean(_) | Value::Nil, @@ -389,7 +406,7 @@ impl Value { } } - pub fn bxor(&self, other: &Value) -> Result { + pub fn bxor(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { ( Value::Float(_) | Value::Integer(_) | Value::Boolean(_) | Value::Nil, @@ -406,7 +423,10 @@ impl Value { } } - pub fn bsleft(&self, other: &Value) -> Result { + pub fn bsleft( + &self, + other: &Value, + ) -> Result, RuntimeError> { match (self, other) { ( Value::Float(_) | Value::Integer(_) | Value::Boolean(_) | Value::Nil, @@ -423,7 +443,10 @@ impl Value { } } - pub fn bsright(&self, other: &Value) -> Result { + pub fn bsright( + &self, + other: &Value, + ) -> Result, RuntimeError> { match (self, other) { ( Value::Float(_) | Value::Integer(_) | Value::Boolean(_) | Value::Nil, @@ -440,7 +463,7 @@ impl Value { } } - pub fn eq(&self, other: &Value) -> Result { + pub fn eq(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(lhs), Value::Integer(rhs)) => { Ok(Value::Boolean(LuaBool(lhs.0 == rhs.0))) @@ -457,7 +480,7 @@ impl Value { } } - pub fn lt(&self, other: &Value) -> Result { + pub fn lt(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(lhs), Value::Integer(rhs)) => { Ok(Value::Boolean(LuaBool(lhs.0 < rhs.0))) @@ -474,7 +497,7 @@ impl Value { } } - pub fn lte(&self, other: &Value) -> Result { + pub fn lte(&self, other: &Value) -> Result, RuntimeError> { match (self, other) { (Value::Integer(lhs), Value::Integer(rhs)) => { Ok(Value::Boolean(LuaBool(lhs.0 <= rhs.0))) @@ -491,7 +514,7 @@ impl Value { } } - pub fn unm(&self) -> Result { + pub fn unm(&self) -> Result, RuntimeError> { match self { Value::Integer(lhs) => { let res = LuaInteger(-lhs.0); @@ -512,7 +535,7 @@ impl Value { } } - pub fn len(&self) -> Result { + pub fn len(&self) -> Result, RuntimeError> { match self { Value::String(value) => Ok(Self::Integer(LuaInteger(value.len() as i64))), Value::Table { contents, .. } => { @@ -553,7 +576,7 @@ impl Value { } } - pub fn and(&self, other: &Value) -> Result { + pub fn and(&self, other: &Value) -> Result, RuntimeError> { if self.is_truthy() { Ok(self.clone()) } else { @@ -561,7 +584,7 @@ impl Value { } } - pub fn or(&self, other: &Value) -> Result { + pub fn or(&self, other: &Value) -> Result, RuntimeError> { if self.is_truthy() { Ok(self.clone()) } else { @@ -569,7 +592,7 @@ impl Value { } } - pub fn not(&self) -> Result { + pub fn not(&self) -> Result, RuntimeError> { if self.is_truthy() { Ok(Value::Boolean(LuaBool(false))) } else { @@ -591,7 +614,7 @@ impl Value { } } -impl Debug for Value { +impl Debug for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Value::Float(arg0) => f.debug_tuple("Float").field(&arg0.lua_number()).finish(), @@ -622,11 +645,11 @@ pub enum IndexableValue { Bool(LuaBool), RustFunction(String), Function(u32), - Table(*mut TableMap), + Table(u64), } impl From<&str> for IndexableValue { fn from(value: &str) -> Self { - Value::from(value).as_indexable().unwrap() + Value::<()>::from(value).as_indexable().unwrap() } }