From 84eb8375340ce9fe11fffbe8cbdb24ec8fcff6cd Mon Sep 17 00:00:00 2001 From: Sofia Date: Fri, 20 Mar 2026 00:07:26 +0200 Subject: [PATCH] Add for-in, fix a bunch of bugs --- examples/test.lua | 87 +++++-------------------- src/ast.rs | 64 ++++++++++++------- src/compile.rs | 130 ++++++++++++++++++++++++++++++++++---- src/token_stream/lexer.rs | 3 + src/vm.rs | 28 ++++---- 5 files changed, 193 insertions(+), 119 deletions(-) diff --git a/examples/test.lua b/examples/test.lua index b8f7015..17a9fc7 100644 --- a/examples/test.lua +++ b/examples/test.lua @@ -1,77 +1,20 @@ -global b = 5 +local table = {10, 20, 30} -function add(x) - return function (y) - x = x + 1 - b = b + 1 - return x + y, 1, 2, b - end +function ipairs(t) + print("inside!") + local i = 0 + return function (state, control) + print(state, control) + i = i + 1 + if i > #table then + return nil, nil + end + return i, t[i] + end, "otus", "potus" end -function min(x, y) - local m = x - if y < x then - m = y - end - return m -end - - -function f(x, ...) - local b = {10, ..., add(10)(15)} - return x + 5, b -end - -global sometable = {} -sometable["hello"] = { 100, 150, add(10)(15) } -print(#sometable["hello"]) -sometable["hello"].there = "my dude" -print(sometable.hello.there) - -print(max(11.12345, 9)) -print(add(10)(15)) -print(add(10)(15)) -print(b) -print(min(11, 9)) -print(10 - 15) -print("hello there!") -print(true or 0) - -global value, table = f(10, 11, 12) - -print("hello") -for i=1,#table do - print(table[i]) - if i > 2 then - goto test - end -end -::test:: - -local test = table[1] -if test == 10 then - print("first") -elseif test == 11 then - print("second") -else - print("third") -end -print("after if/elseif/else") - - -local i = 0 print("before") -while i < 10 do - i = i + 1 - print(i) +for k, v in ipairs(table) do + print(k, v) end -print("after while") - - -local i = 0 -print("before") -repeat - i = i + 1 - print(i) -until i >= 10 -print("after repeat") \ No newline at end of file +print("after!") diff --git a/src/ast.rs b/src/ast.rs index 451486a..4796297 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -272,6 +272,7 @@ pub enum Statement { Node, Block, ), + GenericFor(Vec>, Node, Block), While(Node, Block), Repeat(Block, Node), Break, @@ -377,32 +378,49 @@ impl Parse for Statement { } else if stream.peek() == Some(Token::Keyword(Keyword::For)) { stream.next(); let counter_name = stream.parse()?; - stream.expect_symbol('=')?; - let init = stream.parse()?; - stream.expect_symbol(',')?; - let end = stream.parse()?; - let step = if let Some(Token::Symbol(',')) = stream.peek() { - stream.next(); - stream.parse()? - } else { - Node { - kind: Expression::Literal(Literal::Integer(LuaInteger(1))), - meta: Metadata::empty(), + if let Some(Token::Symbol(',') | Token::Keyword(Keyword::In)) = stream.peek() { + let mut counters = vec![counter_name]; + while let Some(Token::Symbol(',')) = stream.peek() { + stream.next(); + counters.push(stream.parse()?); } - }; - stream.expect(Token::Keyword(Keyword::Do))?; - let block = stream.parse()?; - stream.expect(Token::Keyword(Keyword::End))?; + stream.expect(Token::Keyword(Keyword::In))?; + let expr_list = stream.parse()?; + stream.expect(Token::Keyword(Keyword::Do))?; + let block = stream.parse()?; + stream.expect(Token::Keyword(Keyword::End))?; - Ok(Statement::NumericalFor( - counter_name, - init, - end, - step, - block, - )) + Ok(Self::GenericFor(counters, expr_list, block)) + } else { + stream.expect_symbol('=')?; + let init = stream.parse()?; + stream.expect_symbol(',')?; + let end = stream.parse()?; + + let step = if let Some(Token::Symbol(',')) = stream.peek() { + stream.next(); + stream.parse()? + } else { + Node { + kind: Expression::Literal(Literal::Integer(LuaInteger(1))), + meta: Metadata::empty(), + } + }; + + stream.expect(Token::Keyword(Keyword::Do))?; + let block = stream.parse()?; + stream.expect(Token::Keyword(Keyword::End))?; + + Ok(Statement::NumericalFor( + counter_name, + init, + end, + step, + block, + )) + } } else if let Some(Token::Keyword(Keyword::While)) = stream.peek() { stream.next(); let expr = stream.parse()?; @@ -496,6 +514,8 @@ pub enum Expression { TableConstructor(Vec<(Option>, Node)>), IndexedAccess(Box>, Box>), Ellipsis, + /// Raw access to a register + Register(u16), } impl Parse for Expression { diff --git a/src/compile.rs b/src/compile.rs index 1ba8b5d..69fcdb6 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -207,6 +207,14 @@ impl Statement { constants.extend(block.find_constants(scope, Vec::new())); constants } + Statement::GenericFor(_, expr_list, block) => { + let mut constants = HashSet::new(); + for expr in &expr_list.kind.0 { + constants.extend(expr.kind.find_constants(scope)); + } + constants.extend(block.find_constants(scope, Vec::new())); + constants + } Statement::While(node, block) => { let mut constants = HashSet::new(); constants.extend(node.kind.find_constants(scope)); @@ -385,11 +393,8 @@ impl Statement { ret_registers.extend(registers); } - let first_ret_register = ret_registers - .iter() - .cloned() - .next() - .unwrap_or(scope.register_counter.0); + let new_ret_registers = scope.register_counter.consecutive(ret_registers.len() + 1); + let first_ret_register = new_ret_registers.first().unwrap(); for (i, ret_register) in ret_registers.iter_mut().enumerate() { let new_reg = first_ret_register + i as u16; if *ret_register != new_reg { @@ -406,7 +411,7 @@ impl Statement { } instructions.push(PreInstr::Instr(Instruction::Return( - first_ret_register, + *first_ret_register, if vararg { 0 } else { @@ -481,6 +486,111 @@ impl Statement { ))); instructions.push(PreInstr::Instr(Instruction::Jmp(-(instr_len + 4)))); } + Statement::GenericFor(names, expr_list, block) => { + let mut expr_regs = Vec::new(); + for (i, expr) in expr_list.kind.0.iter().enumerate() { + let (instr, regs) = expr.kind.compile( + state, + scope, + if i == expr_list.kind.0.len() - 1 { + Some(4 - expr_list.kind.0.len() + 1) + } else { + Some(1) + }, + ); + instructions.extend(instr); + expr_regs.extend(regs); + } + + dbg!(&expr_regs); + + let mut inner_scope = scope.clone(); + + let iterator_reg = *expr_regs.get(0).unwrap(); + let state_reg = *expr_regs.get(1).unwrap(); + let initial_value_reg = *expr_regs.get(2).unwrap(); + let closing_value_reg = *expr_regs.get(3).unwrap(); + inner_scope + .locals + .insert("_ITERATOR".to_owned(), iterator_reg); + inner_scope.locals.insert("_STATE".to_owned(), state_reg); + inner_scope + .locals + .insert("_INIT_VALUE".to_owned(), initial_value_reg); + inner_scope + .locals + .insert("_CLOSING_VAL".to_owned(), closing_value_reg); + + let (instr, res_regs) = compile_function_call( + Node::empty(Expression::Register(iterator_reg)), + Node::empty(ExpressionList(vec![ + Node::empty(Expression::Register(state_reg)), + Node::empty(Expression::Register(initial_value_reg)), + ])), + state, + &mut inner_scope, + Some(names.len()), + ); + instructions.extend(instr); + + let mut counter_regs = Vec::new(); + for (i, name) in names.iter().enumerate() { + let reg = inner_scope.register_counter.next(); + counter_regs.push(reg); + inner_scope.locals.insert(name.kind.clone(), reg); + instructions.push(PreInstr::Instr(Instruction::Move( + reg, + *res_regs.get(i).unwrap(), + ))); + } + + let eql_res = inner_scope.register_counter.next(); + let nil_reg = inner_scope.register_counter.next(); + instructions.push(PreInstr::Instr(Instruction::LoadNil(nil_reg, nil_reg))); + instructions.push(PreInstr::Instr(Instruction::Equal( + eql_res, + *counter_regs.first().unwrap(), + nil_reg, + ))); + instructions.push(PreInstr::Instr(Instruction::Test( + inner_scope.register_counter.next(), + eql_res, + 1, + ))); + + let block_instr = block.compile(state, &mut inner_scope); + let block_instr_len = block_instr.len() as i32; + + let (func_instr, res_regs) = compile_function_call( + Node::empty(Expression::Register(iterator_reg)), + Node::empty(ExpressionList(vec![ + Node::empty(Expression::Register(state_reg)), + Node::empty(Expression::Register(initial_value_reg)), + ])), + state, + scope, + Some(names.len()), + ); + let func_instr_len = func_instr.len() as i32; + + instructions.push(PreInstr::Instr(Instruction::Jmp( + block_instr_len + func_instr_len + counter_regs.len() as i32 + 2, + ))); + instructions.extend(block_instr); + instructions.push(PreInstr::Instr(Instruction::Move( + initial_value_reg, + *counter_regs.first().unwrap(), + ))); + instructions.extend(func_instr); + + for (counter_reg, res_reg) in counter_regs.iter().zip(res_regs) { + instructions.push(PreInstr::Instr(Instruction::Move(*counter_reg, res_reg))); + } + + instructions.push(PreInstr::Instr(Instruction::Jmp( + -(block_instr_len + func_instr_len + counter_regs.len() as i32 + 6), + ))) + } Statement::While(expr, block) => { let (instr, expr_regs) = expr.kind.compile(state, scope, Some(1)); let expr_instr_len = instr.len() as i32; @@ -619,6 +729,7 @@ impl Expression { constants } Expression::Ellipsis => HashSet::new(), + Expression::Register(_) => HashSet::new(), } } @@ -883,6 +994,7 @@ impl Expression { (instructions, vec![new_reg]) } } + Expression::Register(reg) => (Vec::new(), vec![*reg]), } } } @@ -957,11 +1069,7 @@ fn compile_function_call( if let Some(expected_values) = expected_values { for i in 0..expected_values { let return_reg = i as u16 + function_reg; - if return_reg > *last_param_reg { - return_regs.push(scope.register_counter.next()); - } else { - return_regs.push(return_reg); - } + return_regs.push(return_reg); } } diff --git a/src/token_stream/lexer.rs b/src/token_stream/lexer.rs index db05c28..955958a 100644 --- a/src/token_stream/lexer.rs +++ b/src/token_stream/lexer.rs @@ -23,6 +23,7 @@ pub enum Keyword { Nil, Not, For, + In, While, Repeat, Until, @@ -48,6 +49,7 @@ impl Keyword { "nil" => Keyword::Nil, "not" => Keyword::Not, "for" => Keyword::For, + "in" => Keyword::In, "do" => Keyword::Do, "break" => Keyword::Break, "goto" => Keyword::GoTo, @@ -76,6 +78,7 @@ impl ToString for Keyword { Keyword::Nil => "nil", Keyword::Not => "not", Keyword::For => "for", + Keyword::In => "in", Keyword::While => "while", Keyword::Repeat => "repeat", Keyword::Until => "until", diff --git a/src/vm.rs b/src/vm.rs index 6dccbc9..3a87950 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -356,11 +356,11 @@ impl Value { let res = LuaBool(self.as_float()?.0 == other.as_float()?.0); Ok(Value::Boolean(res)) } - _ => Err(RuntimeError::InvalidOperands( - BinaryOperator::Equal, - self.clone(), - other.clone(), - )), + (Value::Nil, Value::Nil) => Ok(Value::Boolean(LuaBool(true))), + (Value::Nil, _) | (_, Value::Nil) => Ok(Value::Boolean(LuaBool(false))), + _ => Ok(Value::Boolean(LuaBool( + self.clone().as_indexable()? == other.clone().as_indexable()?, + ))), } } @@ -587,6 +587,7 @@ impl Closure { for (i, param) in params.iter().enumerate() { stack.insert(i as u16, Rc::new(RefCell::new(param.clone()))); } + ClosureRunner { closure: self.clone(), program_counter: 0, @@ -699,19 +700,17 @@ impl ClosureRunner { self.inner = None; if self.return_registers.len() == 0 { for (i, value) in ret_values.iter().enumerate() { - self.stack.insert( + self.set_stack( self.function_register + i as u16 + 1, - Rc::new(RefCell::new(value.clone())), + StackValue::Value(value.clone()), ); } self.top = self.function_register + ret_values.len() as u16; } - for (i, reg) in self.return_registers.iter().enumerate() { - self.stack.insert( + for (i, reg) in self.return_registers.clone().iter().enumerate() { + self.set_stack( *reg, - Rc::new(RefCell::new( - ret_values.get(i).cloned().unwrap_or(Value::Nil), - )), + StackValue::Value(ret_values.get(i).cloned().unwrap_or(Value::Nil)), ); } } else { @@ -765,7 +764,7 @@ impl ClosureRunner { } Instruction::LoadNil(from_reg, to_reg) => { for i in *from_reg..=*to_reg { - self.stack.insert(i, Rc::new(RefCell::new(Value::Nil))); + self.set_stack(i, StackValue::Value(Value::Nil)); } } Instruction::SetGlobal(reg, global) => { @@ -1000,13 +999,14 @@ impl ClosureRunner { } Instruction::Close(_) => {} Instruction::Closure(reg, protok) => { + let upvalues = self.close_upvalues(); self.set_stack( *reg, StackValue::Value(Value::Function(Closure { vm: self.closure.vm.clone(), prototype: *protok, environment: self.closure.environment.clone(), - upvalues: self.close_upvalues(), + upvalues, })), ); }