use std::collections::{HashMap, HashSet}; use crate::{ ast::{ AccessModifier, BinaryOperator, Block, Expression, IdentOrEllipsis, Literal, Statement, UnaryOperator, }, vm::{Constant, Instruction, LuaBool, LuaInteger}, }; #[derive(Clone, Debug)] pub struct State { pub constants: Vec, pub prototypes: Vec>, } impl State { pub fn get_constant(&self, value: &Constant) -> u32 { self.constants .iter() .enumerate() .find(|(_, c)| *c == value) .unwrap() .0 as u32 } } #[derive(Clone, Debug, Default)] pub struct Scope { pub locals: HashMap, pub register_counter: LocalCounter, pub highest_upvalue: u16, pub upvalues: HashMap, pub is_vararg: bool, } #[derive(Clone, Debug, Default)] pub struct LocalCounter(u16, Vec); impl LocalCounter { pub fn next(&mut self) -> u16 { if let Some(reg) = self.1.pop() { reg } else { self.new() } } pub fn consecutive(&mut self, amount: usize) -> Vec { let mut returned = Vec::new(); for _ in 0..amount { returned.push(self.new()); } returned } pub fn new(&mut self) -> u16 { let temp = self.0; self.0 += 1; temp } } impl Block { pub(crate) fn find_constants( &self, scope: &mut Scope, constants: Vec, ) -> HashSet { let mut constants = constants.iter().cloned().collect::>(); let mut inner_scope = scope.clone(); for statement in &self.statements { constants.extend(statement.kind.find_constants(&mut inner_scope)); } constants } pub(crate) fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec { let mut instructions = Vec::new(); let mut inner_scope = scope.clone(); for statement in &self.statements { instructions.extend(statement.kind.compile(state, &mut inner_scope)); } instructions } } impl Statement { fn find_constants(&self, scope: &mut Scope) -> HashSet { match self { Statement::Assignment(access, names, expr_list) => { let mut constants = HashSet::new(); if *access == Some(AccessModifier::Global) { for (name, indexes) in names { constants.insert(Constant::String(name.kind.clone())); for index in indexes { constants.extend(index.kind.find_constants(scope)); } } } else if *access == None { for (name, indexes) in names { if !scope.locals.contains_key(&name.kind) { constants.insert(Constant::String(name.kind.clone())); } for index in indexes { constants.extend(index.kind.find_constants(scope)); } } } else { for (name, indexes) in names { scope.locals.insert(name.kind.clone(), 0); for index in indexes { constants.extend(index.kind.find_constants(scope)); } } } for expr in &expr_list.0 { constants.extend(expr.kind.find_constants(scope)); } constants } Statement::Return(expr_list) => { let mut constants = HashSet::new(); for expr in &expr_list.0 { constants.extend(expr.kind.find_constants(scope)); } constants } Statement::If(cond, then) => { let mut constants = HashSet::new(); constants.extend(cond.kind.find_constants(scope)); constants.extend(then.find_constants(scope, Vec::new())); constants } Statement::Expression(expr) => expr.kind.find_constants(scope), } } fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec { let mut instructions = Vec::new(); match self { Statement::Assignment(access_modifier, names, expr_list) => { let new_registers = if *access_modifier == Some(AccessModifier::Local) { let min_reg = scope.register_counter.0 + 1; let max_reg = scope.register_counter.0 + names.len() as u16; instructions.push(Instruction::LoadNil(min_reg, max_reg)); scope.register_counter.0 += names.len() as u16 + 1; let mut new_registers = Vec::new(); for i in min_reg..=max_reg { new_registers.push(i); } new_registers } else { Vec::new() }; let mut expr_regs = Vec::new(); for expr in &expr_list.0 { let (instr, regs) = expr.kind.compile( state, scope, if expr_list.0.len() == 1 { Some(names.len()) } else { Some(1) }, ); instructions.extend(instr); expr_regs.extend(regs); } match access_modifier { Some(AccessModifier::Local) => { for (i, (name, indexes)) in names.iter().enumerate() { instructions.push(Instruction::Move( *new_registers.get(i).unwrap(), expr_regs.get(i).cloned().unwrap(), )); if indexes.len() > 0 { todo!() } scope .locals .insert(name.kind.clone(), *new_registers.get(i).unwrap()); } } Some(AccessModifier::Global) => { for (i, (name, indexes)) in names.iter().enumerate() { if indexes.len() > 0 { todo!() } let global = state.get_constant(&Constant::String(name.kind.clone())); instructions.push(Instruction::SetGlobal( expr_regs.get(i).cloned().unwrap(), global, )); } } None => { for (i, (name, indexes)) in names.iter().enumerate() { if indexes.len() > 0 { let table_reg = if let Some(reg) = scope.locals.get(&name.kind) { *reg } else if let Some(upval_reg) = scope.upvalues.get(&name.kind) { let local = scope.register_counter.next(); instructions.push(Instruction::GetUpVal(local, *upval_reg)); local } else { let global = state.get_constant(&Constant::String(name.kind.clone())); let local = scope.register_counter.next(); instructions.push(Instruction::GetGlobal(local, global)); local }; if indexes.len() > 1 { for (_, index) in indexes .iter() .enumerate() .take_while(|(i, _)| *i != indexes.len() - 1) { let (instr, index_reg) = index.kind.compile(state, scope, Some(1)); instructions.extend(instr); instructions.push(Instruction::GetTable( table_reg, table_reg, *index_reg.first().unwrap(), )); } } let (instr, index_reg) = indexes.last().unwrap().kind.compile(state, scope, Some(1)); instructions.extend(instr); instructions.push(Instruction::SetTable( table_reg, *index_reg.first().unwrap(), expr_regs.get(i).cloned().unwrap(), )); } else { if let Some(reg) = scope.locals.get(&name.kind) { instructions.push(Instruction::Move( *reg, expr_regs.get(i).cloned().unwrap(), )); } else if let Some(upval_reg) = scope.upvalues.get(&name.kind) { instructions.push(Instruction::SetUpVal( *upval_reg, expr_regs.get(i).cloned().unwrap(), )); } else { let global = state.get_constant(&Constant::String(name.kind.clone())); instructions.push(Instruction::SetGlobal( expr_regs.get(i).cloned().unwrap(), global, )); } } } } } } Statement::Return(expr_list) => { let mut ret_registers = Vec::new(); let mut vararg = false; for (i, expr) in expr_list.0.iter().enumerate() { let (instr, registers) = expr.kind.compile( state, scope, if i == expr_list.0.len() - 1 { None } else { Some(1) }, ); instructions.extend(instr); if registers.len() == 0 { vararg = true; } ret_registers.extend(registers); } let first_ret_register = ret_registers .iter() .cloned() .next() .unwrap_or(scope.register_counter.0); for (i, ret_register) in ret_registers.iter_mut().enumerate() { let new_reg = first_ret_register + i as u16; if *ret_register != new_reg { instructions.push(Instruction::Move(new_reg, *ret_register)); } *ret_register = new_reg; } if vararg { instructions.push(Instruction::MoveRetValues( first_ret_register + ret_registers.len() as u16, )); } dbg!(&first_ret_register); instructions.push(Instruction::Return( first_ret_register, if vararg { 0 } else { *ret_registers.last().unwrap_or(&0) }, )); } Statement::If(expr, block) => { let (instr, regs) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); let local = scope.register_counter.next(); instructions.push(Instruction::Test(local, *regs.first().unwrap(), 0)); let block_instructions = block.compile(state, scope); instructions.push(Instruction::Jmp(block_instructions.len() as i32)); instructions.extend(block_instructions); } Statement::Expression(expr) => { let (instr, _) = expr.kind.compile(state, scope, None); instructions.extend(instr); } } for reg in 0..scope.register_counter.0 { if scope.locals.values().find(|r| **r == reg).is_some() { // Register is still in use continue; } if scope.register_counter.1.contains(®) { // Register is already in the list of unused registers continue; } scope.register_counter.1.push(reg); } instructions } } impl Expression { fn find_constants(&self, scope: &mut Scope) -> HashSet { match self { Expression::ValueRef(name) => { let mut constants = HashSet::new(); if !scope.locals.contains_key(name) { constants.insert(Constant::String(name.clone())); } constants } Expression::BinOp(_, lhs, rhs) => { let mut constants = HashSet::new(); constants.extend(lhs.kind.find_constants(scope)); constants.extend(rhs.kind.find_constants(scope)); constants } Expression::UnOp(_, expr) => expr.kind.find_constants(scope), Expression::FunctionDefinition(_, block) => block.find_constants(scope, Vec::new()), Expression::FunctionCall(expr, params) => { let mut constants = HashSet::new(); constants.extend(expr.kind.find_constants(scope)); for param in ¶ms.kind.0 { constants.extend(param.kind.find_constants(scope)); } constants } Expression::Literal(literal) => match literal { Literal::Float(value) => { let mut constants = HashSet::new(); constants.insert(Constant::Float(value.vm_number())); constants } Literal::Integer(value) => { let mut constants = HashSet::new(); constants.insert(Constant::Integer(*value)); constants } Literal::String(value) => { let mut constants = HashSet::new(); constants.insert(Constant::String(value.clone())); constants } Literal::Bool(value) => { let mut constants = HashSet::new(); constants.insert(Constant::Bool(LuaBool(*value))); constants } Literal::Nil => { let mut constants = HashSet::new(); constants.insert(Constant::Nil); constants } }, Expression::TableConstructor(entries) => { let mut constants = HashSet::new(); let mut counter = 1; for (key, value) in entries { if let Some(key) = key { constants.extend(key.kind.find_constants(scope)); } else { constants.insert(Constant::Integer(LuaInteger(counter))); counter += 1; } constants.extend(value.kind.find_constants(scope)); } constants } Expression::IndexedAccess(expr, index) => { let mut constants = HashSet::new(); constants.extend(expr.kind.find_constants(scope)); constants.extend(index.kind.find_constants(scope)); constants } Expression::Ellipsis => HashSet::new(), } } fn compile( &self, state: &mut State, scope: &mut Scope, expected_values: Option, ) -> (Vec, Vec) { match self { Expression::ValueRef(name) => { if let Some(reg) = scope.locals.get(name) { (Vec::new(), vec![*reg]) } else if let Some(upvalue) = scope.upvalues.get(name) { let local = scope.register_counter.next(); (vec![Instruction::GetUpVal(local, *upvalue)], vec![local]) } else { let mut instructions = Vec::new(); let reg = scope.register_counter.next(); instructions.push(Instruction::GetGlobal( reg, state.get_constant(&Constant::String(name.clone())), )); (instructions, vec![reg]) } } Expression::BinOp(binary_operator, lhs, rhs) => { let mut instructions = Vec::new(); let (instr, lhs) = lhs.kind.compile(state, scope, Some(1)); instructions.extend(instr); let (instr, rhs) = rhs.kind.compile(state, scope, Some(1)); instructions.extend(instr); let reg = scope.register_counter.next(); match binary_operator { BinaryOperator::Equal => { instructions.push(Instruction::Equal( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } BinaryOperator::LessThan => { instructions.push(Instruction::LessThan( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } BinaryOperator::LessThanOrEqual => { instructions.push(Instruction::LessThanOrEqual( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } BinaryOperator::GreaterThan => { instructions.push(Instruction::LessThan( reg, *rhs.get(0).unwrap(), *lhs.get(0).unwrap(), )); } BinaryOperator::GreaterThanOrEqual => { instructions.push(Instruction::LessThanOrEqual( reg, *rhs.get(0).unwrap(), *lhs.get(0).unwrap(), )); } BinaryOperator::Add => { instructions.push(Instruction::Add( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } BinaryOperator::Sub => { instructions.push(Instruction::Unm(reg, *rhs.get(0).unwrap())); instructions.push(Instruction::Add(reg, *lhs.get(0).unwrap(), reg)); } BinaryOperator::And => { instructions.push(Instruction::And( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } BinaryOperator::Or => { instructions.push(Instruction::Or( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } }; (instructions, vec![reg]) } Expression::UnOp(op, expr) => { let mut instructions = Vec::new(); let (instr, registers) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); for reg in ®isters { match op { UnaryOperator::Negation => instructions.push(Instruction::Unm(*reg, *reg)), UnaryOperator::Length => instructions.push(Instruction::Len(*reg, *reg)), } } (instructions, registers) } Expression::FunctionDefinition(params, block) => { let mut inner_scope = Scope::default(); for param in params { match ¶m.kind { IdentOrEllipsis::Ident(name) => { inner_scope .locals .insert(name.clone(), inner_scope.register_counter.next()); } IdentOrEllipsis::Ellipsis => { inner_scope.is_vararg = true; } } } inner_scope.highest_upvalue = scope.highest_upvalue + scope.register_counter.0; inner_scope.upvalues = scope.upvalues.clone(); for (name, reg) in &scope.locals { let new_reg = *reg + scope.highest_upvalue + 1; inner_scope.upvalues.insert(name.clone(), new_reg); } let instructions = block.compile(state, &mut inner_scope); state.prototypes.push(instructions); let mut instructions = Vec::new(); instructions.push(Instruction::Close(scope.register_counter.0)); let local = scope.register_counter.next(); instructions.push(Instruction::Closure(local, state.prototypes.len() as u32)); (instructions, vec![local]) } Expression::FunctionCall(expr, params) => { let mut instructions = Vec::new(); let (instr, registers) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); let old_function_reg = registers.first().unwrap(); let mut registers = scope .register_counter .consecutive(params.kind.0.len() + 1) .into_iter(); let mut param_scope = scope.clone(); let mut original_param_regs = Vec::new(); let mut vararg = false; for (i, param) in params.kind.0.iter().enumerate() { let (instr, registers) = param.kind.compile( state, &mut param_scope, if i == params.kind.0.len() - 1 { None } else { Some(1) }, ); instructions.extend(instr); if registers.len() > 0 { original_param_regs.push(*registers.first().unwrap()); } else { vararg = true; } } let function_reg = registers.next().unwrap(); let mut param_regs = Vec::new(); for _ in &original_param_regs { param_regs.push(registers.next().unwrap()); } for (i, param_reg) in original_param_regs.iter().enumerate().rev() { let new_reg = param_regs.get(i).unwrap(); if param_reg != new_reg { instructions.push(Instruction::Move(*new_reg, *param_reg)); } } if function_reg != *old_function_reg { instructions.push(Instruction::Move(function_reg, *old_function_reg)); } if vararg { let last_reg = param_regs.last().unwrap_or(&function_reg) + 1; instructions.push(Instruction::MoveRetValues(last_reg)); } let last_param_reg = param_regs.last().unwrap_or(&function_reg); let mut return_regs = Vec::new(); if let Some(expected_values) = expected_values { for i in 0..expected_values { let return_reg = i as u16 + function_reg; if return_reg > *last_param_reg { return_regs.push(scope.register_counter.next()); } else { return_regs.push(return_reg); } } } instructions.push(Instruction::Call( *&function_reg, if vararg { 0 } else { param_regs.len() as u16 }, if return_regs.len() == 0 { 0 } else { return_regs.len() as u16 + 1 }, )); (instructions, return_regs) } Expression::Literal(literal) => { let mut instructions = Vec::new(); let reg = scope.register_counter.next(); instructions.push(Instruction::LoadK( reg, state.get_constant(&match literal { Literal::Float(value) => Constant::Float(value.vm_number()), Literal::String(value) => Constant::String(value.clone()), Literal::Integer(lua_integer) => Constant::Integer(*lua_integer), Literal::Bool(value) => Constant::Bool(LuaBool(*value)), Literal::Nil => Constant::Nil, }), )); (instructions, vec![reg]) } Expression::TableConstructor(entries) => { let mut instructions = Vec::new(); let reg = scope.register_counter.next(); instructions.push(Instruction::NewTable(reg)); let mut counter = 1; for (i, (key, value)) in entries.iter().enumerate() { if let Some(key) = key { let (instr, key_regs) = key.kind.compile(state, scope, Some(1)); instructions.extend(instr); let (instr, value_regs) = value.kind.compile(state, scope, Some(1)); instructions.extend(instr); instructions.push(Instruction::SetTable( reg, *key_regs.first().unwrap(), *value_regs.first().unwrap(), )); } else { let (instr, value_regs) = value.kind.compile( state, scope, if i == entries.len() - 1 { None } else { Some(1) }, ); instructions.extend(instr); if value_regs.len() > 0 { let key_reg = scope.register_counter.next(); instructions.push(Instruction::LoadK( key_reg, state.get_constant(&Constant::Integer(LuaInteger(counter))), )); instructions.push(Instruction::SetTable( reg, key_reg, *value_regs.first().unwrap(), )); counter += 1; } else { instructions.push(Instruction::SetList(reg, counter as u32)); } } } (instructions, vec![reg]) } Expression::IndexedAccess(expr, index) => { let mut instructions = Vec::new(); let (instr, expr_regs) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); let (instr, index_regs) = index.kind.compile(state, scope, Some(1)); instructions.extend(instr); let local = scope.register_counter.next(); instructions.push(Instruction::GetTable( local, *expr_regs.first().unwrap(), *index_regs.first().unwrap(), )); (instructions, vec![local]) } Expression::Ellipsis => { if !scope.is_vararg { panic!("Function is not vararg!"); } let mut instructions = Vec::new(); let new_reg = scope.register_counter.new(); instructions.push(Instruction::Vararg(new_reg)); (instructions, Vec::new()) } } } }