Add a function to test if types are equal

This commit is contained in:
Sofia 2026-04-13 17:09:50 +03:00
parent 8ec4e538f5
commit a901806dfb
3 changed files with 50 additions and 4 deletions

View File

@ -3,6 +3,10 @@
#include "ast.h"
#include "errors.h"
namespace {
}
namespace AST {
std::shared_ptr<types::Type> IntLiteralExpression::typecheck(
typecheck::State&,
@ -109,6 +113,8 @@ namespace AST {
auto bool_ty_ptr = new types::FundamentalType{ types::FundamentalTypeKind::Bool };
this->m_condition->typecheck(state, scope, std::shared_ptr<types::Type>{ bool_ty_ptr });
// TODO check that condition really is a boolean
this->m_then->typecheck(state, scope);
if (this->m_else) {
(*this->m_else)->typecheck(state, scope);

View File

@ -165,4 +165,44 @@ namespace types {
this->m_inner
);
}
bool types_equal(std::shared_ptr<types::Type> type1, std::shared_ptr<types::Type> type2) {
if (type1->m_kind != type2->m_kind)
return false;
if (type1->m_kind == TypeKind::Fundamental) {
auto ty1 = dynamic_cast<FundamentalType*>(type1.get());
auto ty2 = dynamic_cast<FundamentalType*>(type2.get());
return ty1->m_ty == ty2->m_ty;
}
else if (type1->m_kind == TypeKind::Function) {
auto ty1 = dynamic_cast<FunctionType*>(type1.get());
auto ty2 = dynamic_cast<FunctionType*>(type2.get());
if (!types_equal(ty1->m_ret_ty, ty2->m_ret_ty))
return false;
if (ty1->m_vararg != ty2->m_vararg)
return false;
if (ty1->m_param_tys.size() != ty2->m_param_tys.size())
return false;
for (int i = 0; i < static_cast<int>(ty1->m_param_tys.size()); i++) {
auto param1 = ty1->m_param_tys[i];
auto param2 = ty2->m_param_tys[i];
if (!types_equal(param1, param2))
return false;
}
return true;
}
else if (type1->m_kind == TypeKind::Pointer) {
auto ty1 = dynamic_cast<PointerType*>(type1.get());
auto ty2 = dynamic_cast<PointerType*>(type2.get());
return types_equal(ty1->m_inner, ty2->m_inner);
}
else {
return false;
}
}
}

View File

@ -47,9 +47,8 @@ namespace types {
};
class FundamentalType : public Type {
private:
FundamentalTypeKind m_ty;
public:
FundamentalTypeKind m_ty;
FundamentalType(FundamentalTypeKind kind) : Type(TypeKind::Fundamental), m_ty{ kind } {}
virtual ~FundamentalType() override = default;
virtual std::string formatted() override;
@ -83,9 +82,8 @@ namespace types {
class PointerType : public Type {
private:
std::shared_ptr<Type> m_inner;
public:
std::shared_ptr<Type> m_inner;
PointerType(std::shared_ptr<Type> inner)
: Type(TypeKind::Pointer), m_inner{ std::move(inner) } {
}
@ -94,6 +92,8 @@ namespace types {
virtual llvm::Type* codegen(codegen::Builder& builder) override;
virtual std::pair<llvm::Value*, std::shared_ptr<Type>> load(codegen::Builder& builder, llvm::Value* ptr) override;
};
bool types_equal(std::shared_ptr<types::Type> type1, std::shared_ptr<types::Type> type2);
}
#endif