use std::collections::{HashMap, HashSet}; use crate::{ ast::{AccessModifier, BinaryOperator, Block, Expression, Literal, Statement}, vm::{Constant, Instruction, VMNumber}, }; pub struct State { pub constants: Vec, pub prototypes: Vec>, } impl State { pub fn get_constant(&self, value: &Constant) -> u32 { self.constants .iter() .enumerate() .find(|(i, c)| *c == value) .unwrap() .0 as u32 } } #[derive(Clone, Debug, Default)] pub struct Scope { pub locals: HashMap, pub register_counter: LocalCounter, pub highest_upvalue: u16, pub upvalues: HashMap, } #[derive(Clone, Debug, Default)] pub struct LocalCounter(u16); impl LocalCounter { pub fn next(&mut self) -> u16 { let temp = self.0; self.0 += 1; temp } } impl Block { pub fn find_constants(&self, scope: &mut Scope) -> HashSet { let mut constants = HashSet::new(); let mut inner_scope = scope.clone(); for statement in &self.statements { constants.extend(statement.kind.find_constants(&mut inner_scope)); } constants } pub fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec { let mut instructions = Vec::new(); let mut inner_scope = scope.clone(); for statement in &self.statements { instructions.extend(statement.kind.compile(state, &mut inner_scope)); } instructions } } impl Statement { pub fn find_constants(&self, scope: &mut Scope) -> HashSet { match self { Statement::Assignment(access, names, expr_list) => { let mut constants = HashSet::new(); if *access == AccessModifier::Global { for name in names { constants.insert(Constant::String(name.kind.clone())); } } else { for name in names { scope.locals.insert(name.kind.clone(), 0); } } for expr in &expr_list.0 { constants.extend(expr.kind.find_constants(scope)); } constants } Statement::Return(expr_list) => { let mut constants = HashSet::new(); for expr in &expr_list.0 { constants.extend(expr.kind.find_constants(scope)); } constants } Statement::If(cond, then) => { let mut constants = HashSet::new(); constants.extend(cond.kind.find_constants(scope)); constants.extend(then.find_constants(scope)); constants } } } pub fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec { let mut instructions = Vec::new(); match self { Statement::Assignment(access_modifier, names, expr_list) => { instructions.push(Instruction::LoadNil( scope.register_counter.0 + 1, scope.register_counter.0 + names.len() as u16, )); let mut expr_regs = Vec::new(); for expr in &expr_list.0 { let (instr, regs) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); expr_regs.extend(regs); } match access_modifier { AccessModifier::Local => { for (name, reg) in names.iter().zip(expr_regs) { scope.locals.insert(name.kind.clone(), reg); } } AccessModifier::Global => { for (name, reg) in names.iter().zip(expr_regs) { let global = state.get_constant(&Constant::String(name.kind.clone())); instructions.push(Instruction::SetGlobal(reg, global)); } } } } Statement::Return(expr_list) => { let mut ret_registers = Vec::new(); for expr in &expr_list.0 { let (instr, registers) = expr.kind.compile( state, scope, if expr_list.0.len() == 1 { None } else { Some(1) }, ); instructions.extend(instr); ret_registers.extend(registers); } let first_ret_register = ret_registers .iter() .cloned() .next() .unwrap_or(scope.register_counter.0); for (i, ret_register) in ret_registers.iter_mut().enumerate() { let new_reg = first_ret_register + i as u16; if *ret_register != new_reg { instructions.push(Instruction::Move(new_reg, *ret_register)); } *ret_register = new_reg; } dbg!(&ret_registers); instructions.push(Instruction::Return( *ret_registers.first().unwrap_or(&scope.register_counter.0), *ret_registers.last().unwrap_or(&0), )); } Statement::If(node, block) => todo!(), } instructions } } impl Expression { pub fn find_constants(&self, scope: &mut Scope) -> HashSet { match self { Expression::ValueRef(name) => { let mut constants = HashSet::new(); if !scope.locals.contains_key(name) { constants.insert(Constant::String(name.clone())); } constants } Expression::BinOp(_, lhs, rhs) => { let mut constants = HashSet::new(); constants.extend(lhs.kind.find_constants(scope)); constants.extend(rhs.kind.find_constants(scope)); constants } Expression::FunctionDefinition(_, block) => block.find_constants(scope), Expression::FunctionCall(expr, params) => { let mut constants = HashSet::new(); constants.extend(expr.kind.find_constants(scope)); for param in ¶ms.kind.0 { constants.extend(param.kind.find_constants(scope)); } constants } Expression::Literal(literal) => match literal { Literal::Number(value) => { let mut constants = HashSet::new(); constants.insert(Constant::Number(value.to_bits())); constants } }, } } pub fn compile( &self, state: &mut State, scope: &mut Scope, expected_values: Option, ) -> (Vec, Vec) { match self { Expression::ValueRef(name) => { if let Some(reg) = scope.locals.get(name) { (Vec::new(), vec![*reg]) } else if let Some(upvalue) = scope.upvalues.get(name) { let local = scope.register_counter.next(); (vec![Instruction::GetUpVal(local, *upvalue)], vec![local]) } else { let mut instructions = Vec::new(); let reg = scope.register_counter.next(); instructions.push(Instruction::GetGlobal( reg, state.get_constant(&Constant::String(name.clone())), )); (instructions, vec![reg]) } } Expression::BinOp(binary_operator, lhs, rhs) => { let mut instructions = Vec::new(); let (instr, lhs) = lhs.kind.compile(state, scope, Some(1)); instructions.extend(instr); let (instr, rhs) = rhs.kind.compile(state, scope, Some(1)); instructions.extend(instr); let reg = scope.register_counter.next(); match binary_operator { BinaryOperator::LessThan => { instructions.push(Instruction::LessThan( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } BinaryOperator::Gt => { instructions.push(Instruction::LessThan( reg, *rhs.get(0).unwrap(), *lhs.get(0).unwrap(), )); } BinaryOperator::Add => { instructions.push(Instruction::Add( reg, *lhs.get(0).unwrap(), *rhs.get(0).unwrap(), )); } }; (instructions, vec![reg]) } Expression::FunctionDefinition(params, block) => { let mut inner_scope = Scope::default(); for param in params { inner_scope .locals .insert(param.kind.clone(), inner_scope.register_counter.next()); } inner_scope.highest_upvalue = scope.highest_upvalue + inner_scope.register_counter.0; inner_scope.upvalues = scope.upvalues.clone(); for (name, reg) in &scope.locals { let new_reg = *reg + inner_scope.highest_upvalue + 1; inner_scope.upvalues.insert(name.clone(), new_reg); } let instructions = block.compile(state, &mut inner_scope); state.prototypes.push(instructions); let mut instructions = Vec::new(); instructions.push(Instruction::Close(scope.register_counter.0)); let local = scope.register_counter.next(); instructions.push(Instruction::Closure(local, state.prototypes.len() as u32)); (instructions, vec![local]) } Expression::FunctionCall(expr, params) => { let mut instructions = Vec::new(); let (instr, registers) = expr.kind.compile(state, scope, Some(1)); instructions.extend(instr); let old_function_reg = registers.first().unwrap(); let mut param_scope = scope.clone(); let mut original_param_regs = Vec::new(); for param in params.kind.0.iter() { let (instr, registers) = param.kind.compile( state, &mut param_scope, if params.kind.0.len() == 1 { None } else { Some(1) }, ); instructions.extend(instr); original_param_regs.extend(registers); } let function_reg = scope.register_counter.next(); let mut param_regs = Vec::new(); for _ in &original_param_regs { param_regs.push(scope.register_counter.next()); } for (i, param_reg) in original_param_regs.iter().enumerate().rev() { let new_reg = param_regs.get(i).unwrap(); if param_reg != new_reg { instructions.push(Instruction::Move(*new_reg, *param_reg)); } } if function_reg != *old_function_reg { instructions.push(Instruction::Move(function_reg, *old_function_reg)); } let last_param_reg = param_regs.last().unwrap_or(&function_reg); let mut return_regs = Vec::new(); if let Some(expected_values) = expected_values { for i in 0..expected_values { let return_reg = i as u16 + function_reg; if return_reg > *last_param_reg { return_regs.push(scope.register_counter.next()); } else { return_regs.push(return_reg); } } } instructions.push(Instruction::Call( *&function_reg, param_regs.len() as u16, if return_regs.len() == 0 { 0 } else { return_regs.len() as u16 + 1 }, )); (instructions, return_regs) } Expression::Literal(literal) => { let mut instructions = Vec::new(); let reg = scope.register_counter.next(); instructions.push(Instruction::LoadK( reg, state.get_constant(&match literal { Literal::Number(value) => Constant::Number(value.to_bits()), }), )); (instructions, vec![reg]) } } } }