From a2e52e0bd21177f505f615578b89e0fbe1ca8d99 Mon Sep 17 00:00:00 2001 From: sofia Date: Sun, 13 Jul 2025 20:31:33 +0300 Subject: [PATCH] Add Array support to llvm-lib --- reid-llvm-lib/src/builder.rs | 19 ++++++++++++++ reid-llvm-lib/src/compile.rs | 49 +++++++++++++++++++++++++++--------- reid-llvm-lib/src/debug.rs | 30 +++++++++++++++------- reid-llvm-lib/src/lib.rs | 3 +++ 4 files changed, 80 insertions(+), 21 deletions(-) diff --git a/reid-llvm-lib/src/builder.rs b/reid-llvm-lib/src/builder.rs index e3bbbd9..12ea5a7 100644 --- a/reid-llvm-lib/src/builder.rs +++ b/reid-llvm-lib/src/builder.rs @@ -277,6 +277,14 @@ impl Builder { Err(()) } } + Extract(list, idx) => { + let list_ty = list.get_type(&self)?; + if let Type::Array(_, len) = list_ty { + if len < idx { Ok(()) } else { Err(()) } + } else { + Err(()) + } + } } } } @@ -338,6 +346,11 @@ impl InstructionValue { Alloca(_, ty) => Ok(Type::Ptr(Box::new(ty.clone()))), Load(_, ty) => Ok(ty.clone()), Store(_, value) => value.get_type(builder), + Extract(arr, _) => match arr.get_type(builder) { + Ok(Type::Array(elem_t, _)) => Ok(*elem_t), + Ok(_) => Err(()), + Err(_) => Err(()), + }, } } } @@ -358,6 +371,10 @@ impl ConstValue { ConstValue::U64(_) => U64, ConstValue::U128(_) => U128, ConstValue::Bool(_) => Bool, + ConstValue::ConstArray(arr) => Array( + Box::new(arr.iter().map(|a| a.get_type()).next().unwrap_or(Void)), + arr.len() as u32, + ), } } } @@ -378,6 +395,7 @@ impl Type { Type::Bool => true, Type::Void => false, Type::Ptr(_) => false, + Type::Array(_, _) => false, } } @@ -396,6 +414,7 @@ impl Type { Type::Bool => false, Type::Void => false, Type::Ptr(_) => false, + Type::Array(_, _) => false, } } } diff --git a/reid-llvm-lib/src/compile.rs b/reid-llvm-lib/src/compile.rs index 6e5088f..d724698 100644 --- a/reid-llvm-lib/src/compile.rs +++ b/reid-llvm-lib/src/compile.rs @@ -349,6 +349,12 @@ impl InstructionHolder { module.values.get(&val).unwrap().value_ref, module.values.get(&ptr).unwrap().value_ref, ), + Extract(instruction_value, idx) => LLVMBuildExtractValue( + module.builder_ref, + module.values.get(instruction_value).unwrap().value_ref, + *idx as u32, + c"extract".as_ptr(), + ), } }; LLVMValue { @@ -415,18 +421,36 @@ impl ConstValue { fn as_llvm(&self, context: LLVMContextRef) -> LLVMValueRef { unsafe { let t = self.get_type().as_llvm(context); - match *self { - ConstValue::Bool(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::I8(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::I16(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::I32(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::I64(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::I128(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::U8(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::U16(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::U32(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::U64(val) => LLVMConstInt(t, val as u64, 1), - ConstValue::U128(val) => LLVMConstInt(t, val as u64, 1), + match self { + ConstValue::Bool(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::I8(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::I16(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::I32(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::I64(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::I128(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::U8(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::U16(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::U32(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::U64(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::U128(val) => LLVMConstInt(t, *val as u64, 1), + ConstValue::ConstArray(const_values) => { + let elem_ty = const_values + .iter() + .map(|e| e.get_type()) + .next() + .unwrap_or(Type::Void); + + let mut elems = const_values + .iter() + .map(|e| e.as_llvm(context)) + .collect::>(); + + LLVMConstArray( + elem_ty.as_llvm(context), + elems.as_mut_ptr(), + elems.len() as u32, + ) + } } } } @@ -445,6 +469,7 @@ impl Type { Bool => LLVMInt1TypeInContext(context), Void => LLVMVoidType(), Ptr(ty) => LLVMPointerType(ty.as_llvm(context), 0), + Array(elem_t, len) => LLVMArrayType(elem_t.as_llvm(context), *len as u32), } } } diff --git a/reid-llvm-lib/src/debug.rs b/reid-llvm-lib/src/debug.rs index 5a693bf..1694632 100644 --- a/reid-llvm-lib/src/debug.rs +++ b/reid-llvm-lib/src/debug.rs @@ -88,18 +88,19 @@ impl Debug for InstructionValue { impl Debug for Instr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Param(nth) => fmt_call(f, &"Param", &nth), - Self::Constant(c) => c.fmt(f), - Self::Add(lhs, rhs) => fmt_binop(f, lhs, &"+", rhs), - Self::Sub(lhs, rhs) => fmt_binop(f, lhs, &"-", rhs), - Self::Mult(lhs, rhs) => fmt_binop(f, lhs, &"*", rhs), - Self::And(lhs, rhs) => fmt_binop(f, lhs, &"&&", rhs), - Self::Phi(val) => fmt_call(f, &"Phi", &val), - Self::ICmp(cmp, lhs, rhs) => fmt_binop(f, lhs, cmp, rhs), - Self::FunctionCall(fun, params) => fmt_call(f, fun, params), + Instr::Param(nth) => fmt_call(f, &"Param", &nth), + Instr::Constant(c) => c.fmt(f), + Instr::Add(lhs, rhs) => fmt_binop(f, lhs, &"+", rhs), + Instr::Sub(lhs, rhs) => fmt_binop(f, lhs, &"-", rhs), + Instr::Mult(lhs, rhs) => fmt_binop(f, lhs, &"*", rhs), + Instr::And(lhs, rhs) => fmt_binop(f, lhs, &"&&", rhs), + Instr::Phi(val) => fmt_call(f, &"Phi", &val), + Instr::ICmp(cmp, lhs, rhs) => fmt_binop(f, lhs, cmp, rhs), + Instr::FunctionCall(fun, params) => fmt_call(f, fun, params), Instr::Alloca(name, ty) => write!(f, "alloca<{:?}>({})", ty, name), Instr::Load(val, ty) => write!(f, "load<{:?}>({:?})", ty, val), Instr::Store(ptr, val) => write!(f, "store({:?} = {:?})", ptr, val), + Instr::Extract(instruction_value, idx) => fmt_index(f, instruction_value, idx), } } } @@ -128,6 +129,17 @@ fn fmt_call( f.write_char(')') } +fn fmt_index( + f: &mut std::fmt::Formatter<'_>, + fun: &impl std::fmt::Debug, + params: &impl std::fmt::Debug, +) -> std::fmt::Result { + fun.fmt(f)?; + f.write_char('[')?; + params.fmt(f)?; + f.write_char(']') +} + impl Debug for CmpPredicate { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/reid-llvm-lib/src/lib.rs b/reid-llvm-lib/src/lib.rs index f161f5a..c9b8d8c 100644 --- a/reid-llvm-lib/src/lib.rs +++ b/reid-llvm-lib/src/lib.rs @@ -177,6 +177,7 @@ pub enum Instr { Alloca(String, Type), Load(InstructionValue, Type), Store(InstructionValue, InstructionValue), + Extract(InstructionValue, u32), /// Integer Comparison ICmp(CmpPredicate, InstructionValue, InstructionValue), @@ -199,6 +200,7 @@ pub enum Type { Bool, Void, Ptr(Box), + Array(Box, u32), } #[derive(Debug, Clone, Hash)] @@ -214,6 +216,7 @@ pub enum ConstValue { U64(u64), U128(u128), Bool(bool), + ConstArray(Vec), } #[derive(Clone, Hash)]