Compare commits

...

6 Commits

Author SHA1 Message Date
e3e0cd9f9f Check return type as well 2026-04-13 17:25:36 +03:00
3f4b8569ea Fix bug with parameter typechecking 2026-04-13 17:23:53 +03:00
185f25d412 Typecheck parameter types 2026-04-13 17:21:07 +03:00
a901806dfb Add a function to test if types are equal 2026-04-13 17:09:50 +03:00
8ec4e538f5 Fix a bug, add type kind 2026-04-13 16:56:33 +03:00
cf965dd47a Fix some bugs 2026-04-13 16:50:29 +03:00
3 changed files with 133 additions and 15 deletions

View File

@ -2,6 +2,42 @@
#include "typechecker.h" #include "typechecker.h"
#include "ast.h" #include "ast.h"
#include "errors.h" #include "errors.h"
#include "result.h"
namespace {
enum class TypecheckRes {
Ok,
Castable,
};
Result<TypecheckRes, std::string> check_type(std::shared_ptr<types::Type> checked, std::shared_ptr<types::Type> target) {
if (types::types_equal(checked, target)) {
return TypecheckRes::Ok;
}
return std::string{ "Types " + checked->formatted() + " and " + target->formatted() + " incompatible" };
}
std::unique_ptr<AST::Expression> handle_res(
std::unique_ptr<AST::Expression> expr,
Result<TypecheckRes, std::string> res,
typecheck::State& state) {
if (res.ok()) {
auto result = res.unwrap();
if (result == TypecheckRes::Ok) {
return expr;
}
else {
state.errors.push_back(CompileError("Casting not yet implemented", expr->m_meta));
return expr;
}
}
else {
state.errors.push_back(CompileError(res.unwrap_err(), expr->m_meta));
return expr;
}
}
}
namespace AST { namespace AST {
std::shared_ptr<types::Type> IntLiteralExpression::typecheck( std::shared_ptr<types::Type> IntLiteralExpression::typecheck(
@ -27,13 +63,20 @@ namespace AST {
} }
std::shared_ptr<types::Type> ValueReferenceExpression::typecheck( std::shared_ptr<types::Type> ValueReferenceExpression::typecheck(
typecheck::State&, typecheck::State& state,
typecheck::Scope& scope, typecheck::Scope& scope,
std::optional<std::shared_ptr<types::Type>> std::optional<std::shared_ptr<types::Type>>
) { ) {
if (scope.symbols.find(this->m_name) != scope.symbols.end()) {
return scope.symbols[this->m_name]; return scope.symbols[this->m_name];
} }
state.errors.push_back(CompileError("Value " + this->m_name + " not defined", this->m_meta));
return std::shared_ptr<types::Type>{
new types::FundamentalType{ types::FundamentalTypeKind::Void }
};
}
std::shared_ptr<types::Type> BinaryOperationExpression::typecheck( std::shared_ptr<types::Type> BinaryOperationExpression::typecheck(
typecheck::State& state, typecheck::State& state,
typecheck::Scope& scope, typecheck::Scope& scope,
@ -54,7 +97,12 @@ namespace AST {
) { ) {
auto expr_ty = this->m_fn_expr->typecheck(state, scope, {}); auto expr_ty = this->m_fn_expr->typecheck(state, scope, {});
// TODO make sure function_ty really is a function type if (expr_ty->m_kind != types::TypeKind::Function) {
state.errors.push_back(CompileError("Tried calling a non-function", this->m_meta));
return std::shared_ptr<types::Type> {
new types::FundamentalType{ types::FundamentalTypeKind::Void }
};
}
auto fn_ty = dynamic_cast<types::FunctionType*>(expr_ty.get()); auto fn_ty = dynamic_cast<types::FunctionType*>(expr_ty.get());
@ -65,11 +113,17 @@ namespace AST {
state.errors.push_back(CompileError("too many arguments", this->m_meta)); state.errors.push_back(CompileError("too many arguments", this->m_meta));
} }
else { else {
for (int i = 0; i < static_cast<int>(fn_ty->m_param_tys.size()); i++) { for (int i = 0; i < static_cast<int>(this->m_args.size()); i++) {
if (i < static_cast<int>(fn_ty->m_param_tys.size())) {
auto expected_param_ty = fn_ty->m_param_tys[i]; auto expected_param_ty = fn_ty->m_param_tys[i];
auto param_ty = this->m_args[i]->typecheck(state, scope, expected_param_ty); auto param_ty = this->m_args[i]->typecheck(state, scope, expected_param_ty);
// TODO make sure types actually match auto check_res = check_type(param_ty, expected_param_ty);
this->m_args[i] = handle_res(std::move(this->m_args[i]), check_res, state);
}
else {
this->m_args[i]->typecheck(state, scope, {});
}
} }
} }
@ -77,7 +131,11 @@ namespace AST {
} }
void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { void ReturnStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) {
this->m_expr->typecheck(state, scope, scope.return_ty); auto res_ty = this->m_expr->typecheck(state, scope, scope.return_ty);
if (scope.return_ty) {
auto check_res = check_type(res_ty, *scope.return_ty);
this->m_expr = handle_res(std::move(this->m_expr), check_res, state);
}
} }
void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) { void InitializationStatement::typecheck(typecheck::State& state, typecheck::Scope& scope) {
@ -95,6 +153,8 @@ namespace AST {
auto bool_ty_ptr = new types::FundamentalType{ types::FundamentalTypeKind::Bool }; auto bool_ty_ptr = new types::FundamentalType{ types::FundamentalTypeKind::Bool };
this->m_condition->typecheck(state, scope, std::shared_ptr<types::Type>{ bool_ty_ptr }); this->m_condition->typecheck(state, scope, std::shared_ptr<types::Type>{ bool_ty_ptr });
// TODO check that condition really is a boolean
this->m_then->typecheck(state, scope); this->m_then->typecheck(state, scope);
if (this->m_else) { if (this->m_else) {
(*this->m_else)->typecheck(state, scope); (*this->m_else)->typecheck(state, scope);
@ -112,6 +172,13 @@ namespace AST {
scope.symbols[this->m_name] = std::shared_ptr<types::Type>{ function_ty }; scope.symbols[this->m_name] = std::shared_ptr<types::Type>{ function_ty };
typecheck::Scope inner{ scope }; typecheck::Scope inner{ scope };
inner.return_ty = return_ty;
for (auto& param : this->m_params) {
if (param.first) {
inner.symbols[*param.first] = param.second;
}
}
if (this->m_statements) { if (this->m_statements) {
for (auto& statement : *this->m_statements) { for (auto& statement : *this->m_statements) {

View File

@ -165,4 +165,44 @@ namespace types {
this->m_inner this->m_inner
); );
} }
bool types_equal(std::shared_ptr<types::Type> type1, std::shared_ptr<types::Type> type2) {
if (type1->m_kind != type2->m_kind)
return false;
if (type1->m_kind == TypeKind::Fundamental) {
auto ty1 = dynamic_cast<FundamentalType*>(type1.get());
auto ty2 = dynamic_cast<FundamentalType*>(type2.get());
return ty1->m_ty == ty2->m_ty;
}
else if (type1->m_kind == TypeKind::Function) {
auto ty1 = dynamic_cast<FunctionType*>(type1.get());
auto ty2 = dynamic_cast<FunctionType*>(type2.get());
if (!types_equal(ty1->m_ret_ty, ty2->m_ret_ty))
return false;
if (ty1->m_vararg != ty2->m_vararg)
return false;
if (ty1->m_param_tys.size() != ty2->m_param_tys.size())
return false;
for (int i = 0; i < static_cast<int>(ty1->m_param_tys.size()); i++) {
auto param1 = ty1->m_param_tys[i];
auto param2 = ty2->m_param_tys[i];
if (!types_equal(param1, param2))
return false;
}
return true;
}
else if (type1->m_kind == TypeKind::Pointer) {
auto ty1 = dynamic_cast<PointerType*>(type1.get());
auto ty2 = dynamic_cast<PointerType*>(type2.get());
return types_equal(ty1->m_inner, ty2->m_inner);
}
else {
return false;
}
}
} }

View File

@ -18,6 +18,12 @@ namespace types {
int operator_precedence(BinOp& op); int operator_precedence(BinOp& op);
std::string format_operator(BinOp& op); std::string format_operator(BinOp& op);
enum class TypeKind {
Fundamental,
Function,
Pointer,
};
enum FundamentalTypeKind { enum FundamentalTypeKind {
Int, Int,
Bool, Bool,
@ -27,6 +33,8 @@ namespace types {
class Type { class Type {
public: public:
TypeKind m_kind;
Type(TypeKind kind) : m_kind{ kind } {}
virtual ~Type() = default; virtual ~Type() = default;
virtual std::string formatted() = 0; virtual std::string formatted() = 0;
virtual llvm::Type* codegen(codegen::Builder& builder) = 0; virtual llvm::Type* codegen(codegen::Builder& builder) = 0;
@ -39,10 +47,9 @@ namespace types {
}; };
class FundamentalType : public Type { class FundamentalType : public Type {
private:
FundamentalTypeKind m_ty;
public: public:
FundamentalType(FundamentalTypeKind kind) : m_ty{ kind } {} FundamentalTypeKind m_ty;
FundamentalType(FundamentalTypeKind kind) : Type(TypeKind::Fundamental), m_ty{ kind } {}
virtual ~FundamentalType() override = default; virtual ~FundamentalType() override = default;
virtual std::string formatted() override; virtual std::string formatted() override;
virtual llvm::Type* codegen(codegen::Builder& builder) override; virtual llvm::Type* codegen(codegen::Builder& builder) override;
@ -60,7 +67,10 @@ namespace types {
std::vector<std::shared_ptr<Type>> m_param_tys; std::vector<std::shared_ptr<Type>> m_param_tys;
bool m_vararg; bool m_vararg;
FunctionType(std::shared_ptr<Type> ret_ty, std::vector<std::shared_ptr<Type>> param_tys, bool vararg) FunctionType(std::shared_ptr<Type> ret_ty, std::vector<std::shared_ptr<Type>> param_tys, bool vararg)
: m_ret_ty{ std::move(ret_ty) }, m_param_tys{ std::move(param_tys) }, m_vararg{ vararg } { : Type(TypeKind::Function)
, m_ret_ty{ std::move(ret_ty) }
, m_param_tys{ std::move(param_tys) }
, m_vararg{ vararg } {
} }
virtual ~FunctionType() override = default; virtual ~FunctionType() override = default;
virtual std::string formatted() override; virtual std::string formatted() override;
@ -72,17 +82,18 @@ namespace types {
class PointerType : public Type { class PointerType : public Type {
private:
std::shared_ptr<Type> m_inner;
public: public:
std::shared_ptr<Type> m_inner;
PointerType(std::shared_ptr<Type> inner) PointerType(std::shared_ptr<Type> inner)
: m_inner{ std::move(inner) } { : Type(TypeKind::Pointer), m_inner{ std::move(inner) } {
} }
virtual ~PointerType() override = default; virtual ~PointerType() override = default;
virtual std::string formatted() override; virtual std::string formatted() override;
virtual llvm::Type* codegen(codegen::Builder& builder) override; virtual llvm::Type* codegen(codegen::Builder& builder) override;
virtual std::pair<llvm::Value*, std::shared_ptr<Type>> load(codegen::Builder& builder, llvm::Value* ptr) override; virtual std::pair<llvm::Value*, std::shared_ptr<Type>> load(codegen::Builder& builder, llvm::Value* ptr) override;
}; };
bool types_equal(std::shared_ptr<types::Type> type1, std::shared_ptr<types::Type> type2);
} }
#endif #endif