Fix linker working with recursive imports

This commit is contained in:
Sofia 2025-08-05 21:03:53 +03:00
parent 1c3386bc9a
commit 1ba0de442a
5 changed files with 501 additions and 228 deletions

View File

@ -1,5 +1,7 @@
import triple_import_vec2::Vec2;
import triple_import_ship::Ship;
fn main() -> i32 { return 0; }
fn main() -> u32 {
let a = Ship::new();
return a.position.x;
}

View File

@ -1,3 +1,11 @@
import triple_import_vec2::Vec2;
struct Ship { position: Vec2 }
struct Ship { position: Vec2 }
impl Ship {
pub fn new() -> Ship {
Ship {
position: Vec2 {x: 15, y: 16}
}
}
}

View File

@ -1,2 +1,2 @@
struct Vec2 { x: f32, y: f32 }
struct Vec2 { x: u32, y: u32 }

View File

@ -1,8 +1,9 @@
use std::{
cell::RefCell,
cell::{RefCell, RefMut},
collections::{HashMap, HashSet},
convert::Infallible,
fs::{self},
hash::Hash,
path::PathBuf,
rc::Rc,
};
@ -46,6 +47,12 @@ pub enum ErrorKind {
NoMainDefined,
#[error("Main module has no main-function!")]
NoMainFunction,
#[error("Type {0} has cyclical fields!")]
CyclicalType(String),
#[error("Type {0} is imported cyclically!")]
RecursiveTypeImport(String),
#[error("Type {} does not exist in module {}", 0.0, 0.1)]
NoSuchTypeInModule(CustomTypeKey),
#[error("Function {1} in module {0} is private!")]
FunctionIsPrivate(String, String),
}
@ -53,11 +60,13 @@ pub enum ErrorKind {
pub fn compile_std(module_map: &mut ErrorModules) -> Result<Module, ReidError> {
let (id, tokens) = parse_module(STD_SOURCE, STD_NAME, None, module_map, None)?;
let module = compile_module(id, tokens, module_map, None, false)?.map_err(|(_, e)| e)?;
dbg!(id, module.module_id);
let module_id = module.module_id;
let mut mir_context = super::Context::from(vec![module], Default::default());
let std_compiled = mir_context.modules.remove(&module_id).unwrap();
dbg!(std_compiled.module_id);
Ok(std_compiled)
}
@ -70,11 +79,21 @@ pub struct LinkerPass<'map> {
#[derive(Default, Clone)]
pub struct LinkerState {
extern_imported_types: HashMap<SourceModuleId, HashMap<String, SourceModuleId>>,
foreign_types: HashMap<SourceModuleId, HashMap<String, SourceModuleId>>,
}
type LinkerPassState<'st, 'sc> = PassState<'st, 'sc, LinkerState, ErrorKind>;
#[derive(Clone, Debug)]
struct LinkerModule {
module: Rc<RefCell<Module>>,
// Functions imported directly from a module
function_imports: HashMap<String, (SourceModuleId, Metadata)>,
// Types imported either directly by the user or indirectly via functions.
// May contain type-imports that are again recursively imported elsewhere.
type_imports: HashMap<String, (SourceModuleId, Metadata)>,
}
impl<'map> Pass for LinkerPass<'map> {
type Data = LinkerState;
type TError = ErrorKind;
@ -102,40 +121,56 @@ impl<'map> Pass for LinkerPass<'map> {
}
};
let mut modules = HashMap::<SourceModuleId, Rc<RefCell<_>>>::new();
let mut modules = HashMap::<SourceModuleId, LinkerModule>::new();
let mut module_ids = HashMap::<String, SourceModuleId>::new();
for (mod_id, module) in context.modules.drain() {
modules.insert(mod_id, Rc::new(RefCell::new(module)));
modules.insert(
mod_id,
LinkerModule {
module: Rc::new(RefCell::new(module)),
function_imports: HashMap::new(),
type_imports: HashMap::new(),
},
);
}
let mut modules_to_process: Vec<Rc<RefCell<_>>> = modules.values().cloned().collect();
let mut module_queue: Vec<LinkerModule> = modules.values().cloned().collect();
let mut still_required_types = HashSet::<CustomTypeKey>::new();
while let Some(mut importer) = module_queue.pop() {
let importer_mod = importer.module.borrow_mut();
while let Some(module) = modules_to_process.pop() {
let mut extern_types = HashMap::new();
let mut already_imported_binops = HashSet::<BinopKey>::new();
let mut already_imported_types = HashSet::<CustomTypeKey>::new();
let mut importer_module = module.borrow_mut();
for import in importer_module.imports.clone() {
// Gp go through all imports in this specific modulee
for import in importer_mod.imports.clone() {
let Import(path, _) = &import;
if path.len() != 2 {
state.ok::<_, Infallible>(Err(ErrorKind::InnerModulesNotYetSupported(import.clone())), import.1);
}
// Cut the import statement into parts
let Some((module_name, _)) = path.get(0) else {
continue;
};
let Some((import_name, _)) = path.get(1) else {
continue;
};
let mut imported = if let Some(mod_id) = module_ids.get(module_name) {
// Actually compile or fetch the imported module
let imported = if let Some(mod_id) = module_ids.get(module_name) {
modules.get(mod_id).unwrap()
} else if module_name == STD_NAME {
let std = compile_std(&mut self.module_map)?;
modules.insert(std.module_id, Rc::new(RefCell::new(compile_std(&mut self.module_map)?)));
module_ids.insert(std.name, std.module_id);
modules.get(&std.module_id).unwrap()
let module_id = std.module_id;
modules.insert(
std.module_id,
LinkerModule {
module: Rc::new(RefCell::new(std)),
function_imports: HashMap::new(),
type_imports: HashMap::new(),
},
);
module_ids.insert(module_name.clone(), module_id);
modules.get(&module_id).unwrap()
} else {
let file_path = PathBuf::from(&context.base.clone()).join(module_name.to_owned() + ".reid");
@ -176,9 +211,16 @@ impl<'map> Pass for LinkerPass<'map> {
}
let module_id = imported_module.module_id;
module_ids.insert(imported_module.name.clone(), imported_module.module_id);
modules.insert(module_id, Rc::new(RefCell::new(imported_module)));
modules.insert(
module_id,
LinkerModule {
module: Rc::new(RefCell::new(imported_module)),
function_imports: HashMap::new(),
type_imports: HashMap::new(),
},
);
let imported = modules.get_mut(&module_id).unwrap();
modules_to_process.push(imported.clone());
module_queue.push(imported.clone());
imported
}
Err((_, err)) => {
@ -203,70 +245,126 @@ impl<'map> Pass for LinkerPass<'map> {
continue;
}
}
}
.borrow_mut();
let Some((import_name, _)) = path.get(1) else {
continue;
};
let imported_id = imported.module_id;
let mut imported_types = Vec::new();
let imported_module = imported.module.borrow();
if let Some(func) = imported.functions.iter_mut().find(|f| f.name == *import_name) {
let func_name = func.name.clone();
let func_signature = func.signature();
if let Some(func) = imported_module.functions.iter().find(|f| f.name == *import_name) {
// If the imported item is a function, add it to the list of imported functions
importer
.function_imports
.insert(func.name.clone(), (imported_module.module_id, import.1));
} else if let Some(ty) = imported_module.typedefs.iter().find(|t| t.name == *import_name) {
// If the imported item is a type, add it to the list of imported types
// imported_types.insert((CustomTypeKey(ty.name.clone(), ty.source_module), true));
importer
.type_imports
.insert(ty.name.clone(), (imported_module.module_id, import.1));
}
}
if !func.is_pub {
let module_id = importer_mod.module_id;
drop(importer_mod);
modules.insert(module_id, importer);
}
for (_, linker_module) in &modules {
let mut importer_module = linker_module.module.borrow_mut();
let mut unresolved_types = HashMap::new();
// 1. Import functions and find all types that are dependencies of
// functions
for (name, (function_source, import_meta)) in &linker_module.function_imports {
dbg!(&name, &function_source, &modules.keys());
let mut function_module = modules.get(&function_source).unwrap().module.borrow_mut();
let func_module_name = function_module.name.clone();
let func_module_id = function_module.module_id;
let function = function_module.functions.iter_mut().find(|f| f.name == *name).unwrap();
// If function is not pub, error
if !function.is_pub {
state.ok::<_, Infallible>(
Err(ErrorKind::FunctionIsPrivate(func_module_name, function.name.clone())),
import_meta.clone(),
);
continue;
}
// If function already exists, error
if let Some(existing) = importer_module.functions.iter().find(|f| f.name == *name) {
if let Err(e) = existing.equals_as_imported(&function) {
state.ok::<_, Infallible>(
Err(ErrorKind::FunctionIsPrivate(module_name.clone(), func_name.clone())),
import.1,
Err(ErrorKind::FunctionImportIssue(func_module_name, name.clone(), e)),
import_meta.clone(),
);
continue;
}
}
func.is_imported = true;
function.is_imported = true;
if let Some(existing) = importer_module.functions.iter().find(|f| f.name == *func_name) {
if let Err(e) = existing.equals_as_imported(func) {
state.ok::<_, Infallible>(
Err(ErrorKind::FunctionImportIssue(
module_name.clone(),
func_name.clone(),
e,
)),
import.1,
);
for ty in import_type(&function.return_type) {
unresolved_types.insert(ty, (import_meta.clone(), true));
}
for param in &function.parameters {
for ty in import_type(&param.ty) {
unresolved_types.insert(ty, (import_meta.clone(), true));
}
}
importer_module.functions.push(FunctionDefinition {
name: function.name.clone(),
linkage_name: None,
is_pub: false,
is_imported: false,
return_type: function.return_type.clone(),
parameters: function.parameters.clone(),
kind: super::FunctionDefinitionKind::Extern(true),
source: Some(func_module_id),
signature_meta: function.signature(),
});
}
// 2. Add all manually imported types to the list of types that need
// to be resolved and recursed
for (name, (source_module, meta)) in &linker_module.type_imports {
unresolved_types.insert(
CustomTypeKey(name.clone(), source_module.clone()),
(meta.clone(), false),
);
}
// 3. Recurse these types to find their true sources, find their
// dependencies, and list them all. Store manually imported types
// in a separate mapping for later.
let mut imported_types = HashSet::new();
let mut foreign_keys = HashSet::new();
let mut already_imported_binops = HashSet::new();
for (ty, (meta, is_dependency)) in unresolved_types {
// First deal with manually imported types
if !is_dependency {
// Add them to the list of foreign types (types that are
// later replaced in-source by name)
let imported_ty_key = match resolve_type(&ty, &modules) {
Ok(ty) => {
foreign_keys.insert(ty.clone());
ty
}
}
Err(e) => {
state.note_errors(&vec![e], meta);
return Ok(());
}
};
let types = import_type(&func.return_type, false);
let return_type = func.return_type.clone();
imported_types.extend(types);
let mut param_tys = Vec::new();
for param in &func.parameters {
let types = import_type(&param.ty, false);
imported_types.extend(types);
param_tys.push(param.clone());
}
importer_module.functions.push(FunctionDefinition {
name: func_name.clone(),
linkage_name: None,
is_pub: false,
is_imported: false,
return_type,
parameters: param_tys,
kind: super::FunctionDefinitionKind::Extern(true),
source: Some(imported.module_id),
signature_meta: func_signature,
});
} else if let Some(ty) = imported.typedefs.iter_mut().find(|f| f.name == *import_name) {
let external_key = CustomTypeKey(ty.name.clone(), ty.source_module);
let imported_ty = TypeKind::CustomType(external_key.clone());
imported_types.push((external_key, true));
let mut imported = modules.get(&imported_ty_key.1).unwrap().module.borrow_mut();
let imported_module_name = imported.name.clone();
let imported_module_id = imported.module_id.clone();
let imported_ty = TypeKind::CustomType(imported_ty_key);
// Add all binary operators that are defined for this type
for binop in &mut imported.binop_defs {
if binop.lhs.ty != imported_ty && binop.rhs.ty != imported_ty {
continue;
@ -297,6 +395,7 @@ impl<'map> Pass for LinkerPass<'map> {
}
}
// Import all functions that are associated with this type
for (ty, func) in &mut imported.associated_functions {
if *ty != imported_ty {
continue;
@ -306,8 +405,11 @@ impl<'map> Pass for LinkerPass<'map> {
if !func.is_pub {
state.ok::<_, Infallible>(
Err(ErrorKind::FunctionIsPrivate(module_name.clone(), func_name.clone())),
import.1,
Err(ErrorKind::FunctionIsPrivate(
imported_module_name.clone(),
func_name.clone(),
)),
meta.clone(),
);
continue;
}
@ -322,26 +424,42 @@ impl<'map> Pass for LinkerPass<'map> {
if let Err(e) = existing.equals_as_imported(func) {
state.ok::<_, Infallible>(
Err(ErrorKind::FunctionImportIssue(
module_name.clone(),
imported_module_name.clone(),
func_name.clone(),
e,
)),
import.1,
meta.clone(),
);
}
}
let types = import_type(&func.return_type, false);
let mut assoc_function_types = HashSet::new();
let types = import_type(&func.return_type);
let return_type = func.return_type.clone();
imported_types.extend(types);
assoc_function_types.extend(types);
let mut param_tys = Vec::new();
for param in &func.parameters {
let types = import_type(&param.ty, false);
imported_types.extend(types);
let types = import_type(&param.ty);
assoc_function_types.extend(types);
param_tys.push(param.clone());
}
for inner_ty in assoc_function_types {
dbg!(&inner_ty, &imported_module_id);
if inner_ty.1 != imported_module_id {
let resolved = match resolve_types_recursively(&inner_ty, &modules, &mut HashSet::new())
{
Ok(ty) => ty,
Err(e) => {
state.note_errors(&vec![e], meta);
return Ok(());
}
};
imported_types.extend(resolved);
}
}
importer_module.associated_functions.push((
ty.clone(),
FunctionDefinition {
@ -352,81 +470,32 @@ impl<'map> Pass for LinkerPass<'map> {
return_type,
parameters: param_tys,
kind: super::FunctionDefinitionKind::Extern(true),
source: Some(imported_id),
source: Some(imported_module_id),
signature_meta: func.signature_meta,
},
));
}
} else {
state.ok::<_, Infallible>(
Err(ErrorKind::ImportDoesNotExist(module_name.clone(), import_name.clone())),
import.1,
);
continue;
}
let mut seen = HashSet::new();
let mut current_extern_types = HashSet::new();
seen.extend(imported_types.clone().iter().map(|t| t.0.clone()));
for ty in still_required_types.clone() {
if ty.1 == imported_id && !seen.contains(&ty) {
imported_types.push((ty, false));
let resolved = match resolve_types_recursively(&ty, &modules, &mut HashSet::new()) {
Ok(ty) => ty,
Err(e) => {
state.note_errors(&vec![e], meta);
return Ok(());
}
}
};
imported_types.extend(resolved);
}
current_extern_types.extend(imported_types.clone().iter().filter(|t| t.1).map(|t| t.0.clone()));
for extern_type in &current_extern_types {
extern_types.insert(extern_type.0.clone(), extern_type.1);
}
let imported_mod_id = imported.module_id;
let imported_mod_typedefs = &mut imported.typedefs;
for typekey in imported_types.clone() {
let typedef = imported_mod_typedefs
.iter()
.find(|ty| CustomTypeKey(ty.name.clone(), ty.source_module) == typekey.0)
.unwrap();
let inner = find_inner_types(typedef, seen.clone(), imported_mod_typedefs);
for ty in inner {
if ty.1 == imported_id && imported_mod_typedefs.iter().find(|t| t.name == ty.0).is_some() {
seen.insert(ty);
} else {
still_required_types.insert(ty);
}
}
}
// TODO: Unable to import same-named type from multiple places..
let seen = seen
.difference(&already_imported_types)
// 4. Import all listed types.
for typekey in &imported_types {
let imported_ty_module = modules.get(&typekey.1).unwrap().module.borrow();
if let Some(mut typedef) = imported_ty_module
.typedefs
.iter()
.find(|ty| CustomTypeKey(ty.name.clone(), ty.source_module) == *typekey)
.cloned()
.collect::<HashSet<_>>();
already_imported_types.extend(seen.clone());
for typekey in &already_imported_types {
if current_extern_types.contains(typekey) {
let module_id = importer_module.module_id;
let typedef = importer_module
.typedefs
.iter_mut()
.find(|t| t.name == typekey.0 && t.source_module == typekey.1);
if let Some(typedef) = typedef {
typedef.importer = Some(module_id);
}
}
}
for typekey in seen.into_iter() {
let mut typedef = imported_mod_typedefs
.iter()
.find(|ty| CustomTypeKey(ty.name.clone(), imported_mod_id) == typekey)
.unwrap()
.clone();
if current_extern_types.contains(&typekey) {
{
if foreign_keys.contains(&typekey) {
typedef = TypeDefinition {
importer: Some(importer_module.module_id),
..typedef
@ -436,16 +505,23 @@ impl<'map> Pass for LinkerPass<'map> {
importer_module.typedefs.push(typedef);
}
}
// Set foreign types
let mut foreign_types = HashMap::new();
for key in imported_types {
foreign_types.insert(key.0.clone(), key.1);
}
state
.scope
.data
.extern_imported_types
.insert(importer_module.module_id, extern_types);
.foreign_types
.insert(importer_module.module_id, foreign_types);
}
let mut modules: Vec<Module> = modules
.into_values()
.map(|v| Rc::into_inner(v).unwrap().into_inner())
.map(|v| Rc::into_inner(v.module).unwrap().into_inner())
.collect();
for module in modules.drain(..) {
@ -456,13 +532,13 @@ impl<'map> Pass for LinkerPass<'map> {
}
fn module(&mut self, module: &mut Module, state: PassState<Self::Data, Self::TError>) -> PassResult {
let extern_types = &state.scope.data.extern_imported_types.get(&module.module_id);
if let Some(extern_types) = extern_types {
let foreign_types = &state.scope.data.foreign_types.get(&module.module_id);
if let Some(foreign_types) = foreign_types {
for ty in &mut module.typedefs {
match &mut ty.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field in fields {
field.1 = field.1.update_imported(extern_types, module.module_id);
field.1 = field.1.update_imported(foreign_types);
}
}
}
@ -478,11 +554,11 @@ impl<'map> Pass for LinkerPass<'map> {
) -> PassResult {
if matches!(function.kind, FunctionDefinitionKind::Local(_, _)) {
let mod_id = state.scope.module_id.unwrap();
let extern_types = &state.scope.data.extern_imported_types.get(&mod_id);
if let Some(extern_types) = extern_types {
function.return_type = function.return_type.update_imported(*extern_types, mod_id);
let foreign_types = &state.scope.data.foreign_types.get(&mod_id);
if let Some(foreign_types) = foreign_types {
function.return_type = function.return_type.update_imported(*foreign_types);
for param in function.parameters.iter_mut() {
param.ty = param.ty.update_imported(extern_types, mod_id);
param.ty = param.ty.update_imported(foreign_types);
}
}
}
@ -491,11 +567,11 @@ impl<'map> Pass for LinkerPass<'map> {
fn stmt(&mut self, stmt: &mut super::Statement, state: PassState<Self::Data, Self::TError>) -> PassResult {
let mod_id = state.scope.module_id.unwrap();
let extern_types = &state.scope.data.extern_imported_types.get(&mod_id);
if let Some(extern_types) = extern_types {
let foreign_types = &state.scope.data.foreign_types.get(&mod_id);
if let Some(foreign_types) = foreign_types {
match &mut stmt.0 {
super::StmtKind::Let(var_ref, _, _) => {
var_ref.0 = var_ref.0.update_imported(extern_types, mod_id);
var_ref.0 = var_ref.0.update_imported(foreign_types);
}
_ => {}
}
@ -505,28 +581,24 @@ impl<'map> Pass for LinkerPass<'map> {
fn expr(&mut self, expr: &mut super::Expression, state: PassState<Self::Data, Self::TError>) -> PassResult {
let mod_id = state.scope.module_id.unwrap();
let extern_types = &state.scope.data.extern_imported_types.get(&mod_id);
if let Some(extern_types) = extern_types {
let foreign_types = &state.scope.data.foreign_types.get(&mod_id);
if let Some(foreign_types) = foreign_types {
match &mut expr.0 {
super::ExprKind::Variable(var_ref) => {
var_ref.0 = var_ref.0.update_imported(extern_types, mod_id);
var_ref.0 = var_ref.0.update_imported(foreign_types);
}
super::ExprKind::Indexed(.., type_kind, _) => {
*type_kind = type_kind.update_imported(extern_types, mod_id)
}
super::ExprKind::Accessed(.., type_kind, _, _) => {
*type_kind = type_kind.update_imported(extern_types, mod_id)
}
super::ExprKind::BinOp(.., type_kind) => *type_kind = type_kind.update_imported(extern_types, mod_id),
super::ExprKind::Indexed(.., type_kind, _) => *type_kind = type_kind.update_imported(foreign_types),
super::ExprKind::Accessed(.., type_kind, _, _) => *type_kind = type_kind.update_imported(foreign_types),
super::ExprKind::BinOp(.., type_kind) => *type_kind = type_kind.update_imported(foreign_types),
super::ExprKind::Borrow(..) => {}
super::ExprKind::Deref(..) => {}
super::ExprKind::CastTo(_, type_kind) => *type_kind = type_kind.update_imported(extern_types, mod_id),
super::ExprKind::CastTo(_, type_kind) => *type_kind = type_kind.update_imported(foreign_types),
super::ExprKind::AssociatedFunctionCall(type_kind, _) => {
*type_kind = type_kind.update_imported(extern_types, mod_id)
*type_kind = type_kind.update_imported(foreign_types)
}
super::ExprKind::Struct(key, _) => {
*key = if let Some(mod_id) = extern_types.get(&key.0) {
*key = if let Some(mod_id) = foreign_types.get(&key.0) {
CustomTypeKey(key.0.clone(), *mod_id)
} else {
key.clone()
@ -540,81 +612,101 @@ impl<'map> Pass for LinkerPass<'map> {
}
impl TypeKind {
fn update_imported(
&self,
extern_types: &HashMap<String, SourceModuleId>,
importer_mod_id: SourceModuleId,
) -> TypeKind {
fn update_imported(&self, foreign_types: &HashMap<String, SourceModuleId>) -> TypeKind {
match &self {
TypeKind::Array(type_kind, len) => {
TypeKind::Array(Box::new(type_kind.update_imported(extern_types, importer_mod_id)), *len)
TypeKind::Array(Box::new(type_kind.update_imported(foreign_types)), *len)
}
TypeKind::CustomType(custom_type_key) => {
if let Some(mod_id) = extern_types.get(&custom_type_key.0) {
dbg!(foreign_types, &custom_type_key);
if let Some(mod_id) = foreign_types.get(&custom_type_key.0) {
TypeKind::CustomType(CustomTypeKey(custom_type_key.0.clone(), *mod_id))
} else {
self.clone()
}
}
TypeKind::Borrow(type_kind, mutable) => TypeKind::Borrow(
Box::new(type_kind.update_imported(extern_types, importer_mod_id)),
*mutable,
),
TypeKind::UserPtr(type_kind) => {
TypeKind::UserPtr(Box::new(type_kind.update_imported(extern_types, importer_mod_id)))
}
TypeKind::CodegenPtr(type_kind) => {
TypeKind::CodegenPtr(Box::new(type_kind.update_imported(extern_types, importer_mod_id)))
TypeKind::Borrow(type_kind, mutable) => {
TypeKind::Borrow(Box::new(type_kind.update_imported(foreign_types)), *mutable)
}
TypeKind::UserPtr(type_kind) => TypeKind::UserPtr(Box::new(type_kind.update_imported(foreign_types))),
TypeKind::CodegenPtr(type_kind) => TypeKind::CodegenPtr(Box::new(type_kind.update_imported(foreign_types))),
_ => self.clone(),
}
}
}
fn import_type(ty: &TypeKind, usable_import: bool) -> Vec<(CustomTypeKey, bool)> {
fn import_type(ty: &TypeKind) -> Vec<CustomTypeKey> {
let mut imported_types = Vec::new();
match &ty {
TypeKind::CustomType(key) => imported_types.push((key.clone(), usable_import)),
TypeKind::Borrow(ty, _) => imported_types.extend(import_type(ty, usable_import)),
TypeKind::Array(ty, _) => imported_types.extend(import_type(ty, usable_import)),
TypeKind::UserPtr(ty) => imported_types.extend(import_type(ty, usable_import)),
TypeKind::CodegenPtr(ty) => imported_types.extend(import_type(ty, usable_import)),
TypeKind::CustomType(key) => imported_types.push(key.clone()),
TypeKind::Borrow(ty, _) => imported_types.extend(import_type(ty)),
TypeKind::Array(ty, _) => imported_types.extend(import_type(ty)),
TypeKind::UserPtr(ty) => imported_types.extend(import_type(ty)),
TypeKind::CodegenPtr(ty) => imported_types.extend(import_type(ty)),
_ => {}
};
imported_types
}
fn find_inner_types(
typedef: &TypeDefinition,
mut seen: HashSet<CustomTypeKey>,
typedefs: &Vec<TypeDefinition>,
) -> Vec<CustomTypeKey> {
match &typedef.kind {
crate::mir::TypeDefinitionKind::Struct(struct_type) => {
let typekeys = struct_type
.0
.iter()
.filter_map(|t| match &t.1 {
TypeKind::CustomType(key) => Some(key),
_ => None,
})
.cloned()
.collect::<Vec<_>>();
for typekey in typekeys {
if seen.contains(&typekey) {
continue;
}
seen.insert(typekey.clone());
if typekey.1 == typedef.source_module {
if let Some(inner) = typedefs.iter().find(|t| t.name == typekey.0) {
let ret = find_inner_types(inner, seen.clone(), typedefs);
seen.extend(ret);
}
}
fn resolve_type(
ty: &CustomTypeKey,
modules: &HashMap<SourceModuleId, LinkerModule>,
) -> Result<CustomTypeKey, ErrorKind> {
let mut source_module_id = ty.1;
let mut seen = HashSet::new();
loop {
seen.insert(source_module_id);
let source_module = modules.get(&source_module_id).unwrap();
if let Some((new_module_id, _)) = source_module.type_imports.get(&ty.0) {
if seen.contains(new_module_id) {
return Err(ErrorKind::RecursiveTypeImport(ty.0.clone()));
}
seen.into_iter().collect()
source_module_id = *new_module_id;
} else {
break;
}
}
Ok(CustomTypeKey(ty.0.clone(), source_module_id))
}
fn resolve_types_recursively(
ty: &CustomTypeKey,
modules: &HashMap<SourceModuleId, LinkerModule>,
seen: &mut HashSet<CustomTypeKey>,
) -> Result<Vec<CustomTypeKey>, ErrorKind> {
let resolved_ty = resolve_type(ty, modules)?;
if seen.contains(&resolved_ty) {
return Err(ErrorKind::CyclicalType(ty.0.clone()));
}
let mut types = Vec::new();
types.push(resolved_ty.clone());
seen.insert(resolved_ty.clone());
let resolved = modules
.get(&resolved_ty.1)
.unwrap()
.module
.borrow()
.typedefs
.iter()
.find(|t| t.name == resolved_ty.0)
.ok_or(ErrorKind::NoSuchTypeInModule(ty.clone()))
.cloned()?;
match resolved.kind {
TypeDefinitionKind::Struct(StructType(fields)) => {
for field in fields {
match &field.1 {
TypeKind::CustomType(ty_key) => {
types.extend(resolve_types_recursively(ty_key, modules, seen)?);
}
_ => {}
}
}
}
}
Ok(types)
}

View File

@ -0,0 +1,171 @@
use std::{path::PathBuf, process::Command, time::SystemTime};
use reid::{
compile_module,
ld::LDRunner,
mir::{self},
parse_module, perform_all_passes,
};
use reid_lib::{compile::CompileOutput, Context};
use util::assert_err;
mod util;
fn test_compile(source: &str, name: &str) -> CompileOutput {
assert_err(assert_err(std::panic::catch_unwind(|| {
let mut map = Default::default();
let (id, tokens) = assert_err(parse_module(source, name, None, &mut map, None));
let module = assert_err(assert_err(compile_module(id, tokens, &mut map, None, true)).map_err(|(_, e)| e));
let mut mir_context = mir::Context::from(vec![module], Default::default());
assert_err(perform_all_passes(&mut mir_context, &mut map));
let context = Context::new(format!("Reid ({})", env!("CARGO_PKG_VERSION")));
let codegen = assert_err(mir_context.codegen(&context));
Ok::<_, ()>(codegen.compile(None, Vec::new()).output())
})))
}
fn test(source: &str, name: &str, expected_exit_code: Option<i32>) {
assert_err(assert_err(std::panic::catch_unwind(|| {
let output = test_compile(source, name);
let time = SystemTime::now();
let in_path = PathBuf::from(format!(
"/tmp/temp-{}.o",
time.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_nanos()
));
std::fs::write(&in_path, &output.obj_buffer).expect("Could not write OBJ-file!");
let out_path = in_path.with_extension("out");
LDRunner::from_command("ld")
.with_library("c")
.invoke(&in_path, &out_path);
std::fs::remove_file(in_path).unwrap();
let executed = Command::new(&out_path).output();
std::fs::remove_file(out_path).unwrap();
if let Some(expected_exit_code) = expected_exit_code {
assert_eq!(expected_exit_code, executed.unwrap().status.code().unwrap());
}
Ok::<(), ()>(())
})))
}
#[test]
fn arithmetic_compiles_well() {
test(include_str!("../../examples/arithmetic.reid"), "test", Some(48));
}
#[test]
fn array_structs_compiles_well() {
test(include_str!("../../examples/array_structs.reid"), "test", Some(5));
}
#[test]
fn array_compiles_well() {
test(include_str!("../../examples/array.reid"), "test", Some(3));
}
#[test]
fn borrow_compiles_well() {
test(include_str!("../../examples/borrow.reid"), "test", Some(17));
}
#[test]
fn borrow_hard_compiles_well() {
test(include_str!("../../examples/borrow_hard.reid"), "test", Some(17));
}
#[test]
fn cast_compiles_well() {
test(include_str!("../../examples/cast.reid"), "test", Some(6));
}
#[test]
fn char_compiles_well() {
test(include_str!("../../examples/char.reid"), "test", Some(98));
}
#[test]
fn div_mod_compiles_well() {
test(include_str!("../../examples/div_mod.reid"), "test", Some(12));
}
#[test]
fn fibonacci_compiles_well() {
test(include_str!("../../examples/fibonacci.reid"), "test", Some(1));
}
#[test]
fn float_compiles_well() {
test(include_str!("../../examples/float.reid"), "test", Some(1));
}
#[test]
fn hello_world_compiles_well() {
test(include_str!("../../examples/hello_world.reid"), "test", None);
}
#[test]
fn hello_world_harder_compiles_well() {
test(include_str!("../../examples/hello_world_harder.reid"), "test", None);
}
#[test]
fn mutable_compiles_well() {
test(include_str!("../../examples/mutable.reid"), "test", Some(21));
}
#[test]
fn ptr_compiles_well() {
test(include_str!("../../examples/ptr.reid"), "test", Some(5));
}
#[test]
fn std_test_compiles_well() {
test(include_str!("../../examples/std_test.reid"), "test", Some(3));
}
#[test]
fn strings_compiles_well() {
test(include_str!("../../examples/strings.reid"), "test", Some(5));
}
#[test]
fn struct_compiles_well() {
test(include_str!("../../examples/struct.reid"), "test", Some(17));
}
#[test]
fn loops_compiles_well() {
test(include_str!("../../examples/loops.reid"), "test", Some(10));
}
#[test]
fn ptr_hard_compiles_well() {
test(include_str!("../../examples/ptr_hard.reid"), "test", Some(0));
}
#[test]
fn loop_hard_compiles_well() {
test(include_str!("../../examples/loop_hard.reid"), "test", Some(0));
}
#[test]
fn custom_binop_compiles_well() {
test(include_str!("../../examples/custom_binop.reid"), "test", Some(21));
}
#[test]
fn array_short_compiles_well() {
test(include_str!("../../examples/array_short.reid"), "test", Some(5));
}
#[test]
fn imported_type_compiles_well() {
test(include_str!("../../examples/imported_type.reid"), "test", Some(0));
}
#[test]
fn associated_functions() {
test(
include_str!("../../examples/associated_functions.reid"),
"test",
Some(4),
);
}
#[test]
fn mutable_inner_functions() {
test(include_str!("../../examples/mutable_inner.reid"), "test", Some(0));
}
#[test]
fn cpu_raytracer_compiles() {
test_compile(include_str!("../../examples/cpu_raytracer.reid"), "test");
}