From 05da3db5e65399f5ff52a76e921be9834897123e Mon Sep 17 00:00:00 2001 From: Sofia Date: Sun, 15 Mar 2026 17:54:48 +0200 Subject: [PATCH] Implement returning expression lists --- examples/test.lua | 10 ++++-- src/compile.rs | 80 ++++++++++++++++++++++++++++++++++++----------- src/vm.rs | 46 ++++++++++++++++++++++++--- 3 files changed, 111 insertions(+), 25 deletions(-) diff --git a/examples/test.lua b/examples/test.lua index 9cbc050..61d42c9 100644 --- a/examples/test.lua +++ b/examples/test.lua @@ -1,5 +1,11 @@ function add(x) - return x + return function (y) + return x + y, 1, 2 + end end -global c = print(add(5)) \ No newline at end of file +function test() + return add(10)(15) +end + +global c = print(test()) \ No newline at end of file diff --git a/src/compile.rs b/src/compile.rs index d869835..36b144f 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -98,7 +98,7 @@ impl Statement { match self { Statement::Assignment(access_modifier, name, expr) => { - let (instr, regs) = expr.kind.compile(state, scope, 1); + let (instr, regs) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); match access_modifier { AccessModifier::Local => { @@ -115,13 +115,36 @@ impl Statement { Statement::Return(expr_list) => { let mut ret_registers = Vec::new(); for expr in &expr_list.0 { - let (instr, registers) = expr.kind.compile(state, scope, 1); + let (instr, registers) = expr.kind.compile( + state, + scope, + if expr_list.0.len() == 1 { + None + } else { + Some(1) + }, + ); instructions.extend(instr); ret_registers.extend(registers); } + + let first_ret_register = ret_registers + .iter() + .cloned() + .next() + .unwrap_or(scope.register_counter.0); + for (i, ret_register) in ret_registers.iter_mut().enumerate() { + let new_reg = first_ret_register + i as u16; + if *ret_register != new_reg { + instructions.push(Instruction::Move(new_reg, *ret_register)); + } + *ret_register = new_reg; + } + + dbg!(&ret_registers); instructions.push(Instruction::Return( - *ret_registers.first().unwrap(), - *ret_registers.last().unwrap(), + *ret_registers.first().unwrap_or(&scope.register_counter.0), + *ret_registers.last().unwrap_or(&0), )); } Statement::If(node, block) => todo!(), @@ -170,7 +193,7 @@ impl Expression { &self, state: &mut State, scope: &mut Scope, - expected_values: usize, + expected_values: Option, ) -> (Vec, Vec) { match self { Expression::ValueRef(name) => { @@ -191,9 +214,9 @@ impl Expression { } Expression::BinOp(binary_operator, lhs, rhs) => { let mut instructions = Vec::new(); - let (instr, lhs) = lhs.kind.compile(state, scope, 1); + let (instr, lhs) = lhs.kind.compile(state, scope, Some(1)); instructions.extend(instr); - let (instr, rhs) = rhs.kind.compile(state, scope, 1); + let (instr, rhs) = rhs.kind.compile(state, scope, Some(1)); instructions.extend(instr); let reg = scope.register_counter.next(); match binary_operator { @@ -233,7 +256,9 @@ impl Expression { inner_scope.upvalues = scope.upvalues.clone(); for (name, reg) in &scope.locals { - inner_scope.upvalues.insert(name.clone(), *reg + 1); + inner_scope + .upvalues + .insert(name.clone(), *reg + highest_upvalue + 1); } let instructions = block.compile(state, &mut inner_scope); @@ -248,17 +273,28 @@ impl Expression { } Expression::FunctionCall(expr, params) => { let mut instructions = Vec::new(); - let (instr, registers) = expr.kind.compile(state, scope, 1); + + let (instr, registers) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); let old_function_reg = registers.first().unwrap(); + let mut param_scope = scope.clone(); let mut original_param_regs = Vec::new(); for param in params.kind.0.iter() { - let (instr, registers) = param.kind.compile(state, &mut scope.clone(), 1); + let (instr, registers) = param.kind.compile( + state, + &mut param_scope, + if params.kind.0.len() == 1 { + None + } else { + Some(1) + }, + ); instructions.extend(instr); original_param_regs.extend(registers); } + let function_reg = scope.register_counter.next(); let mut param_regs = Vec::new(); for _ in &original_param_regs { @@ -271,24 +307,32 @@ impl Expression { instructions.push(Instruction::Move(*new_reg, *param_reg)); } } - instructions.push(Instruction::Move(function_reg, *old_function_reg)); + if function_reg != *old_function_reg { + instructions.push(Instruction::Move(function_reg, *old_function_reg)); + } let last_param_reg = param_regs.last().unwrap_or(&function_reg); let mut return_regs = Vec::new(); - for i in 0..expected_values { - let return_reg = i as u16 + function_reg; - if return_reg > *last_param_reg { - return_regs.push(scope.register_counter.next()); - } else { - return_regs.push(return_reg); + if let Some(expected_values) = expected_values { + for i in 0..expected_values { + let return_reg = i as u16 + function_reg; + if return_reg > *last_param_reg { + return_regs.push(scope.register_counter.next()); + } else { + return_regs.push(return_reg); + } } } instructions.push(Instruction::Call( *&function_reg, param_regs.len() as u16, - return_regs.len() as u16 + 1, + if return_regs.len() == 0 { + 0 + } else { + return_regs.len() as u16 + 1 + }, )); (instructions, return_regs) diff --git a/src/vm.rs b/src/vm.rs index f256563..0b056bd 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -157,7 +157,9 @@ impl Closure { program_counter: 0, stack, inner: None, + function_register: 0, return_registers: Vec::new(), + top: 0, } } } @@ -167,7 +169,9 @@ pub struct ClosureRunner { pub program_counter: usize, pub stack: HashMap, pub inner: Option>, + pub function_register: u16, pub return_registers: Vec, + pub top: u16, } impl ClosureRunner { @@ -175,6 +179,13 @@ impl ClosureRunner { if let Some(inner) = &mut self.inner { if let Some(ret_values) = inner.next() { self.inner = None; + if self.return_registers.len() == 0 { + for (i, value) in ret_values.iter().enumerate() { + self.stack + .insert(self.function_register + i as u16 + 1, value.clone()); + } + self.top = self.function_register + ret_values.len() as u16; + } for (i, reg) in self.return_registers.iter().enumerate() { self.stack .insert(*reg, ret_values.get(i).cloned().unwrap_or(Value::Nil)); @@ -226,15 +237,29 @@ impl ClosureRunner { } } Instruction::Call(func_reg, param_len, ret_len) => { + let param_start_func_reg = if *param_len == 0 { + self.function_register + } else { + *func_reg + }; + + let param_len = if *param_len == 0 { + self.top - self.top.min(param_start_func_reg) + } else { + *param_len + }; + self.function_register = *func_reg; + let mut params = Vec::new(); - for i in 0..*param_len { + for i in 0..param_len { params.push( self.stack - .get(&(func_reg + i + 1)) + .get(&(param_start_func_reg + i + 1)) .unwrap_or(&Value::Nil) .clone(), ); } + let value = self.stack.get(func_reg).unwrap_or(&Value::Nil); match value { Value::RustFunction(func) => { @@ -248,8 +273,10 @@ impl ClosureRunner { } Value::Function(closure) => { self.return_registers = Vec::new(); - for i in 0..=(*ret_len - 2) { - self.return_registers.push(*func_reg + i); + if *ret_len != 0 { + for i in 0..=(*ret_len - 2) { + self.return_registers.push(*func_reg + i); + } } self.inner = Some(Box::new(closure.run(params))); } @@ -292,7 +319,16 @@ impl ClosureRunner { Instruction::Return(reg_start, reg_end) => { self.program_counter += 1; let mut ret_values = Vec::new(); - for i in *reg_start..=*reg_end { + let (reg_start, reg_end) = if *reg_end == 0 { + if self.function_register > 0 && self.top > 0 { + (self.function_register + 1, self.top) + } else { + (*reg_start, *reg_end) + } + } else { + (*reg_start, *reg_end) + }; + for i in reg_start..=reg_end { ret_values.push(self.stack.get(&i).cloned().unwrap_or(Value::Nil)); } dbg!(&self.closure.upvalues);