ferrite-lua/src/compile.rs
2026-03-18 19:23:14 +02:00

744 lines
30 KiB
Rust

use std::collections::{HashMap, HashSet};
use crate::{
ast::{
AccessModifier, BinaryOperator, Block, Expression, IdentOrEllipsis, Literal, Statement,
UnaryOperator,
},
vm::{Constant, Instruction, LuaBool, LuaInteger},
};
#[derive(Clone, Debug)]
pub struct State {
pub constants: Vec<Constant>,
pub prototypes: Vec<Vec<Instruction>>,
}
impl State {
pub fn get_constant(&self, value: &Constant) -> u32 {
self.constants
.iter()
.enumerate()
.find(|(_, c)| *c == value)
.unwrap()
.0 as u32
}
}
#[derive(Clone, Debug, Default)]
pub struct Scope {
pub locals: HashMap<String, u16>,
pub register_counter: LocalCounter,
pub highest_upvalue: u16,
pub upvalues: HashMap<String, u16>,
pub is_vararg: bool,
}
#[derive(Clone, Debug, Default)]
pub struct LocalCounter(u16, Vec<u16>);
impl LocalCounter {
pub fn next(&mut self) -> u16 {
if let Some(reg) = self.1.pop() {
reg
} else {
self.new()
}
}
pub fn consecutive(&mut self, amount: usize) -> Vec<u16> {
let mut returned = Vec::new();
for _ in 0..amount {
returned.push(self.new());
}
returned
}
pub fn new(&mut self) -> u16 {
let temp = self.0;
self.0 += 1;
temp
}
}
impl Block {
pub(crate) fn find_constants(
&self,
scope: &mut Scope,
constants: Vec<Constant>,
) -> HashSet<Constant> {
let mut constants = constants.iter().cloned().collect::<HashSet<_>>();
let mut inner_scope = scope.clone();
for statement in &self.statements {
constants.extend(statement.kind.find_constants(&mut inner_scope));
}
constants
}
pub(crate) fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
let mut inner_scope = scope.clone();
for statement in &self.statements {
instructions.extend(statement.kind.compile(state, &mut inner_scope));
}
instructions
}
}
impl Statement {
fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
match self {
Statement::Assignment(access, names, expr_list) => {
let mut constants = HashSet::new();
if *access == Some(AccessModifier::Global) {
for (name, indexes) in names {
constants.insert(Constant::String(name.kind.clone()));
for index in indexes {
constants.extend(index.kind.find_constants(scope));
}
}
} else if *access == None {
for (name, indexes) in names {
if !scope.locals.contains_key(&name.kind) {
constants.insert(Constant::String(name.kind.clone()));
}
for index in indexes {
constants.extend(index.kind.find_constants(scope));
}
}
} else {
for (name, indexes) in names {
scope.locals.insert(name.kind.clone(), 0);
for index in indexes {
constants.extend(index.kind.find_constants(scope));
}
}
}
for expr in &expr_list.0 {
constants.extend(expr.kind.find_constants(scope));
}
constants
}
Statement::Return(expr_list) => {
let mut constants = HashSet::new();
for expr in &expr_list.0 {
constants.extend(expr.kind.find_constants(scope));
}
constants
}
Statement::If(cond, then) => {
let mut constants = HashSet::new();
constants.extend(cond.kind.find_constants(scope));
constants.extend(then.find_constants(scope, Vec::new()));
constants
}
Statement::Expression(expr) => expr.kind.find_constants(scope),
}
}
fn compile(&self, state: &mut State, scope: &mut Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
match self {
Statement::Assignment(access_modifier, names, expr_list) => {
let new_registers = if *access_modifier == Some(AccessModifier::Local) {
let min_reg = scope.register_counter.0 + 1;
let max_reg = scope.register_counter.0 + names.len() as u16;
instructions.push(Instruction::LoadNil(min_reg, max_reg));
scope.register_counter.0 += names.len() as u16 + 1;
let mut new_registers = Vec::new();
for i in min_reg..=max_reg {
new_registers.push(i);
}
new_registers
} else {
Vec::new()
};
let mut expr_regs = Vec::new();
for expr in &expr_list.0 {
let (instr, regs) = expr.kind.compile(
state,
scope,
if expr_list.0.len() == 1 {
Some(names.len())
} else {
Some(1)
},
);
instructions.extend(instr);
expr_regs.extend(regs);
}
match access_modifier {
Some(AccessModifier::Local) => {
for (i, (name, indexes)) in names.iter().enumerate() {
instructions.push(Instruction::Move(
*new_registers.get(i).unwrap(),
expr_regs.get(i).cloned().unwrap(),
));
if indexes.len() > 0 {
todo!()
}
scope
.locals
.insert(name.kind.clone(), *new_registers.get(i).unwrap());
}
}
Some(AccessModifier::Global) => {
for (i, (name, indexes)) in names.iter().enumerate() {
if indexes.len() > 0 {
todo!()
}
let global = state.get_constant(&Constant::String(name.kind.clone()));
instructions.push(Instruction::SetGlobal(
expr_regs.get(i).cloned().unwrap(),
global,
));
}
}
None => {
for (i, (name, indexes)) in names.iter().enumerate() {
if indexes.len() > 0 {
let table_reg = if let Some(reg) = scope.locals.get(&name.kind) {
*reg
} else if let Some(upval_reg) = scope.upvalues.get(&name.kind) {
let local = scope.register_counter.next();
instructions.push(Instruction::GetUpVal(local, *upval_reg));
local
} else {
let global =
state.get_constant(&Constant::String(name.kind.clone()));
let local = scope.register_counter.next();
instructions.push(Instruction::GetGlobal(local, global));
local
};
if indexes.len() > 1 {
for (_, index) in indexes
.iter()
.enumerate()
.take_while(|(i, _)| *i != indexes.len() - 1)
{
let (instr, index_reg) =
index.kind.compile(state, scope, Some(1));
instructions.extend(instr);
instructions.push(Instruction::GetTable(
table_reg,
table_reg,
*index_reg.first().unwrap(),
));
}
}
let (instr, index_reg) =
indexes.last().unwrap().kind.compile(state, scope, Some(1));
instructions.extend(instr);
instructions.push(Instruction::SetTable(
table_reg,
*index_reg.first().unwrap(),
expr_regs.get(i).cloned().unwrap(),
));
} else {
if let Some(reg) = scope.locals.get(&name.kind) {
instructions.push(Instruction::Move(
*reg,
expr_regs.get(i).cloned().unwrap(),
));
} else if let Some(upval_reg) = scope.upvalues.get(&name.kind) {
instructions.push(Instruction::SetUpVal(
*upval_reg,
expr_regs.get(i).cloned().unwrap(),
));
} else {
let global =
state.get_constant(&Constant::String(name.kind.clone()));
instructions.push(Instruction::SetGlobal(
expr_regs.get(i).cloned().unwrap(),
global,
));
}
}
}
}
}
}
Statement::Return(expr_list) => {
let mut ret_registers = Vec::new();
let mut vararg = false;
for (i, expr) in expr_list.0.iter().enumerate() {
let (instr, registers) = expr.kind.compile(
state,
scope,
if i == expr_list.0.len() - 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
if registers.len() == 0 {
vararg = true;
}
ret_registers.extend(registers);
}
let first_ret_register = ret_registers
.iter()
.cloned()
.next()
.unwrap_or(scope.register_counter.0);
for (i, ret_register) in ret_registers.iter_mut().enumerate() {
let new_reg = first_ret_register + i as u16;
if *ret_register != new_reg {
instructions.push(Instruction::Move(new_reg, *ret_register));
}
*ret_register = new_reg;
}
if vararg {
instructions.push(Instruction::MoveRetValues(
first_ret_register + ret_registers.len() as u16,
));
}
dbg!(&first_ret_register);
instructions.push(Instruction::Return(
first_ret_register,
if vararg {
0
} else {
*ret_registers.last().unwrap_or(&0)
},
));
}
Statement::If(expr, block) => {
let (instr, regs) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let local = scope.register_counter.next();
instructions.push(Instruction::Test(local, *regs.first().unwrap(), 0));
let block_instructions = block.compile(state, scope);
instructions.push(Instruction::Jmp(block_instructions.len() as i32));
instructions.extend(block_instructions);
}
Statement::Expression(expr) => {
let (instr, _) = expr.kind.compile(state, scope, None);
instructions.extend(instr);
}
}
for reg in 0..scope.register_counter.0 {
if scope.locals.values().find(|r| **r == reg).is_some() {
// Register is still in use
continue;
}
if scope.register_counter.1.contains(&reg) {
// Register is already in the list of unused registers
continue;
}
scope.register_counter.1.push(reg);
}
instructions
}
}
impl Expression {
fn find_constants(&self, scope: &mut Scope) -> HashSet<Constant> {
match self {
Expression::ValueRef(name) => {
let mut constants = HashSet::new();
if !scope.locals.contains_key(name) {
constants.insert(Constant::String(name.clone()));
}
constants
}
Expression::BinOp(_, lhs, rhs) => {
let mut constants = HashSet::new();
constants.extend(lhs.kind.find_constants(scope));
constants.extend(rhs.kind.find_constants(scope));
constants
}
Expression::UnOp(_, expr) => expr.kind.find_constants(scope),
Expression::FunctionDefinition(_, block) => block.find_constants(scope, Vec::new()),
Expression::FunctionCall(expr, params) => {
let mut constants = HashSet::new();
constants.extend(expr.kind.find_constants(scope));
for param in &params.kind.0 {
constants.extend(param.kind.find_constants(scope));
}
constants
}
Expression::Literal(literal) => match literal {
Literal::Float(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Float(value.vm_number()));
constants
}
Literal::Integer(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Integer(*value));
constants
}
Literal::String(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::String(value.clone()));
constants
}
Literal::Bool(value) => {
let mut constants = HashSet::new();
constants.insert(Constant::Bool(LuaBool(*value)));
constants
}
Literal::Nil => {
let mut constants = HashSet::new();
constants.insert(Constant::Nil);
constants
}
},
Expression::TableConstructor(entries) => {
let mut constants = HashSet::new();
let mut counter = 1;
for (key, value) in entries {
if let Some(key) = key {
constants.extend(key.kind.find_constants(scope));
} else {
constants.insert(Constant::Integer(LuaInteger(counter)));
counter += 1;
}
constants.extend(value.kind.find_constants(scope));
}
constants
}
Expression::IndexedAccess(expr, index) => {
let mut constants = HashSet::new();
constants.extend(expr.kind.find_constants(scope));
constants.extend(index.kind.find_constants(scope));
constants
}
Expression::Ellipsis => HashSet::new(),
}
}
fn compile(
&self,
state: &mut State,
scope: &mut Scope,
expected_values: Option<usize>,
) -> (Vec<Instruction>, Vec<u16>) {
match self {
Expression::ValueRef(name) => {
if let Some(reg) = scope.locals.get(name) {
(Vec::new(), vec![*reg])
} else if let Some(upvalue) = scope.upvalues.get(name) {
let local = scope.register_counter.next();
(vec![Instruction::GetUpVal(local, *upvalue)], vec![local])
} else {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(Instruction::GetGlobal(
reg,
state.get_constant(&Constant::String(name.clone())),
));
(instructions, vec![reg])
}
}
Expression::BinOp(binary_operator, lhs, rhs) => {
let mut instructions = Vec::new();
let (instr, lhs) = lhs.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, rhs) = rhs.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let reg = scope.register_counter.next();
match binary_operator {
BinaryOperator::Equal => {
instructions.push(Instruction::Equal(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
BinaryOperator::LessThan => {
instructions.push(Instruction::LessThan(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
BinaryOperator::LessThanOrEqual => {
instructions.push(Instruction::LessThanOrEqual(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
BinaryOperator::GreaterThan => {
instructions.push(Instruction::LessThan(
reg,
*rhs.get(0).unwrap(),
*lhs.get(0).unwrap(),
));
}
BinaryOperator::GreaterThanOrEqual => {
instructions.push(Instruction::LessThanOrEqual(
reg,
*rhs.get(0).unwrap(),
*lhs.get(0).unwrap(),
));
}
BinaryOperator::Add => {
instructions.push(Instruction::Add(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
BinaryOperator::Sub => {
instructions.push(Instruction::Unm(reg, *rhs.get(0).unwrap()));
instructions.push(Instruction::Add(reg, *lhs.get(0).unwrap(), reg));
}
BinaryOperator::And => {
instructions.push(Instruction::And(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
BinaryOperator::Or => {
instructions.push(Instruction::Or(
reg,
*lhs.get(0).unwrap(),
*rhs.get(0).unwrap(),
));
}
};
(instructions, vec![reg])
}
Expression::UnOp(op, expr) => {
let mut instructions = Vec::new();
let (instr, registers) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
for reg in &registers {
match op {
UnaryOperator::Negation => instructions.push(Instruction::Unm(*reg, *reg)),
UnaryOperator::Length => instructions.push(Instruction::Len(*reg, *reg)),
}
}
(instructions, registers)
}
Expression::FunctionDefinition(params, block) => {
let mut inner_scope = Scope::default();
for param in params {
match &param.kind {
IdentOrEllipsis::Ident(name) => {
inner_scope
.locals
.insert(name.clone(), inner_scope.register_counter.next());
}
IdentOrEllipsis::Ellipsis => {
inner_scope.is_vararg = true;
}
}
}
inner_scope.highest_upvalue = scope.highest_upvalue + scope.register_counter.0;
inner_scope.upvalues = scope.upvalues.clone();
for (name, reg) in &scope.locals {
let new_reg = *reg + scope.highest_upvalue + 1;
inner_scope.upvalues.insert(name.clone(), new_reg);
}
let instructions = block.compile(state, &mut inner_scope);
state.prototypes.push(instructions);
let mut instructions = Vec::new();
instructions.push(Instruction::Close(scope.register_counter.0));
let local = scope.register_counter.next();
instructions.push(Instruction::Closure(local, state.prototypes.len() as u32));
(instructions, vec![local])
}
Expression::FunctionCall(expr, params) => {
let mut instructions = Vec::new();
let (instr, registers) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let old_function_reg = registers.first().unwrap();
let mut registers = scope
.register_counter
.consecutive(params.kind.0.len() + 1)
.into_iter();
let mut param_scope = scope.clone();
let mut original_param_regs = Vec::new();
let mut vararg = false;
for (i, param) in params.kind.0.iter().enumerate() {
let (instr, registers) = param.kind.compile(
state,
&mut param_scope,
if i == params.kind.0.len() - 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
if registers.len() > 0 {
original_param_regs.push(*registers.first().unwrap());
} else {
vararg = true;
}
}
let function_reg = registers.next().unwrap();
let mut param_regs = Vec::new();
for _ in &original_param_regs {
param_regs.push(registers.next().unwrap());
}
for (i, param_reg) in original_param_regs.iter().enumerate().rev() {
let new_reg = param_regs.get(i).unwrap();
if param_reg != new_reg {
instructions.push(Instruction::Move(*new_reg, *param_reg));
}
}
if function_reg != *old_function_reg {
instructions.push(Instruction::Move(function_reg, *old_function_reg));
}
if vararg {
let last_reg = param_regs.last().unwrap_or(&function_reg) + 1;
instructions.push(Instruction::MoveRetValues(last_reg));
}
let last_param_reg = param_regs.last().unwrap_or(&function_reg);
let mut return_regs = Vec::new();
if let Some(expected_values) = expected_values {
for i in 0..expected_values {
let return_reg = i as u16 + function_reg;
if return_reg > *last_param_reg {
return_regs.push(scope.register_counter.next());
} else {
return_regs.push(return_reg);
}
}
}
instructions.push(Instruction::Call(
*&function_reg,
if vararg { 0 } else { param_regs.len() as u16 },
if return_regs.len() == 0 {
0
} else {
return_regs.len() as u16 + 1
},
));
(instructions, return_regs)
}
Expression::Literal(literal) => {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(Instruction::LoadK(
reg,
state.get_constant(&match literal {
Literal::Float(value) => Constant::Float(value.vm_number()),
Literal::String(value) => Constant::String(value.clone()),
Literal::Integer(lua_integer) => Constant::Integer(*lua_integer),
Literal::Bool(value) => Constant::Bool(LuaBool(*value)),
Literal::Nil => Constant::Nil,
}),
));
(instructions, vec![reg])
}
Expression::TableConstructor(entries) => {
let mut instructions = Vec::new();
let reg = scope.register_counter.next();
instructions.push(Instruction::NewTable(reg));
let mut counter = 1;
for (i, (key, value)) in entries.iter().enumerate() {
if let Some(key) = key {
let (instr, key_regs) = key.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, value_regs) = value.kind.compile(state, scope, Some(1));
instructions.extend(instr);
instructions.push(Instruction::SetTable(
reg,
*key_regs.first().unwrap(),
*value_regs.first().unwrap(),
));
} else {
let (instr, value_regs) = value.kind.compile(
state,
scope,
if i == entries.len() - 1 {
None
} else {
Some(1)
},
);
instructions.extend(instr);
if value_regs.len() > 0 {
let key_reg = scope.register_counter.next();
instructions.push(Instruction::LoadK(
key_reg,
state.get_constant(&Constant::Integer(LuaInteger(counter))),
));
instructions.push(Instruction::SetTable(
reg,
key_reg,
*value_regs.first().unwrap(),
));
counter += 1;
} else {
instructions.push(Instruction::SetList(reg, counter as u32));
}
}
}
(instructions, vec![reg])
}
Expression::IndexedAccess(expr, index) => {
let mut instructions = Vec::new();
let (instr, expr_regs) = expr.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let (instr, index_regs) = index.kind.compile(state, scope, Some(1));
instructions.extend(instr);
let local = scope.register_counter.next();
instructions.push(Instruction::GetTable(
local,
*expr_regs.first().unwrap(),
*index_regs.first().unwrap(),
));
(instructions, vec![local])
}
Expression::Ellipsis => {
if !scope.is_vararg {
panic!("Function is not vararg!");
}
let mut instructions = Vec::new();
let new_reg = scope.register_counter.new();
instructions.push(Instruction::Vararg(new_reg));
(instructions, Vec::new())
}
}
}
}