c-compiler/src/codegen.cpp

255 lines
8.6 KiB
C++

#include "codegen.h"
#include "ast.h"
#include "types.h"
#include "errors.h"
#include <llvm/IR/Module.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/Casting.h>
#include <memory>
#include <iostream>
namespace codegen {
Scope Scope::with_lvalue() {
return Scope{ this->binops, this->values, true };
}
}
namespace AST {
codegen::StackValue IntLiteralExpression::codegen(codegen::Builder& builder, codegen::Scope&) {
auto ty = builder.builder->getInt32Ty();
auto stack_type = new types::FundamentalType{ types::FundamentalTypeKind::Int };
return codegen::StackValue{
llvm::ConstantInt::get(ty, this->m_value),
std::unique_ptr<types::Type>{stack_type}
};
}
codegen::StackValue StringLiteralExpression::codegen(codegen::Builder& builder, codegen::Scope&) {
auto stack_type = new types::PointerType{ std::make_unique<types::FundamentalType>(types::FundamentalTypeKind::Char) };
auto str = llvm::StringRef{ this->m_value.c_str() };
return codegen::StackValue{
builder.builder->CreateGlobalString(str),
std::unique_ptr<types::Type>{stack_type},
};
}
codegen::StackValue ValueReferenceExpression::codegen(codegen::Builder& builder, codegen::Scope& scope) {
auto value = scope.values.find(this->m_name);
if (value != scope.values.end()) {
if (scope.is_lvalue) {
return value->second;
}
else {
auto loaded = value->second.ty->load(builder, value->second.value);
return codegen::StackValue{
loaded.first,
loaded.second
};
}
}
else {
throw CompileError("Value " + this->m_name + " not found", this->m_meta);
}
}
codegen::StackValue BinaryOperationExpression::codegen(codegen::Builder& builder, codegen::Scope& scope) {
auto lvalued = scope.with_lvalue();
auto lhs = this->m_lhs->codegen(builder, this->m_binop == types::BinOp::Assignment ? lvalued : scope);
auto rhs = this->m_rhs->codegen(builder, scope);
try {
switch (this->m_binop) {
case types::BinOp::Assignment:
builder.builder->CreateStore(rhs.value, lhs.value, false);
return rhs;
default:
auto binop = types::find_binop(
scope.binops,
lhs.ty,
this->m_binop,
rhs.ty);
if (binop) {
return codegen::StackValue{
binop->codegen(builder, lhs.value, rhs.value),
binop->result
};
}
throw CompileError("invalid binop", this->m_meta);
}
}
catch (std::runtime_error& error) {
throw CompileError(error.what(), this->m_meta);
}
}
codegen::StackValue FunctionCallExpression::codegen(codegen::Builder& builder, codegen::Scope& scope) {
std::vector<llvm::Value*> args{};
for (auto& arg : this->m_args) {
args.push_back(arg->codegen(builder, scope).value);
}
auto function = this->m_fn_expr->codegen(builder, scope);
auto value = builder.builder->CreateCall(llvm::dyn_cast<llvm::FunctionType>(function.ty->codegen(builder)), function.value, args, "call");
return codegen::StackValue{
value,
*function.ty->return_type(),
};
}
void ReturnStatement::codegen(codegen::Builder& builder, codegen::Scope& scope) {
if (!builder.block)
return;
builder.builder->SetInsertPoint(builder.block);
auto value = this->m_expr->codegen(builder, scope);
builder.builder->CreateRet(value.value);
}
void ExpressionStatement::codegen(codegen::Builder& builder, codegen::Scope& scope) {
if (!builder.block)
return;
builder.builder->SetInsertPoint(builder.block);
this->m_expr->codegen(builder, scope);
}
void InitializationStatement::codegen(codegen::Builder& builder, codegen::Scope& scope) {
if (!builder.block)
return;
builder.builder->SetInsertPoint(builder.block);
auto ty = this->m_type->codegen(builder);
auto ptr = builder.builder->CreateAlloca(ty);
if (this->m_expr.has_value()) {
auto value = this->m_expr->get()->codegen(builder, scope);
builder.builder->CreateStore(value.value, ptr, false);
}
scope.values[this->m_name] = codegen::StackValue{ ptr, this->m_type };
}
void IfStatement::codegen(codegen::Builder& builder, codegen::Scope& scope) {
if (!builder.block)
return;
builder.builder->SetInsertPoint(builder.block);
auto condition = this->m_condition->codegen(builder, scope);
auto function = builder.block->getParent();
auto then_block = llvm::BasicBlock::Create(*builder.context, "then", function);
std::optional<llvm::BasicBlock*> else_block{};
if (this->m_else.has_value())
else_block = llvm::BasicBlock::Create(*builder.context, "else", function);
auto after_block = llvm::BasicBlock::Create(*builder.context, "after", function);
builder.builder->CreateCondBr(condition.value, then_block, else_block.value_or(after_block));
builder.block = then_block;
builder.builder->SetInsertPoint(then_block);
this->m_then->codegen(builder, scope);
builder.builder->CreateBr(after_block);
if (else_block.has_value()) {
builder.block = *else_block;
builder.builder->SetInsertPoint(*else_block);
this->m_else->get()->codegen(builder, scope);
builder.builder->CreateBr(after_block);
}
builder.block = after_block;
builder.builder->SetInsertPoint(after_block);
}
void Function::codegen(codegen::Builder& builder, codegen::Scope& scope) {
std::shared_ptr<types::Type> ret_ty_ptr{ this->m_return_ty };
std::vector<std::shared_ptr<types::Type>> param_ty_ptrs{};
for (auto& param : this->m_params) {
param_ty_ptrs.push_back(param.second);
}
auto fn_ty_ptr = std::shared_ptr<types::Type>{ new types::FunctionType{ ret_ty_ptr, param_ty_ptrs, this->m_is_vararg } };
auto fn_ty = fn_ty_ptr->codegen(builder);
auto function = llvm::Function::Create(
llvm::dyn_cast<llvm::FunctionType>(fn_ty),
llvm::GlobalValue::LinkageTypes::ExternalLinkage,
this->m_name,
builder.mod.get()
);
scope.values[this->m_name] = codegen::StackValue{ function, fn_ty_ptr };
if (this->m_statements) {
auto BB = llvm::BasicBlock::Create(*builder.context, "entry", function, nullptr);
builder.block = BB;
codegen::Scope inner_scope{ scope };
int counter = 0;
for (auto& param : this->m_params) {
auto param_ty_ptr = param_ty_ptrs[counter];
auto arg = function->getArg(counter++);
if (param.first) {
arg->setName(*param.first);
inner_scope.values[*param.first] = codegen::StackValue{
arg,
param_ty_ptr,
};
}
}
for (auto& statement : *this->m_statements) {
statement->codegen(builder, inner_scope);
}
}
llvm::verifyFunction(*function);
builder.block = nullptr;
}
}
namespace types {
llvm::Type* FundamentalType::codegen(codegen::Builder& builder) {
switch (this->m_ty) {
case FundamentalTypeKind::Int:
return builder.builder->getInt32Ty();
case FundamentalTypeKind::Bool:
return builder.builder->getInt1Ty();
case FundamentalTypeKind::Char:
return builder.builder->getInt8Ty();
case FundamentalTypeKind::Void:
return builder.builder->getVoidTy();
default:
return builder.builder->getVoidTy();
}
}
llvm::Type* FunctionType::codegen(codegen::Builder& builder) {
std::vector<llvm::Type*> params{};
for (auto& param : this->m_param_tys) {
params.push_back(param->codegen(builder));
}
auto ret_ty = this->m_ret_ty->codegen(builder);
return llvm::FunctionType::get(ret_ty, params, this->m_vararg);
}
llvm::Type* PointerType::codegen(codegen::Builder& builder) {
return llvm::PointerType::get(*builder.context, 0);
}
}