diff --git a/src/typechecker.cpp b/src/typechecker.cpp index bd98d26..913be10 100644 --- a/src/typechecker.cpp +++ b/src/typechecker.cpp @@ -3,6 +3,10 @@ #include "ast.h" #include "errors.h" +namespace { + +} + namespace AST { std::shared_ptr IntLiteralExpression::typecheck( typecheck::State&, @@ -109,6 +113,8 @@ namespace AST { auto bool_ty_ptr = new types::FundamentalType{ types::FundamentalTypeKind::Bool }; this->m_condition->typecheck(state, scope, std::shared_ptr{ bool_ty_ptr }); + // TODO check that condition really is a boolean + this->m_then->typecheck(state, scope); if (this->m_else) { (*this->m_else)->typecheck(state, scope); diff --git a/src/types.cpp b/src/types.cpp index c32c489..785be71 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -165,4 +165,44 @@ namespace types { this->m_inner ); } + + bool types_equal(std::shared_ptr type1, std::shared_ptr type2) { + if (type1->m_kind != type2->m_kind) + return false; + + if (type1->m_kind == TypeKind::Fundamental) { + auto ty1 = dynamic_cast(type1.get()); + auto ty2 = dynamic_cast(type2.get()); + return ty1->m_ty == ty2->m_ty; + } + else if (type1->m_kind == TypeKind::Function) { + auto ty1 = dynamic_cast(type1.get()); + auto ty2 = dynamic_cast(type2.get()); + + if (!types_equal(ty1->m_ret_ty, ty2->m_ret_ty)) + return false; + if (ty1->m_vararg != ty2->m_vararg) + return false; + if (ty1->m_param_tys.size() != ty2->m_param_tys.size()) + return false; + + for (int i = 0; i < static_cast(ty1->m_param_tys.size()); i++) { + auto param1 = ty1->m_param_tys[i]; + auto param2 = ty2->m_param_tys[i]; + if (!types_equal(param1, param2)) + return false; + } + + return true; + } + else if (type1->m_kind == TypeKind::Pointer) { + auto ty1 = dynamic_cast(type1.get()); + auto ty2 = dynamic_cast(type2.get()); + + return types_equal(ty1->m_inner, ty2->m_inner); + } + else { + return false; + } + } } \ No newline at end of file diff --git a/src/types.h b/src/types.h index 61fdafd..596eac2 100644 --- a/src/types.h +++ b/src/types.h @@ -47,9 +47,8 @@ namespace types { }; class FundamentalType : public Type { - private: - FundamentalTypeKind m_ty; public: + FundamentalTypeKind m_ty; FundamentalType(FundamentalTypeKind kind) : Type(TypeKind::Fundamental), m_ty{ kind } {} virtual ~FundamentalType() override = default; virtual std::string formatted() override; @@ -83,9 +82,8 @@ namespace types { class PointerType : public Type { - private: - std::shared_ptr m_inner; public: + std::shared_ptr m_inner; PointerType(std::shared_ptr inner) : Type(TypeKind::Pointer), m_inner{ std::move(inner) } { } @@ -94,6 +92,8 @@ namespace types { virtual llvm::Type* codegen(codegen::Builder& builder) override; virtual std::pair> load(codegen::Builder& builder, llvm::Value* ptr) override; }; + + bool types_equal(std::shared_ptr type1, std::shared_ptr type2); } #endif \ No newline at end of file