ferrite-lua/src/compile.rs
2026-03-19 16:39:51 +02:00

867 lines
35 KiB
Rust

use std::collections::{HashMap, HashSet};
use crate::{
ast::{
AccessModifier, BinaryOperator, Block, Expression, IdentOrEllipsis, Literal, Statement,
UnaryOperator,
},
vm::{Constant, Instruction, LuaBool, LuaInteger, Prototype},
};
#[derive(Clone, Debug)]
pub struct State {
pub constants: Vec<Constant>,
pub prototypes: Vec<Prototype>,
}
impl State {
pub fn get_constant(&self, value: &Constant) -> u32 {
self.constants
.iter()
.enumerate()
.find(|(_, c)| *c == value)
.unwrap()
.0 as u32
}
}
#[derive(Clone, Debug, Default)]
pub struct Scope {
pub locals: HashMap<String, u16>,
pub register_counter: LocalCounter,
pub highest_upvalue: u16,
pub upvalues: HashMap<String, u16>,
pub is_vararg: bool,
}
#[derive(Clone, Debug, Default)]
pub struct LocalCounter(u16, Vec<u16>);
impl LocalCounter {
pub fn next(&mut self) -> u16 {
if let Some(reg) = self.1.pop() {
reg
} else {
self.new()
}
}
pub fn consecutive(&mut self, amount: usize) -> Vec<u16> {
'outer: for free_num in self.1.clone() {
let mut potentials = vec![free_num];
let mut curr = free_num;
for _ in 0..amount {
if let Some(next) = self.1.iter().find(|v| **v == curr + 1) {
potentials.push(*next);
curr = *next;
} else {
continue 'outer;
}
}
self.1
.retain_mut(|v| potentials.iter().find(|p| v != *p).is_none());
return potentials;
}
let mut returned = Vec::new();
for _ in 0..amount {
returned.push(self.new());
}
returned
}
pub fn new(&mut self) -> u16 {
let temp = self.0;
self.0 += 1;
temp
}
}
pub(crate) enum PreInstr {
Instr(Instruction),
Break,
}
pub(crate) fn process_pre_instrs(pre_instructions: Vec<PreInstr>) -> Vec<Instruction> {
let mut instructions = Vec::new();
for pre_instr in pre_instructions {
match pre_instr {
PreInstr::Instr(instruction) => instructions.push(instruction),
PreInstr::Break => panic!(),
}
}
instructions
}
impl Block {
pub(crate) fn find_constants(
&self,
scope: &mut Scope,
constants: Vec<Constant>,
) -> HashSet<Constant> {
let mut constants = constants.iter().cloned().collect::<HashSet<_>>();
let mut inner_scope = scope.clone();
for statement in &self.statements {
constants.extend(statement.kind.find_constants(&mut inner_scope));
}
constants
}
pub(crate) fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec<PreInstr> {
let mut instructions = Vec::new();
let mut inner_scope = scope.clone();
for statement in &self.statements {
instructions.extend(statement.kind.compile(state, &mut inner_scope));
}
instructions
}
}
impl Statement {
fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
match self {
Statement::Assignment(access, names, expr_list) => {
let mut constants = HashSet::new();
if *access == Some(AccessModifier::Global) {
for (name, indexes) in names {
constants.insert(Constant::String(name.kind.clone()));
for index in indexes {
constants.extend(index.kind.find_constants(scope));
}
}
} else if *access == None {
for (name, indexes) in names {
if !scope.locals.contains_key(&name.kind) {
constants.insert(Constant::String(name.kind.clone()));
}
for index in indexes {
constants.extend(index.kind.find_constants(scope));
}
}
} else {
for (name, indexes) in names {
scope.locals.insert(name.kind.clone(), 0);
for index in indexes {
constants.extend(index.kind.find_constants(scope));
}
}
}
for expr in &expr_list.0 {
constants.extend(expr.kind.find_constants(scope));
}
constants
}
Statement::Return(expr_list) => {
let mut constants = HashSet::new();
for expr in &expr_list.0 {
constants.extend(expr.kind.find_constants(scope));
}
constants
}
Statement::If(cond, then) => {
let mut constants = HashSet::new();
constants.extend(cond.kind.find_constants(scope));
constants.extend(then.find_constants(scope, Vec::new()));
constants
}
Statement::Expression(expr) => expr.kind.find_constants(scope),
Statement::NumericalFor(_, init, end, step, block) => {
let mut constants = HashSet::new();
constants.extend(init.kind.find_constants(scope));
constants.extend(end.kind.find_constants(scope));
constants.extend(step.kind.find_constants(scope));
constants.extend(block.find_constants(scope, Vec::new()));
constants
}
}
}
fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec<PreInstr> {
let mut instructions = Vec::new();
match self {
Statement::Assignment(access_modifier, names, expr_list) => {
let new_registers = if *access_modifier == Some(AccessModifier::Local) {
let min_reg = scope.register_counter.0 + 1;
let max_reg = scope.register_counter.0 + names.len() as u16;
instructions.push(PreInstr::Instr(Instruction::LoadNil(min_reg, max_reg)));
scope.register_counter.0 += names.len() as u16 + 1;
let mut new_registers = Vec::new();
for i in min_reg..=max_reg {
new_registers.push(i);
}
new_registers
} else {
Vec::new()
};
let mut expr_regs = Vec::new();
for expr in &expr_list.0 {
let (instr, regs) = expr.kind.compile(
state,
scope,
if expr_list.0.len() == 1 {
Some(names.len())
} else {
Some(1)
},
);
instructions.extend(instr);
expr_regs.extend(regs);
}
match access_modifier {
Some(AccessModifier::Local) => {
let mut vararg_applied = false;
for (i, (name, indexes)) in names.iter().enumerate() {
if let Some(expr_reg) = expr_regs.get(i) {
instructions.push(PreInstr::Instr(Instruction::Move(
*new_registers.get(i).unwrap(),
*expr_reg,
)));
} else if !vararg_applied {
instructions.push(PreInstr::Instr(Instruction::MoveRetValues(
*new_registers.get(i).unwrap(),
)));
vararg_applied = true;
}
if indexes.len() > 0 {
todo!()
}
scope
.locals
.insert(name.kind.clone(), *new_registers.get(i).unwrap());
}
}
Some(AccessModifier::Global) => {
for (i, (name, indexes)) in names.iter().enumerate() {
if indexes.len() > 0 {
todo!()
}
let global = state.get_constant(&Constant::String(name.kind.clone()));
instructions.push(PreInstr::Instr(Instruction::SetGlobal(
expr_regs.get(i).cloned().unwrap(),
global,
)));
}
}
None => {
for (i, (name, indexes)) in names.iter().enumerate() {
if indexes.len() > 0 {
let table_reg = if let Some(reg) = scope.locals.get(&name.kind) {
*reg
} else if let Some(upval_reg) = scope.upvalues.get(&name.kind) {
let local = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::GetUpVal(
local, *upval_reg,
)));
local
} else {
let global =
state.get_constant(&Constant::String(name.kind.clone()));
let local = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::GetGlobal(
local, global,
)));
local
};
if indexes.len() > 1 {
for (_, index) in indexes
.iter()
.enumerate()
.take_while(|(i, _)| *i != indexes.len() - 1)
{
let (instr, index_reg) =
index.kind.compile(state, scope, Some(1));
instructions.extend(instr);
instructions.push(PreInstr::Instr(Instruction::GetTable(
table_reg,
table_reg,
*index_reg.first().unwrap(),
)));
}
}
let (instr, index_reg) =
indexes.last().unwrap().kind.compile(state, scope, Some(1));
instructions.extend(instr);
instructions.push(PreInstr::Instr(Instruction::SetTable(
table_reg,
*index_reg.first().unwrap(),
expr_regs.get(i).cloned().unwrap(),
)));
} else {
if let Some(reg) = scope.locals.get(&name.kind) {
instructions.push(PreInstr::Instr(Instruction::Move(
*reg,
expr_regs.get(i).cloned().unwrap(),
)));
} else if let Some(upval_reg) = scope.upvalues.get(&name.kind) {
instructions.push(PreInstr::Instr(Instruction::SetUpVal(
*upval_reg,
expr_regs.get(i).cloned().unwrap(),
)));
} else {
let global =
state.get_constant(&Constant::String(name.kind.clone()));
instructions.push(PreInstr::Instr(Instruction::SetGlobal(
expr_regs.get(i).cloned().unwrap(),
global,
)));
}
}
}
}
}
}
Statement::Return(expr_list) => {
let mut ret_registers = Vec::new();
let mut vararg = false;
for (i, expr) in expr_list.0.iter().enumerate() {
let (instr, registers) = expr.kind.compile(
state,
scope,
if i == expr_list.0.len() - 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
if registers.len() == 0 {
vararg = true;
}
ret_registers.extend(registers);
}
let first_ret_register = ret_registers
.iter()
.cloned()
.next()
.unwrap_or(scope.register_counter.0);
for (i, ret_register) in ret_registers.iter_mut().enumerate() {
let new_reg = first_ret_register + i as u16;
if *ret_register != new_reg {
instructions
.push(PreInstr::Instr(Instruction::Move(new_reg, *ret_register)));
}
*ret_register = new_reg;
}
if vararg {
instructions.push(PreInstr::Instr(Instruction::MoveRetValues(
first_ret_register + ret_registers.len() as u16,
)));
}
instructions.push(PreInstr::Instr(Instruction::Return(
first_ret_register,
if vararg {
0
} else {
*ret_registers.last().unwrap_or(&0)
},
)));
}
Statement::If(expr, block) => {
let (instr, regs) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let local = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::Test(
local,
*regs.first().unwrap(),
0,
)));
let block_instructions = block.compile(state, scope);
instructions.push(PreInstr::Instr(Instruction::Jmp(
block_instructions.len() as i32
)));
instructions.extend(block_instructions);
}
Statement::Expression(expr) => {
let (instr, _) = expr.kind.compile(state, scope, None);
instructions.extend(instr);
}
Statement::NumericalFor(counter, init, end, step, block) => {
let (instr, init_reg) = init.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, end_reg) = end.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, step_reg) = step.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let init_reg = init_reg.first().unwrap();
let end_reg = end_reg.first().unwrap();
let step_reg = step_reg.first().unwrap();
scope.locals.insert(counter.kind.clone(), *init_reg);
scope.locals.insert("_END".to_owned(), *end_reg);
scope.locals.insert("_STEP".to_owned(), *step_reg);
instructions.push(PreInstr::Instr(Instruction::ForTest(
*init_reg, *end_reg, *step_reg,
)));
let mut inner_scope = scope.clone();
let instr = block.compile(state, &mut inner_scope);
let instr_len = instr.len() as i32;
instructions.push(PreInstr::Instr(Instruction::Jmp(instr_len + 2)));
instructions.extend(instr);
instructions.push(PreInstr::Instr(Instruction::Add(
*init_reg, *init_reg, *step_reg,
)));
instructions.push(PreInstr::Instr(Instruction::Jmp(-(instr_len + 4))));
}
}
for reg in 0..scope.register_counter.0 {
if scope.locals.values().find(|r| **r == reg).is_some() {
// Register is still in use
continue;
}
if scope.register_counter.1.contains(&reg) {
// Register is already in the list of unused registers
continue;
}
scope.register_counter.1.push(reg);
}
instructions
}
}
impl Expression {
fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
match self {
Expression::ValueRef(name) => {
let mut constants = HashSet::new();
if !scope.locals.contains_key(name) {
constants.insert(Constant::String(name.clone()));
}
constants
}
Expression::BinOp(_, lhs, rhs) => {
let mut constants = HashSet::new();
constants.extend(lhs.kind.find_constants(scope));
constants.extend(rhs.kind.find_constants(scope));
constants
}
Expression::UnOp(_, expr) => expr.kind.find_constants(scope),
Expression::FunctionDefinition(_, block) => block.find_constants(scope, Vec::new()),
Expression::FunctionCall(expr, params) => {
let mut constants = HashSet::new();
constants.extend(expr.kind.find_constants(scope));
for param in &params.kind.0 {
constants.extend(param.kind.find_constants(scope));
}
constants
}
Expression::Literal(literal) => match literal {
Literal::Float(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Float(value.vm_number()));
constants
}
Literal::Integer(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Integer(*value));
constants
}
Literal::String(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::String(value.clone()));
constants
}
Literal::Bool(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Bool(LuaBool(*value)));
constants
}
Literal::Nil => {
let mut constants = HashSet::new();
constants.insert(Constant::Nil);
constants
}
},
Expression::TableConstructor(entries) => {
let mut constants = HashSet::new();
let mut counter = 1;
for (key, value) in entries {
if let Some(key) = key {
constants.extend(key.kind.find_constants(scope));
} else {
constants.insert(Constant::Integer(LuaInteger(counter)));
counter += 1;
}
constants.extend(value.kind.find_constants(scope));
}
constants
}
Expression::IndexedAccess(expr, index) => {
let mut constants = HashSet::new();
constants.extend(expr.kind.find_constants(scope));
constants.extend(index.kind.find_constants(scope));
constants
}
Expression::Ellipsis => HashSet::new(),
}
}
fn compile(
&self,
state: &mut State,
scope: &mut Scope,
expected_values: Option<usize>,
) -> (Vec<PreInstr>, Vec<u16>) {
match self {
Expression::ValueRef(name) => {
if let Some(reg) = scope.locals.get(name) {
(Vec::new(), vec![*reg])
} else if let Some(upvalue) = scope.upvalues.get(name) {
let local = scope.register_counter.next();
(
vec![PreInstr::Instr(Instruction::GetUpVal(local, *upvalue))],
vec![local],
)
} else {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::GetGlobal(
reg,
state.get_constant(&Constant::String(name.clone())),
)));
(instructions, vec![reg])
}
}
Expression::BinOp(binary_operator, lhs, rhs) => {
let mut instructions = Vec::new();
let (instr, lhs) = lhs.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, rhs) = rhs.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let reg = scope.register_counter.next();
match binary_operator {
BinaryOperator::Equal => {
instructions.push(PreInstr::Instr(Instruction::Equal(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
)));
}
BinaryOperator::LessThan => {
instructions.push(PreInstr::Instr(Instruction::LessThan(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
)));
}
BinaryOperator::LessThanOrEqual => {
instructions.push(PreInstr::Instr(Instruction::LessThanOrEqual(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
)));
}
BinaryOperator::GreaterThan => {
instructions.push(PreInstr::Instr(Instruction::LessThan(
reg,
*rhs.get(0).unwrap(),
*lhs.get(0).unwrap(),
)));
}
BinaryOperator::GreaterThanOrEqual => {
instructions.push(PreInstr::Instr(Instruction::LessThanOrEqual(
reg,
*rhs.get(0).unwrap(),
*lhs.get(0).unwrap(),
)));
}
BinaryOperator::Add => {
instructions.push(PreInstr::Instr(Instruction::Add(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
)));
}
BinaryOperator::Sub => {
instructions
.push(PreInstr::Instr(Instruction::Unm(reg, *rhs.get(0).unwrap())));
instructions.push(PreInstr::Instr(Instruction::Add(
reg,
*lhs.get(0).unwrap(),
reg,
)));
}
BinaryOperator::And => {
instructions.push(PreInstr::Instr(Instruction::And(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
)));
}
BinaryOperator::Or => {
instructions.push(PreInstr::Instr(Instruction::Or(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
)));
}
};
(instructions, vec![reg])
}
Expression::UnOp(op, expr) => {
let mut instructions = Vec::new();
let (instr, registers) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
for reg in &registers {
match op {
UnaryOperator::Negation => {
instructions.push(PreInstr::Instr(Instruction::Unm(*reg, *reg)))
}
UnaryOperator::Length => {
instructions.push(PreInstr::Instr(Instruction::Len(*reg, *reg)))
}
}
}
(instructions, registers)
}
Expression::FunctionDefinition(params, block) => {
let mut inner_scope = Scope::default();
for param in params {
match &param.kind {
IdentOrEllipsis::Ident(name) => {
inner_scope
.locals
.insert(name.clone(), inner_scope.register_counter.next());
}
IdentOrEllipsis::Ellipsis => {
inner_scope.is_vararg = true;
}
}
}
inner_scope.highest_upvalue = scope.highest_upvalue + scope.register_counter.0;
inner_scope.upvalues = scope.upvalues.clone();
for (name, reg) in &scope.locals {
let new_reg = *reg + scope.highest_upvalue + 1;
inner_scope.upvalues.insert(name.clone(), new_reg);
}
let instructions = block.compile(state, &mut inner_scope);
state.prototypes.push(Prototype {
instructions: process_pre_instrs(instructions),
parameters: if inner_scope.is_vararg {
params.len() - 1
} else {
params.len()
},
});
let mut instructions = Vec::new();
instructions.push(PreInstr::Instr(Instruction::Close(
scope.register_counter.0,
)));
let local = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::Closure(
local,
state.prototypes.len() as u32,
)));
(instructions, vec![local])
}
Expression::FunctionCall(expr, params) => {
let mut instructions = Vec::new();
let (instr, registers) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let old_function_reg = registers.first().unwrap();
let mut registers = scope
.register_counter
.consecutive(params.kind.0.len() + 1)
.into_iter();
let mut param_scope = scope.clone();
let mut original_param_regs = Vec::new();
let mut vararg = false;
for (i, param) in params.kind.0.iter().enumerate() {
let (instr, registers) = param.kind.compile(
state,
&mut param_scope,
if i == params.kind.0.len() - 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
if registers.len() > 0 {
original_param_regs.push(*registers.first().unwrap());
} else {
vararg = true;
}
}
let function_reg = registers.next().unwrap();
let mut param_regs = Vec::new();
for _ in &original_param_regs {
param_regs.push(registers.next().unwrap());
}
for (i, param_reg) in original_param_regs.iter().enumerate().rev() {
let new_reg = param_regs.get(i).unwrap();
if param_reg != new_reg {
instructions.push(PreInstr::Instr(Instruction::Move(*new_reg, *param_reg)));
}
}
if function_reg != *old_function_reg {
instructions.push(PreInstr::Instr(Instruction::Move(
function_reg,
*old_function_reg,
)));
}
if vararg {
let last_reg = param_regs.last().unwrap_or(&function_reg) + 1;
instructions.push(PreInstr::Instr(Instruction::MoveRetValues(last_reg)));
}
let last_param_reg = param_regs.last().unwrap_or(&function_reg);
let mut return_regs = Vec::new();
if let Some(expected_values) = expected_values {
for i in 0..expected_values {
let return_reg = i as u16 + function_reg;
if return_reg > *last_param_reg {
return_regs.push(scope.register_counter.next());
} else {
return_regs.push(return_reg);
}
}
}
instructions.push(PreInstr::Instr(Instruction::Call(
*&function_reg,
if vararg { 0 } else { param_regs.len() as u16 },
if return_regs.len() == 0 {
0
} else {
return_regs.len() as u16 + 1
},
)));
(instructions, return_regs)
}
Expression::Literal(literal) => {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::LoadK(
reg,
state.get_constant(&match literal {
Literal::Float(value) => Constant::Float(value.vm_number()),
Literal::String(value) => Constant::String(value.clone()),
Literal::Integer(lua_integer) => Constant::Integer(*lua_integer),
Literal::Bool(value) => Constant::Bool(LuaBool(*value)),
Literal::Nil => Constant::Nil,
}),
)));
(instructions, vec![reg])
}
Expression::TableConstructor(entries) => {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::NewTable(reg)));
let mut counter = 1;
for (i, (key, value)) in entries.iter().enumerate() {
if let Some(key) = key {
let (instr, key_regs) = key.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, value_regs) = value.kind.compile(state, scope, Some(1));
instructions.extend(instr);
instructions.push(PreInstr::Instr(Instruction::SetTable(
reg,
*key_regs.first().unwrap(),
*value_regs.first().unwrap(),
)));
} else {
let (instr, value_regs) = value.kind.compile(
state,
scope,
if i == entries.len() - 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
if value_regs.len() > 0 {
let key_reg = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::LoadK(
key_reg,
state.get_constant(&Constant::Integer(LuaInteger(counter))),
)));
instructions.push(PreInstr::Instr(Instruction::SetTable(
reg,
key_reg,
*value_regs.first().unwrap(),
)));
counter += 1;
} else {
instructions
.push(PreInstr::Instr(Instruction::SetList(reg, counter as u32)));
}
}
}
(instructions, vec![reg])
}
Expression::IndexedAccess(expr, index) => {
let mut instructions = Vec::new();
let (instr, expr_regs) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, index_regs) = index.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let local = scope.register_counter.next();
instructions.push(PreInstr::Instr(Instruction::GetTable(
local,
*expr_regs.first().unwrap(),
*index_regs.first().unwrap(),
)));
(instructions, vec![local])
}
Expression::Ellipsis => {
if !scope.is_vararg {
panic!("Function is not vararg!");
}
let mut instructions = Vec::new();
let new_reg = scope.register_counter.new();
if expected_values == None {
instructions.push(PreInstr::Instr(Instruction::Vararg(new_reg, 0)));
(instructions, Vec::new())
} else {
instructions.push(PreInstr::Instr(Instruction::Vararg(new_reg, 3)));
(instructions, vec![new_reg])
}
}
}
}
}