ferrite-lua/src/vm.rs
2026-03-16 16:13:11 +02:00

495 lines
18 KiB
Rust

use std::{cell::RefCell, collections::HashMap, fmt::Debug, rc::Rc};
use crate::ast::LuaNumber;
pub type VMNumber = u64;
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Constant {
String(String),
Number(VMNumber),
}
impl Debug for Constant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::String(arg0) => f.debug_tuple("String").field(arg0).finish(),
Self::Number(arg0) => f
.debug_tuple("Number")
.field(&LuaNumber::from_bits(*arg0))
.finish(),
}
}
}
#[derive(Clone, Copy)]
pub enum Instruction {
/// R(A) := R(B)
Move(u16, u16),
/// R(A) := K(Bx)
LoadK(u16, u32),
/// R(A), ..., R(B) := nil
LoadNil(u16, u16),
/// G[K(Bx)] := R(A)
SetGlobal(u16, u32),
/// R(A) := G[K(Bx)]
GetGlobal(u16, u32),
/// R(A) := U[B]
GetUpVal(u16, u16),
/// U[B] := R(A)
SetUpVal(u16, u16),
/// R(A) := R(B) + R(C)
Add(u16, u16, u16),
/// R(A) := R(B) < R(C)
LessThan(u16, u16, u16),
/// R(A) := R(B) < R(C)
LessThanOrEqual(u16, u16, u16),
/// PC += sAx
Jmp(i32),
/// if (R(B) <=> C) then R(A) := R(B) else PC++
Test(u16, u16, u16),
/// [func] [params.len()] [ret_regs.len()]
/// R(A), ... R(A+C-2) := R(A)(R(A+1), ... R(A+B-1))
Call(u16, u16, u16),
/// return R(A), ... , R(B)
Return(u16, u16),
/// close stack variables up to R(A)
Close(u16),
/// R(A) := closure(KPROTO[Bx], R(A), ..., R(A+n))
Closure(u16, u32),
}
impl Debug for Instruction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Instruction::Move(arg0, arg1) => write!(f, "MOVE {} {}", arg0, arg1),
Instruction::LoadK(arg0, arg1) => write!(f, "LOADK {} {}", arg0, arg1),
Instruction::SetGlobal(arg0, arg1) => write!(f, "SETGLOBAL {} {}", arg0, arg1),
Instruction::GetGlobal(arg0, arg1) => write!(f, "GETGLOBAL {} {}", arg0, arg1),
Instruction::GetUpVal(arg0, arg1) => write!(f, "GETUPVAL {} {}", arg0, arg1),
Instruction::SetUpVal(arg0, arg1) => write!(f, "SETUPVAL {} {}", arg0, arg1),
Instruction::Jmp(arg0) => write!(f, "JMP {}", arg0),
Instruction::Test(arg0, arg1, arg2) => write!(f, "TEST {} {} {}", arg0, arg1, arg2),
Instruction::Call(arg0, arg1, arg2) => write!(f, "CALL {} {} {}", arg0, arg1, arg2),
Instruction::Close(arg0) => write!(f, "CLOSE {}", arg0),
Instruction::Closure(arg0, arg1) => write!(f, "CLOSURE {} {}", arg0, arg1),
Instruction::Return(arg0, arg1) => write!(f, "RETURN {} {}", arg0, arg1),
Instruction::LessThan(arg0, arg1, arg2) => write!(f, "LT {} {} {}", arg0, arg1, arg2),
Instruction::LessThanOrEqual(arg0, arg1, arg2) => {
write!(f, "LE {} {} {}", arg0, arg1, arg2)
}
Instruction::Add(arg0, arg1, arg2) => write!(f, "ADD {} {} {}", arg0, arg1, arg2),
Instruction::LoadNil(arg0, arg1) => write!(f, "LOADNIL {} {}", arg0, arg1),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Environment {
pub globals: HashMap<Constant, Value>,
}
#[derive(Clone)]
pub enum Value {
Number(VMNumber),
RustFunction(Rc<RefCell<dyn RustFunction>>),
Function(Closure),
Nil,
}
impl Value {
pub fn add(&self, other: &Value) -> Value {
match (self, other) {
(Value::Number(lhs), Value::Number(rhs)) => {
let res = LuaNumber::from_bits(*lhs) + LuaNumber::from_bits(*rhs);
Value::Number(res.to_bits())
}
_ => Value::Nil,
}
}
pub fn lt(&self, other: &Value) -> Value {
match (self, other) {
(Value::Number(lhs), Value::Number(rhs)) => {
let res = LuaNumber::from_bits(*lhs) < LuaNumber::from_bits(*rhs);
Value::Number((res as u64 as f64).to_bits())
}
_ => Value::Nil,
}
}
pub fn lte(&self, other: &Value) -> Value {
match (self, other) {
(Value::Number(lhs), Value::Number(rhs)) => {
let res = LuaNumber::from_bits(*lhs) <= LuaNumber::from_bits(*rhs);
Value::Number((res as u64 as f64).to_bits())
}
_ => Value::Nil,
}
}
}
impl Debug for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Number(arg0) => f
.debug_tuple("Number")
.field(&LuaNumber::from_bits(*arg0))
.finish(),
Self::RustFunction(arg0) => f.debug_tuple("RustFunction").field(arg0).finish(),
Self::Function(_) => f.debug_tuple("Function").finish(),
Self::Nil => write!(f, "Nil"),
}
}
}
pub trait RustFunction: Debug {
fn execute(&self, parameters: Vec<Value>) -> Vec<Value>;
}
#[derive(Debug, Clone)]
pub struct VirtualMachine {
pub environment: Rc<RefCell<Environment>>,
pub constants: Vec<Constant>,
pub prototypes: HashMap<u32, Vec<Instruction>>,
}
impl VirtualMachine {
pub fn create_closure(&self, prototype: u32) -> Closure {
Closure {
vm: self.clone(),
prototype,
environment: self.environment.clone(),
upvalues: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct Closure {
pub vm: VirtualMachine,
pub prototype: u32,
pub environment: Rc<RefCell<Environment>>,
pub upvalues: HashMap<u16, Rc<RefCell<Value>>>,
}
impl Closure {
pub fn run(&self, params: Vec<Value>) -> ClosureRunner {
let mut stack = HashMap::new();
for (i, param) in params.iter().enumerate() {
stack.insert(i as u16, Rc::new(RefCell::new(param.clone())));
}
ClosureRunner {
closure: self.clone(),
program_counter: 0,
stack,
inner: None,
function_register: 0,
return_registers: Vec::new(),
top: 0,
}
}
}
pub struct ClosureRunner {
pub closure: Closure,
pub program_counter: usize,
pub stack: HashMap<u16, Rc<RefCell<Value>>>,
pub inner: Option<Box<ClosureRunner>>,
pub function_register: u16,
pub return_registers: Vec<u16>,
pub top: u16,
}
impl ClosureRunner {
pub fn set_stack(&mut self, idx: u16, value: Value) {
if let Some(stack_slot) = self.stack.get(&idx) {
*stack_slot.borrow_mut() = value;
} else {
self.stack.insert(idx, Rc::new(RefCell::new(value)));
}
}
pub fn next(&mut self) -> Option<Vec<Value>> {
if let Some(inner) = &mut self.inner {
if let Some(ret_values) = inner.next() {
self.inner = None;
if self.return_registers.len() == 0 {
for (i, value) in ret_values.iter().enumerate() {
self.stack.insert(
self.function_register + i as u16 + 1,
Rc::new(RefCell::new(value.clone())),
);
}
self.top = self.function_register + ret_values.len() as u16;
}
for (i, reg) in self.return_registers.iter().enumerate() {
self.stack.insert(
*reg,
Rc::new(RefCell::new(
ret_values.get(i).cloned().unwrap_or(Value::Nil),
)),
);
}
} else {
return None;
}
}
let instructions = self
.closure
.vm
.prototypes
.get(&self.closure.prototype)
.unwrap()
.clone();
let constants = &self.closure.vm.constants;
if let Some(instr) = instructions.get(self.program_counter) {
match instr {
Instruction::Move(a, b) => {
self.set_stack(
*a,
self.stack
.get(b)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil)
.clone(),
);
}
Instruction::LoadK(reg, constant) => {
self.set_stack(
*reg,
match constants.get(*constant as usize).unwrap() {
Constant::String(_) => todo!(),
Constant::Number(value) => Value::Number(*value),
},
);
}
Instruction::LoadNil(from_reg, to_reg) => {
for i in *from_reg..=*to_reg {
self.stack.insert(i, Rc::new(RefCell::new(Value::Nil)));
}
}
Instruction::SetGlobal(reg, global) => {
self.closure.environment.borrow_mut().globals.insert(
constants.get(*global as usize).unwrap().clone(),
self.stack
.get(reg)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil)
.clone(),
);
}
Instruction::GetGlobal(reg, global) => {
if let Some(global) = self
.closure
.environment
.borrow()
.globals
.get(constants.get(*global as usize).unwrap())
{
self.stack
.insert(*reg, Rc::new(RefCell::new(global.clone())));
} else {
todo!("Global not found: {:?}", constants.get(*global as usize))
}
}
Instruction::GetUpVal(reg, upvalreg) => {
self.stack
.insert(*reg, self.closure.upvalues.get(upvalreg).unwrap().clone());
}
Instruction::SetUpVal(upvalreg, reg) => {
*self.closure.upvalues.get(upvalreg).unwrap().borrow_mut() = self
.stack
.get(reg)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
}
Instruction::Jmp(b) => {
self.program_counter = (self.program_counter as i32 + *b) as usize
}
Instruction::Test(a, b, c) => {
let is_true = match self
.stack
.get(b)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil)
{
Value::Number(val) => (LuaNumber::from_bits(val) as u16) == *c,
_ => false,
};
if is_true {
self.set_stack(
*a,
self.stack
.get(b)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil),
);
} else {
self.program_counter += 1;
}
}
Instruction::Call(func_reg, param_len, ret_len) => {
let param_start_func_reg = if *param_len == 0 {
self.function_register
} else {
*func_reg
};
let param_len = if *param_len == 0 {
self.top - self.top.min(param_start_func_reg)
} else {
*param_len
};
self.function_register = *func_reg;
let mut params = Vec::new();
for i in 0..param_len {
params.push(
self.stack
.get(&(param_start_func_reg + i + 1))
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil)
.clone(),
);
}
let value = self
.stack
.get(func_reg)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
match value {
Value::RustFunction(func) => {
let ret_values = func.borrow_mut().execute(params);
if *ret_len != 0 {
for i in 0..=(*ret_len - 2) {
self.set_stack(
*func_reg + i,
ret_values.get(i as usize).cloned().unwrap_or(Value::Nil),
);
}
} else {
for (i, value) in ret_values.iter().enumerate() {
self.set_stack(*func_reg + i as u16 + 1, value.clone());
}
self.top = *func_reg + ret_values.len() as u16;
}
}
Value::Function(closure) => {
self.return_registers = Vec::new();
if *ret_len != 0 {
for i in 0..=(*ret_len - 2) {
self.return_registers.push(*func_reg + i);
}
}
self.inner = Some(Box::new(closure.run(params)));
}
_ => {
if *ret_len > 0 {
for i in 0..=(*ret_len - 2) {
self.set_stack(*func_reg + i, Value::Nil);
}
}
}
}
}
Instruction::Close(_) => {}
Instruction::Closure(reg, protok) => {
let highest_upvalue = self
.closure
.upvalues
.iter()
.map(|(v, _)| *v)
.max()
.unwrap_or(0);
let mut upvalues = self.closure.upvalues.clone();
for (reg, value) in &self.stack {
upvalues.insert(reg + highest_upvalue + 1, value.clone());
}
self.set_stack(
*reg,
Value::Function(Closure {
vm: self.closure.vm.clone(),
prototype: *protok,
environment: self.closure.environment.clone(),
upvalues,
}),
);
}
Instruction::Return(reg_start, reg_end) => {
self.program_counter += 1;
let mut ret_values = Vec::new();
let (reg_start, reg_end) = if *reg_end == 0 {
if self.function_register > 0 && self.top > 0 {
(self.function_register + 1, self.top)
} else {
(*reg_start, *reg_end)
}
} else {
(*reg_start, *reg_end)
};
for i in reg_start..=reg_end {
ret_values.push(
self.stack
.get(&i)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil),
);
}
return Some(ret_values);
}
Instruction::Add(res, lhs, rhs) => {
let lhs = self
.stack
.get(lhs)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
let rhs = self
.stack
.get(rhs)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
self.set_stack(*res, lhs.add(&rhs));
}
Instruction::LessThan(res, lhs, rhs) => {
let lhs = self
.stack
.get(lhs)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
let rhs = self
.stack
.get(rhs)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
self.set_stack(*res, lhs.lt(&rhs));
}
Instruction::LessThanOrEqual(res, lhs, rhs) => {
let lhs = self
.stack
.get(lhs)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
let rhs = self
.stack
.get(rhs)
.map(|v| v.borrow().clone())
.unwrap_or(Value::Nil);
self.set_stack(*res, lhs.lte(&rhs));
}
};
self.program_counter += 1;
None
} else {
Some(Vec::new())
}
}
}