diff --git a/examples/associated_functions_shorthand.reid b/examples/associated_functions_shorthand.reid new file mode 100644 index 0000000..d06e0df --- /dev/null +++ b/examples/associated_functions_shorthand.reid @@ -0,0 +1,27 @@ +import std::print; +import std::from_str; +import std::String; + +struct Otus { + field: u32, +} + +impl Otus { + fn test(self) -> u32 { + self.field + } +} + +impl i32 { + fn test(self) -> u32 { + 43 + } +} + +fn main() -> u32 { + let otus = Otus { field: 17 }; + print(from_str("otus: ") + otus.test() as u64); + + + return otus.test(); +} diff --git a/reid/src/ast/mod.rs b/reid/src/ast/mod.rs index caeb2af..18cc297 100644 --- a/reid/src/ast/mod.rs +++ b/reid/src/ast/mod.rs @@ -89,6 +89,8 @@ pub enum ExpressionKind { Indexed(Box, Box), /// Struct-accessed, e.g. . Accessed(Box, String), + /// Associated function call, but with a shorthand + AccessCall(Box, Box), Binop(BinaryOperator, Box, Box), FunctionCall(Box), AssociatedFunctionCall(Type, Box), diff --git a/reid/src/ast/parse.rs b/reid/src/ast/parse.rs index 9430dc2..27809bb 100644 --- a/reid/src/ast/parse.rs +++ b/reid/src/ast/parse.rs @@ -358,12 +358,20 @@ impl Parse for PrimaryExpression { stream.get_range().unwrap(), ); } - ValueIndex::Struct(StructValueIndex(name)) => { - expr = Expression( - ExpressionKind::Accessed(Box::new(expr), name), - stream.get_range().unwrap(), - ); - } + ValueIndex::Dot(val) => match val { + DotIndexKind::StructValueIndex(name) => { + expr = Expression( + ExpressionKind::Accessed(Box::new(expr), name), + stream.get_range().unwrap(), + ); + } + DotIndexKind::FunctionCall(function_call_expression) => { + expr = Expression( + ExpressionKind::AccessCall(Box::new(expr), Box::new(function_call_expression)), + stream.get_range().unwrap(), + ); + } + }, } } @@ -473,27 +481,35 @@ impl Parse for BinaryOperator { impl Parse for FunctionCallExpression { fn parse(mut stream: TokenStream) -> Result { if let Some(Token::Identifier(name)) = stream.next() { - stream.expect(Token::ParenOpen)?; - - let mut args = Vec::new(); - - if let Ok(exp) = stream.parse() { - args.push(exp); - - while stream.expect(Token::Comma).is_ok() { - args.push(stream.parse()?); - } - } - - stream.expect(Token::ParenClose)?; - - Ok(FunctionCallExpression(name, args, stream.get_range().unwrap())) + let args = stream.parse::()?; + Ok(FunctionCallExpression(name, args.0, stream.get_range().unwrap())) } else { Err(stream.expected_err("identifier")?) } } } +#[derive(Debug)] +pub struct FunctionArgs(Vec); + +impl Parse for FunctionArgs { + fn parse(mut stream: TokenStream) -> Result { + stream.expect(Token::ParenOpen)?; + + let mut params = Vec::new(); + if let Ok(exp) = stream.parse() { + params.push(exp); + + while stream.expect(Token::Comma).is_ok() { + params.push(stream.parse()?); + } + } + stream.expect(Token::ParenClose)?; + + Ok(FunctionArgs(params)) + } +} + impl Parse for IfExpression { fn parse(mut stream: TokenStream) -> Result { stream.expect(Token::If)?; @@ -766,14 +782,14 @@ impl Parse for NamedField { #[derive(Debug, Clone)] pub enum ValueIndex { Array(ArrayValueIndex), - Struct(StructValueIndex), + Dot(DotIndexKind), } impl Parse for ValueIndex { fn parse(mut stream: TokenStream) -> Result { match stream.peek() { Some(Token::BracketOpen) => Ok(ValueIndex::Array(stream.parse()?)), - Some(Token::Dot) => Ok(ValueIndex::Struct(stream.parse()?)), + Some(Token::Dot) => Ok(ValueIndex::Dot(stream.parse()?)), _ => Err(stream.expecting_err("value or struct index")?), } } @@ -792,13 +808,24 @@ impl Parse for ArrayValueIndex { } #[derive(Debug, Clone)] -pub struct StructValueIndex(String); +pub enum DotIndexKind { + StructValueIndex(String), + FunctionCall(FunctionCallExpression), +} -impl Parse for StructValueIndex { +impl Parse for DotIndexKind { fn parse(mut stream: TokenStream) -> Result { stream.expect(Token::Dot)?; if let Some(Token::Identifier(name)) = stream.next() { - Ok(StructValueIndex(name)) + if let Ok(args) = stream.parse::() { + Ok(Self::FunctionCall(FunctionCallExpression( + name, + args.0, + stream.get_range_prev().unwrap(), + ))) + } else { + Ok(Self::StructValueIndex(name)) + } } else { return Err(stream.expected_err("struct index (number)")?); } diff --git a/reid/src/ast/process.rs b/reid/src/ast/process.rs index 29ef2e4..4ae0e87 100644 --- a/reid/src/ast/process.rs +++ b/reid/src/ast/process.rs @@ -422,6 +422,19 @@ impl ast::Expression { meta: fn_call_expr.2.as_meta(module_id), }, ), + ast::ExpressionKind::AccessCall(expression, fn_call_expr) => { + let mut params: Vec<_> = fn_call_expr.1.iter().map(|e| e.process(module_id)).collect(); + params.insert(0, expression.process(module_id)); + mir::ExprKind::AssociatedFunctionCall( + mir::TypeKind::Vague(mir::VagueType::Unknown), + mir::FunctionCall { + name: fn_call_expr.0.clone(), + return_type: mir::TypeKind::Vague(mir::VagueType::Unknown), + parameters: params, + meta: fn_call_expr.2.as_meta(module_id), + }, + ) + } }; mir::Expression(kind, self.1.as_meta(module_id)) diff --git a/reid/src/mir/typecheck/mod.rs b/reid/src/mir/typecheck/mod.rs index 93b1fd7..ed524c0 100644 --- a/reid/src/mir/typecheck/mod.rs +++ b/reid/src/mir/typecheck/mod.rs @@ -82,6 +82,8 @@ pub enum ErrorKind { BinaryOpAlreadyDefined(BinaryOperator, TypeKind, TypeKind), #[error("Binary operation {0} between {1} and {2} is not defined")] InvalidBinop(BinaryOperator, TypeKind, TypeKind), + #[error("Could not infer type for {0:?}. Try adding type annotations.")] + CouldNotInferType(String), } #[derive(Clone, Debug, PartialEq, Eq)] @@ -294,13 +296,13 @@ impl TypeKind { } } - pub(super) fn assert_known(&self, refs: &TypeRefs, state: &TypecheckPassState) -> Result { - self.is_known(refs, state).map(|_| self.clone()) + pub(super) fn assert_known(&self, state: &TypecheckPassState) -> Result { + self.is_known(state).map(|_| self.clone()) } - pub(super) fn is_known(&self, refs: &TypeRefs, state: &TypecheckPassState) -> Result<(), ErrorKind> { + pub(super) fn is_known(&self, state: &TypecheckPassState) -> Result<(), ErrorKind> { match &self { - TypeKind::Array(type_kind, _) => type_kind.as_ref().is_known(refs, state), + TypeKind::Array(type_kind, _) => type_kind.as_ref().is_known(state), TypeKind::CustomType(custom_type_key) => { state .scope @@ -311,9 +313,9 @@ impl TypeKind { state.module_id.unwrap(), )) } - TypeKind::Borrow(type_kind, _) => type_kind.is_known(refs, state), - TypeKind::UserPtr(type_kind) => type_kind.is_known(refs, state), - TypeKind::CodegenPtr(type_kind) => type_kind.is_known(refs, state), + TypeKind::Borrow(type_kind, _) => type_kind.is_known(state), + TypeKind::UserPtr(type_kind) => type_kind.is_known(state), + TypeKind::CodegenPtr(type_kind) => type_kind.is_known(state), TypeKind::Vague(vague_type) => Err(ErrorKind::TypeIsVague(*vague_type)), _ => Ok(()), } diff --git a/reid/src/mir/typecheck/typecheck.rs b/reid/src/mir/typecheck/typecheck.rs index 2c8aa3b..1495b1a 100644 --- a/reid/src/mir/typecheck/typecheck.rs +++ b/reid/src/mir/typecheck/typecheck.rs @@ -112,7 +112,7 @@ impl BinopDefinition { fn typecheck(&mut self, typerefs: &TypeRefs, state: &mut TypecheckPassState) -> Result { for param in vec![&self.lhs, &self.rhs] { let param_t = state.or_else( - param.1.assert_known(typerefs, state), + param.1.assert_known(state), TypeKind::Vague(Vague::Unknown), self.signature(), ); @@ -130,7 +130,7 @@ impl BinopDefinition { state.ok(res, self.signature()); } - let return_type = self.return_type.clone().assert_known(typerefs, state)?; + let return_type = self.return_type.clone().assert_known(state)?; state.scope.return_type_hint = Some(self.return_type.clone()); let inferred = self @@ -150,7 +150,7 @@ impl FunctionDefinition { fn typecheck(&mut self, typerefs: &TypeRefs, state: &mut TypecheckPassState) -> Result { for param in &self.parameters { let param_t = state.or_else( - param.1.assert_known(typerefs, state), + param.1.assert_known(state), TypeKind::Vague(Vague::Unknown), self.signature(), ); @@ -168,7 +168,7 @@ impl FunctionDefinition { state.ok(res, self.signature()); } - let return_type = self.return_type.clone().assert_known(typerefs, state)?; + let return_type = self.return_type.clone().assert_known(state)?; let inferred = self.kind.typecheck(typerefs, state, Some(self.return_type.clone())); match inferred { @@ -327,7 +327,7 @@ impl Block { } StmtKind::While(WhileStatement { condition, block, meta }) => { let condition_ty = condition.typecheck(&mut state, typerefs, HintKind::Coerce(TypeKind::Bool))?; - if condition_ty.assert_known(typerefs, &state)? != TypeKind::Bool { + if condition_ty.assert_known(&state)? != TypeKind::Bool { state.note_errors(&vec![ErrorKind::TypesIncompatible(condition_ty, TypeKind::Bool)], *meta); } diff --git a/reid/src/mir/typecheck/typeinference.rs b/reid/src/mir/typecheck/typeinference.rs index a49862e..bd7712c 100644 --- a/reid/src/mir/typecheck/typeinference.rs +++ b/reid/src/mir/typecheck/typeinference.rs @@ -595,6 +595,19 @@ impl Expression { Ok(type_refs.from_type(type_kind).unwrap()) } ExprKind::AssociatedFunctionCall(type_kind, function_call) => { + if type_kind.is_known(state).is_err() { + let first_param = function_call + .parameters + .get_mut(0) + .expect("Unknown-type associated function NEEDS to always have at least one parameter!"); + let param_ty = first_param.infer_types(state, type_refs).unwrap().resolve_deep(); + *type_kind = state.or_else( + param_ty.ok_or(ErrorKind::CouldNotInferType(format!("{}", first_param))), + Void, + first_param.1, + ); + } + // Get function definition and types let fn_call = state .scope