ferrite-lua/src/compile.rs

372 lines
14 KiB
Rust

use std::collections::{HashMap, HashSet};
use crate::{
ast::{AccessModifier, BinaryOperator, Block, Expression, Literal, Statement},
vm::{Constant, Instruction, VMNumber},
};
pub struct State {
pub constants: Vec<Constant>,
pub prototypes: Vec<Vec<Instruction>>,
}
impl State {
pub fn get_constant(&self, value: &Constant) -> u32 {
self.constants
.iter()
.enumerate()
.find(|(i, c)| *c == value)
.unwrap()
.0 as u32
}
}
#[derive(Clone, Debug, Default)]
pub struct Scope {
pub locals: HashMap<String, u16>,
pub register_counter: LocalCounter,
pub highest_upvalue: u16,
pub upvalues: HashMap<String, u16>,
}
#[derive(Clone, Debug, Default)]
pub struct LocalCounter(u16);
impl LocalCounter {
pub fn next(&mut self) -> u16 {
let temp = self.0;
self.0 += 1;
temp
}
}
impl Block {
pub fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
let mut constants = HashSet::new();
let mut inner_scope = scope.clone();
for statement in &self.statements {
constants.extend(statement.kind.find_constants(&mut inner_scope));
}
constants
}
pub fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
let mut inner_scope = scope.clone();
for statement in &self.statements {
instructions.extend(statement.kind.compile(state, &mut inner_scope));
}
instructions
}
}
impl Statement {
pub fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
match self {
Statement::Assignment(access, names, expr_list) => {
let mut constants = HashSet::new();
if *access == AccessModifier::Global {
for name in names {
constants.insert(Constant::String(name.kind.clone()));
}
} else {
for name in names {
scope.locals.insert(name.kind.clone(), 0);
}
}
for expr in &expr_list.0 {
constants.extend(expr.kind.find_constants(scope));
}
constants
}
Statement::Return(expr_list) => {
let mut constants = HashSet::new();
for expr in &expr_list.0 {
constants.extend(expr.kind.find_constants(scope));
}
constants
}
Statement::If(cond, then) => {
let mut constants = HashSet::new();
constants.extend(cond.kind.find_constants(scope));
constants.extend(then.find_constants(scope));
constants
}
}
}
pub fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
match self {
Statement::Assignment(access_modifier, names, expr_list) => {
instructions.push(Instruction::LoadNil(
scope.register_counter.0 + 1,
scope.register_counter.0 + names.len() as u16,
));
let mut expr_regs = Vec::new();
for expr in &expr_list.0 {
let (instr, regs) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
expr_regs.extend(regs);
}
match access_modifier {
AccessModifier::Local => {
for (name, reg) in names.iter().zip(expr_regs) {
scope.locals.insert(name.kind.clone(), reg);
}
}
AccessModifier::Global => {
for (name, reg) in names.iter().zip(expr_regs) {
let global = state.get_constant(&Constant::String(name.kind.clone()));
instructions.push(Instruction::SetGlobal(reg, global));
}
}
}
}
Statement::Return(expr_list) => {
let mut ret_registers = Vec::new();
for expr in &expr_list.0 {
let (instr, registers) = expr.kind.compile(
state,
scope,
if expr_list.0.len() == 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
ret_registers.extend(registers);
}
let first_ret_register = ret_registers
.iter()
.cloned()
.next()
.unwrap_or(scope.register_counter.0);
for (i, ret_register) in ret_registers.iter_mut().enumerate() {
let new_reg = first_ret_register + i as u16;
if *ret_register != new_reg {
instructions.push(Instruction::Move(new_reg, *ret_register));
}
*ret_register = new_reg;
}
dbg!(&ret_registers);
instructions.push(Instruction::Return(
*ret_registers.first().unwrap_or(&scope.register_counter.0),
*ret_registers.last().unwrap_or(&0),
));
}
Statement::If(node, block) => todo!(),
}
instructions
}
}
impl Expression {
pub fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
match self {
Expression::ValueRef(name) => {
let mut constants = HashSet::new();
if !scope.locals.contains_key(name) {
constants.insert(Constant::String(name.clone()));
}
constants
}
Expression::BinOp(_, lhs, rhs) => {
let mut constants = HashSet::new();
constants.extend(lhs.kind.find_constants(scope));
constants.extend(rhs.kind.find_constants(scope));
constants
}
Expression::FunctionDefinition(_, block) => block.find_constants(scope),
Expression::FunctionCall(expr, params) => {
let mut constants = HashSet::new();
constants.extend(expr.kind.find_constants(scope));
for param in &params.kind.0 {
constants.extend(param.kind.find_constants(scope));
}
constants
}
Expression::Literal(literal) => match literal {
Literal::Number(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Number(value.to_bits()));
constants
}
},
}
}
pub fn compile(
&self,
state: &mut State,
scope: &mut Scope,
expected_values: Option<usize>,
) -> (Vec<Instruction>, Vec<u16>) {
match self {
Expression::ValueRef(name) => {
if let Some(reg) = scope.locals.get(name) {
(Vec::new(), vec![*reg])
} else if let Some(upvalue) = scope.upvalues.get(name) {
let local = scope.register_counter.next();
(vec![Instruction::GetUpVal(local, *upvalue)], vec![local])
} else {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(Instruction::GetGlobal(
reg,
state.get_constant(&Constant::String(name.clone())),
));
(instructions, vec![reg])
}
}
Expression::BinOp(binary_operator, lhs, rhs) => {
let mut instructions = Vec::new();
let (instr, lhs) = lhs.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, rhs) = rhs.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let reg = scope.register_counter.next();
match binary_operator {
BinaryOperator::LessThan => {
instructions.push(Instruction::LessThan(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
BinaryOperator::Gt => {
instructions.push(Instruction::LessThan(
reg,
*rhs.get(0).unwrap(),
*lhs.get(0).unwrap(),
));
}
BinaryOperator::Add => {
instructions.push(Instruction::Add(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
};
(instructions, vec![reg])
}
Expression::FunctionDefinition(params, block) => {
let mut inner_scope = Scope::default();
for param in params {
inner_scope
.locals
.insert(param.kind.clone(), inner_scope.register_counter.next());
}
inner_scope.highest_upvalue =
scope.highest_upvalue + inner_scope.register_counter.0;
inner_scope.upvalues = scope.upvalues.clone();
for (name, reg) in &scope.locals {
let new_reg = *reg + inner_scope.highest_upvalue + 1;
inner_scope.upvalues.insert(name.clone(), new_reg);
}
let instructions = block.compile(state, &mut inner_scope);
state.prototypes.push(instructions);
let mut instructions = Vec::new();
instructions.push(Instruction::Close(scope.register_counter.0));
let local = scope.register_counter.next();
instructions.push(Instruction::Closure(local, state.prototypes.len() as u32));
(instructions, vec![local])
}
Expression::FunctionCall(expr, params) => {
let mut instructions = Vec::new();
let (instr, registers) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let old_function_reg = registers.first().unwrap();
let mut param_scope = scope.clone();
let mut original_param_regs = Vec::new();
for param in params.kind.0.iter() {
let (instr, registers) = param.kind.compile(
state,
&mut param_scope,
if params.kind.0.len() == 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
original_param_regs.extend(registers);
}
let function_reg = scope.register_counter.next();
let mut param_regs = Vec::new();
for _ in &original_param_regs {
param_regs.push(scope.register_counter.next());
}
for (i, param_reg) in original_param_regs.iter().enumerate().rev() {
let new_reg = param_regs.get(i).unwrap();
if param_reg != new_reg {
instructions.push(Instruction::Move(*new_reg, *param_reg));
}
}
if function_reg != *old_function_reg {
instructions.push(Instruction::Move(function_reg, *old_function_reg));
}
let last_param_reg = param_regs.last().unwrap_or(&function_reg);
let mut return_regs = Vec::new();
if let Some(expected_values) = expected_values {
for i in 0..expected_values {
let return_reg = i as u16 + function_reg;
if return_reg > *last_param_reg {
return_regs.push(scope.register_counter.next());
} else {
return_regs.push(return_reg);
}
}
}
instructions.push(Instruction::Call(
*&function_reg,
param_regs.len() as u16,
if return_regs.len() == 0 {
0
} else {
return_regs.len() as u16 + 1
},
));
(instructions, return_regs)
}
Expression::Literal(literal) => {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(Instruction::LoadK(
reg,
state.get_constant(&match literal {
Literal::Number(value) => Constant::Number(value.to_bits()),
}),
));
(instructions, vec![reg])
}
}
}
}