From 109fedb624c1ad8ee248b08138546994da9a1965 Mon Sep 17 00:00:00 2001 From: sofia Date: Sun, 3 Aug 2025 23:54:49 +0300 Subject: [PATCH] Make get-references work for multifiles --- reid-lsp/src/analysis.rs | 212 +++++++++++++++++++++++++++++---------- reid-lsp/src/main.rs | 190 ++++++++++++++++++----------------- 2 files changed, 259 insertions(+), 143 deletions(-) diff --git a/reid-lsp/src/analysis.rs b/reid-lsp/src/analysis.rs index 2982ef8..ffc42f3 100644 --- a/reid-lsp/src/analysis.rs +++ b/reid-lsp/src/analysis.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::format, path::PathBuf}; +use std::{collections::HashMap, fmt::format, hash::Hash, path::PathBuf}; use reid::{ ast::{ @@ -45,26 +45,36 @@ pub struct StaticAnalysis { } impl StaticAnalysis { - pub fn find_definition(&self, token_idx: usize) -> Option<&FullToken> { + pub fn find_definition(&self, token_idx: usize, map: &StateMap) -> Option<(SourceModuleId, &FullToken)> { let semantic_token = self.state.map.get(&token_idx)?; let symbol_id = semantic_token.symbol?; - let definition_id = self.state.find_definition(&symbol_id); + let (module_id, definition_id) = self.state.find_definition(&symbol_id, map); - let def_token_idx = self.state.symbol_to_token.get(&definition_id)?; - self.tokens.get(*def_token_idx) + if module_id == self.state.module_id { + self.state + .symbol_to_token + .get(&definition_id) + .and_then(|def_token_idx| self.tokens.get(*def_token_idx).map(|t| ((module_id, t)))) + } else { + map.get(&module_id) + .and_then(|state| state.symbol_to_token.get(&definition_id)) + .and_then(|def_token_idx| self.tokens.get(*def_token_idx).map(|t| ((module_id, t)))) + } } - pub fn find_references(&self, token_idx: usize) -> Option> { + pub fn find_references(&self, token_idx: usize, map: &StateMap) -> Option> { let mut references = Vec::new(); let semantic_token = self.state.map.get(&token_idx)?; let symbol_id = semantic_token.symbol?; - let definition_id = self.state.find_definition(&symbol_id); - references.push(definition_id); + let (def_module_id, definition_id) = self.state.find_definition(&symbol_id, map); + references.push((def_module_id, definition_id)); - for (symbol_idx, semantic_symbol) in self.state.symbol_table.iter().enumerate() { - if let SemanticKind::Reference(ref_idx) = semantic_symbol.kind { - if ref_idx == definition_id { - references.push(SymbolId(symbol_idx)); + for state in map.values() { + for (symbol_idx, semantic_symbol) in state.symbol_table.iter().enumerate() { + if let SemanticKind::Reference(module_id, ref_idx) = semantic_symbol.kind { + if def_module_id == module_id && ref_idx == definition_id { + references.push((state.module_id, SymbolId(symbol_idx))); + } } } } @@ -121,6 +131,8 @@ pub struct AnalysisState { /// SymbolID -> Symbol pub symbol_to_token: HashMap, + module_id: SourceModuleId, + functions: HashMap, associated_functions: HashMap<(TypeKind, String), SymbolId>, properties: HashMap<(TypeKind, String), SymbolId>, @@ -130,7 +142,7 @@ pub struct AnalysisState { pub type StateMap = HashMap; impl AnalysisState { - pub fn get_symbol(&self, id: SymbolId) -> &Symbol { + pub fn get_local_symbol(&self, id: SymbolId) -> &Symbol { self.symbol_table.get(id.0).unwrap() } } @@ -190,11 +202,19 @@ impl AnalysisState { id } - pub fn find_definition(&self, id: &SymbolId) -> SymbolId { - let symbol = self.get_symbol(*id); + pub fn find_definition(&self, id: &SymbolId, map: &StateMap) -> (SourceModuleId, SymbolId) { + let symbol = self.get_local_symbol(*id); match symbol.kind { - SemanticKind::Reference(idx) => self.find_definition(&idx), - _ => *id, + SemanticKind::Reference(module_id, idx) => { + if module_id == self.module_id { + self.find_definition(&idx, map) + } else { + map.get(&module_id) + .map(|state| state.find_definition(&idx, map)) + .unwrap_or((self.module_id, *id)) + } + } + _ => (self.module_id, *id), } } } @@ -210,6 +230,9 @@ pub struct AnalysisScope<'a> { state: &'a mut AnalysisState, tokens: &'a Vec, variables: HashMap, + types: HashMap, + functions: HashMap, + associated_functions: HashMap<(TypeKind, String), (SourceModuleId, SymbolId)>, map: &'a StateMap, } @@ -220,6 +243,9 @@ impl<'a> AnalysisScope<'a> { map: self.map, tokens: self.tokens, variables: self.variables.clone(), + types: self.types.clone(), + functions: self.functions.clone(), + associated_functions: self.associated_functions.clone(), } } @@ -237,6 +263,36 @@ impl<'a> AnalysisScope<'a> { } return None; } + + pub fn find_property(&self, ty: TypeKind, property: String) -> Option<(SourceModuleId, SymbolId)> { + match &ty { + TypeKind::CustomType(CustomTypeKey(_, module_id)) => { + if *module_id == self.state.module_id { + self.state + .properties + .get(&(ty.clone(), property.clone())) + .cloned() + .map(|p| (*module_id, p)) + } else { + if let Some(state) = self.map.get(&module_id) { + state + .properties + .get(&(ty.clone(), property.clone())) + .cloned() + .map(|p| (*module_id, p)) + } else { + None + } + } + } + _ => self + .state + .properties + .get(&(ty, property.clone())) + .cloned() + .map(|p| (self.state.module_id, p)), + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -252,7 +308,7 @@ pub enum SemanticKind { Comment, Operator, Keyword, - Reference(SymbolId), + Reference(SourceModuleId, SymbolId), } impl Default for SemanticKind { @@ -262,7 +318,7 @@ impl Default for SemanticKind { } impl SemanticKind { - pub fn into_token_idx(&self, state: &AnalysisState) -> Option { + pub fn into_token_idx(&self, map: &StateMap) -> Option { let token_type = match self { SemanticKind::Variable => SemanticTokenType::VARIABLE, SemanticKind::Function => SemanticTokenType::FUNCTION, @@ -275,7 +331,16 @@ impl SemanticKind { SemanticKind::Operator => SemanticTokenType::OPERATOR, SemanticKind::Keyword => SemanticTokenType::KEYWORD, SemanticKind::Default => return None, - SemanticKind::Reference(symbol_id) => return state.get_symbol(*symbol_id).kind.into_token_idx(state), + SemanticKind::Reference(module_id, symbol_id) => { + return map + .get(module_id) + .unwrap() + .symbol_table + .get(symbol_id.0) + .unwrap() + .kind + .into_token_idx(map); + } }; TOKEN_LEGEND .iter() @@ -297,7 +362,7 @@ impl SemanticKind { SemanticKind::Comment => return None, SemanticKind::Operator => return None, SemanticKind::Keyword => return None, - SemanticKind::Reference(_) => SEMANTIC_REFERENCE, + SemanticKind::Reference(..) => SEMANTIC_REFERENCE, }; MODIFIER_LEGEND .iter() @@ -355,6 +420,7 @@ pub fn analyze_context( associated_functions: HashMap::new(), properties: HashMap::new(), types: HashMap::new(), + module_id: module.module_id, }; let mut scope = AnalysisScope { @@ -362,6 +428,9 @@ pub fn analyze_context( tokens: &module.tokens, variables: HashMap::new(), map, + types: HashMap::new(), + functions: HashMap::new(), + associated_functions: HashMap::new(), }; for (i, token) in module.tokens.iter().enumerate() { @@ -454,6 +523,12 @@ pub fn analyze_context( for typedef in &module.typedefs { if typedef.source_module != module.module_id { + if let Some(state) = map.get(&typedef.source_module) { + let ty = TypeKind::CustomType(CustomTypeKey(typedef.name.clone(), typedef.source_module)); + if let Some(symbol) = state.types.get(&ty) { + scope.types.insert(ty, (typedef.source_module, *symbol)); + } + } continue; } @@ -466,10 +541,10 @@ pub fn analyze_context( .state .new_symbol(struct_idx, SemanticKind::Struct, module.module_id); scope.state.set_symbol(struct_idx, struct_symbol); - scope.state.types.insert( - TypeKind::CustomType(CustomTypeKey(typedef.name.clone(), typedef.source_module)), - struct_symbol, - ); + + let ty = TypeKind::CustomType(CustomTypeKey(typedef.name.clone(), typedef.source_module)); + scope.state.types.insert(ty.clone(), struct_symbol); + scope.types.insert(ty, (typedef.source_module, struct_symbol)); for field in fields { let field_idx = scope @@ -526,6 +601,17 @@ pub fn analyze_context( } for (ty, function) in &module.associated_functions { + if let Some(source_id) = function.source { + if let Some(state) = map.get(&source_id) { + if let Some(symbol) = state.associated_functions.get(&(ty.clone(), function.name.clone())) { + scope + .associated_functions + .insert((ty.clone(), function.name.clone()), (source_id, *symbol)); + } + } + continue; + } + let idx = scope .token_idx(&function.signature(), |t| matches!(t, Token::Identifier(_))) .unwrap_or(function.signature().range.end); @@ -559,6 +645,18 @@ pub fn analyze_context( } for function in &module.functions { + if let Some(source_id) = function.source { + if source_id != module.module_id { + dbg!(source_id, &function.name); + if let Some(state) = map.get(&source_id) { + if let Some(symbol) = state.functions.get(&function.name) { + scope.functions.insert(function.name.clone(), (source_id, *symbol)); + } + } + continue; + } + } + scope .state .init_types(&function.signature(), Some(function.return_type.clone())); @@ -668,9 +766,11 @@ pub fn analyze_expr( .token_idx(&var_ref.2, |t| matches!(t, Token::Identifier(_))) .unwrap_or(var_ref.2.range.end); let symbol = if let Some(symbol_id) = scope.variables.get(&var_ref.1) { - scope - .state - .new_symbol(idx, SemanticKind::Reference(*symbol_id), source_module.module_id) + scope.state.new_symbol( + idx, + SemanticKind::Reference(source_module.module_id, *symbol_id), + source_module.module_id, + ) } else { scope.state.new_symbol(idx, SemanticKind::Type, source_module.module_id) }; @@ -709,12 +809,12 @@ pub fn analyze_expr( .token_idx(&meta, |t| matches!(t, Token::Identifier(_))) .unwrap_or(meta.range.end); - let field_symbol = if let Some(symbol_id) = - scope.state.properties.get(&(accessed_type.clone(), name.clone())) + let field_symbol = if let Some((module_id, symbol_id)) = + scope.find_property(accessed_type.clone(), name.clone()) { scope.state.new_symbol( field_idx, - SemanticKind::Reference(*symbol_id), + SemanticKind::Reference(module_id, symbol_id), source_module.module_id, ) } else { @@ -755,9 +855,11 @@ pub fn analyze_expr( .unwrap_or(expr.1.range.end); let struct_symbol = if let Some(symbol_id) = scope.state.types.get(&struct_type) { - scope - .state - .new_symbol(struct_idx, SemanticKind::Reference(*symbol_id), source_module.module_id) + scope.state.new_symbol( + struct_idx, + SemanticKind::Reference(source_module.module_id, *symbol_id), + source_module.module_id, + ) } else { scope .state @@ -772,9 +874,11 @@ pub fn analyze_expr( let field_symbol = if let Some(symbol_id) = scope.state.properties.get(&(struct_type.clone(), field_name.clone())) { - scope - .state - .new_symbol(field_idx, SemanticKind::Reference(*symbol_id), source_module.module_id) + scope.state.new_symbol( + field_idx, + SemanticKind::Reference(source_module.module_id, *symbol_id), + source_module.module_id, + ) } else { scope .state @@ -816,10 +920,12 @@ pub fn analyze_expr( let idx = scope .token_idx(&meta, |t| matches!(t, Token::Identifier(_))) .unwrap_or(meta.range.end); - let symbol = if let Some(symbol_id) = scope.state.functions.get(name) { - scope - .state - .new_symbol(idx, SemanticKind::Reference(*symbol_id), source_module.module_id) + let symbol = if let Some((module_id, symbol_id)) = scope.functions.get(name) { + scope.state.new_symbol( + idx, + SemanticKind::Reference(*module_id, *symbol_id), + source_module.module_id, + ) } else { scope .state @@ -842,10 +948,12 @@ pub fn analyze_expr( ty.clone() }; - let type_symbol = if let Some(symbol_id) = scope.state.types.get(&invoked_ty) { - scope - .state - .new_symbol(type_idx, SemanticKind::Reference(*symbol_id), source_module.module_id) + let type_symbol = if let Some((module_id, symbol_id)) = scope.types.get(&invoked_ty) { + scope.state.new_symbol( + type_idx, + SemanticKind::Reference(*module_id, *symbol_id), + source_module.module_id, + ) } else { scope .state @@ -856,14 +964,14 @@ pub fn analyze_expr( let fn_idx = scope .token_idx(&meta, |t| matches!(t, Token::Identifier(_))) .unwrap_or(meta.range.end); - let fn_symbol = if let Some(symbol_id) = scope - .state - .associated_functions - .get(&(invoked_ty.clone(), name.clone())) + let fn_symbol = if let Some((module_id, symbol_id)) = + scope.associated_functions.get(&(invoked_ty.clone(), name.clone())) { - scope - .state - .new_symbol(fn_idx, SemanticKind::Reference(*symbol_id), source_module.module_id) + scope.state.new_symbol( + fn_idx, + SemanticKind::Reference(*module_id, *symbol_id), + source_module.module_id, + ) } else { scope .state diff --git a/reid-lsp/src/main.rs b/reid-lsp/src/main.rs index 5398ce7..8e47d08 100644 --- a/reid-lsp/src/main.rs +++ b/reid-lsp/src/main.rs @@ -15,11 +15,11 @@ use tower_lsp::lsp_types::{ ReferenceParams, RenameParams, SemanticToken, SemanticTokensLegend, SemanticTokensOptions, SemanticTokensParams, SemanticTokensResult, SemanticTokensServerCapabilities, ServerCapabilities, TextDocumentItem, TextDocumentRegistrationOptions, TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions, - TextEdit, WorkspaceEdit, WorkspaceFoldersServerCapabilities, WorkspaceServerCapabilities, + TextEdit, Url, WorkspaceEdit, WorkspaceFoldersServerCapabilities, WorkspaceServerCapabilities, }; use tower_lsp::{Client, LanguageServer, LspService, Server, jsonrpc}; -use crate::analysis::{MODIFIER_LEGEND, StaticAnalysis, TOKEN_LEGEND, analyze}; +use crate::analysis::{MODIFIER_LEGEND, StateMap, StaticAnalysis, TOKEN_LEGEND, analyze}; mod analysis; @@ -79,9 +79,7 @@ impl LanguageServer for Backend { static_registration_options: Default::default(), }, )), - definition_provider: Some(OneOf::Left(true)), references_provider: Some(OneOf::Left(true)), - rename_provider: Some(OneOf::Left(true)), ..Default::default() }; Ok(InitializeResult { @@ -228,8 +226,8 @@ impl LanguageServer for Backend { if let Some(token_analysis) = analysis.state.map.get(&i) { if let Some(symbol_id) = token_analysis.symbol { - let symbol = analysis.state.get_symbol(symbol_id); - if let Some(idx) = symbol.kind.into_token_idx(&analysis.state) { + let symbol = analysis.state.get_local_symbol(symbol_id); + if let Some(idx) = symbol.kind.into_token_idx(&self.state_map()) { let semantic_token = SemanticToken { delta_line, delta_start, @@ -252,30 +250,34 @@ impl LanguageServer for Backend { }))) } - async fn goto_definition(&self, params: GotoDefinitionParams) -> jsonrpc::Result> { - let path = PathBuf::from(params.text_document_position_params.text_document.uri.path()); - let analysis = self.analysis.get(&path); - let position = params.text_document_position_params.position; + // async fn goto_definition(&self, params: GotoDefinitionParams) -> jsonrpc::Result> { + // let path = PathBuf::from(params.text_document_position_params.text_document.uri.path()); + // let analysis = self.analysis.get(&path); + // let position = params.text_document_position_params.position; - if let Some(analysis) = &analysis { - let token = analysis.tokens.iter().enumerate().find(|(_, tok)| { - tok.position.1 == position.line + 1 - && (tok.position.0 <= position.character + 1 - && (tok.position.0 + tok.token.len() as u32) > position.character + 1) - }); + // if let Some(analysis) = &analysis { + // let token = analysis.tokens.iter().enumerate().find(|(_, tok)| { + // tok.position.1 == position.line + 1 + // && (tok.position.0 <= position.character + 1 + // && (tok.position.0 + tok.token.len() as u32) > position.character + 1) + // }); - if let Some(token) = token { - if let Some(def_token) = analysis.find_definition(token.0) { - return Ok(Some(GotoDefinitionResponse::Scalar(lsp_types::Location { - uri: params.text_document_position_params.text_document.uri, - range: token_to_range(def_token), - }))); - } - } - }; + // if let Some(token) = token { + // if let Some((module_id, def_token)) = analysis.find_definition(token.0, &self.state_map()) { + // return if let Some(path) = self.module_to_url.get(&module_id) { + // Ok(Some(GotoDefinitionResponse::Scalar(lsp_types::Location { + // uri: Url::from_file_path(path.value()).unwrap(), + // range: token_to_range(def_token), + // }))) + // } else { + // Ok(None) + // }; + // } + // } + // }; - Ok(None) - } + // Ok(None) + // } async fn references(&self, params: ReferenceParams) -> jsonrpc::Result>> { let path = PathBuf::from(params.text_document_position.text_document.uri.path()); @@ -289,20 +291,24 @@ impl LanguageServer for Backend { && (tok.position.0 + tok.token.len() as u32) > position.character + 1) }); if let Some(token) = token { - let tokens = analysis.find_references(token.0).map(|symbols| { - symbols - .iter() - .map(|symbol_id| analysis.state.symbol_to_token.get(&symbol_id).cloned().unwrap()) - .collect::>() - }); + let reference_tokens = analysis.find_references(token.0, &self.state_map()); + dbg!(&reference_tokens); let mut locations = Vec::new(); - if let Some(tokens) = tokens { - for token_idx in tokens { - let token = analysis.tokens.get(token_idx).unwrap(); - locations.push(Location { - uri: params.text_document_position.text_document.uri.clone(), - range: token_to_range(token), - }); + if let Some(reference_tokens) = reference_tokens { + for (module_id, symbol_idx) in reference_tokens { + if let Some(path) = self.module_to_url.get(&module_id) { + let url = Url::from_file_path(path.value()).unwrap(); + if let Some(inner_analysis) = self.analysis.get(path.value()) { + if let Some(token_idx) = inner_analysis.state.symbol_to_token.get(&symbol_idx) { + let token = inner_analysis.tokens.get(*token_idx).unwrap(); + + locations.push(lsp_types::Location { + uri: url, + range: token_to_range(token), + }); + } + } + } } } Ok(Some(locations)) @@ -314,49 +320,48 @@ impl LanguageServer for Backend { } } - async fn rename(&self, params: RenameParams) -> jsonrpc::Result> { - let path = PathBuf::from(params.text_document_position.text_document.uri.path()); - let file_name = path.file_name().unwrap().to_str().unwrap().to_owned(); - let analysis = self.analysis.get(&path); - let position = params.text_document_position.position; + // async fn rename(&self, params: RenameParams) -> jsonrpc::Result> { + // let path = PathBuf::from(params.text_document_position.text_document.uri.path()); + // let analysis = self.analysis.get(&path); + // let position = params.text_document_position.position; - if let Some(analysis) = &analysis { - let token = analysis.tokens.iter().enumerate().find(|(_, tok)| { - tok.position.1 == position.line + 1 - && (tok.position.0 <= position.character + 1 - && (tok.position.0 + tok.token.len() as u32) > position.character + 1) - }); - if let Some(token) = token { - let tokens = analysis.find_references(token.0).map(|symbols| { - symbols - .iter() - .map(|symbol_id| analysis.state.symbol_to_token.get(&symbol_id).cloned().unwrap()) - .collect::>() - }); - let mut edits = Vec::new(); - if let Some(tokens) = tokens { - for token_idx in tokens { - let token = analysis.tokens.get(token_idx).unwrap(); - edits.push(TextEdit { - range: token_to_range(token), - new_text: params.new_name.clone(), - }); - } - } - let mut changes = HashMap::new(); - changes.insert(params.text_document_position.text_document.uri, edits); - Ok(Some(WorkspaceEdit { - changes: Some(changes), - document_changes: None, - change_annotations: None, - })) - } else { - Ok(None) - } - } else { - Ok(None) - } - } + // if let Some(analysis) = &analysis { + // let token = analysis.tokens.iter().enumerate().find(|(_, tok)| { + // tok.position.1 == position.line + 1 + // && (tok.position.0 <= position.character + 1 + // && (tok.position.0 + tok.token.len() as u32) > position.character + 1) + // }); + // if let Some(token) = token { + // let tokens = analysis.find_references(token.0, &self.state_map()).map(|symbols| { + // symbols + // .iter() + // .map(|symbol_id| analysis.state.symbol_to_token.get(&symbol_id).cloned().unwrap()) + // .collect::>() + // }); + // let mut edits = Vec::new(); + // if let Some(tokens) = tokens { + // for token_idx in tokens { + // let token = analysis.tokens.get(token_idx).unwrap(); + // edits.push(TextEdit { + // range: token_to_range(token), + // new_text: params.new_name.clone(), + // }); + // } + // } + // let mut changes = HashMap::new(); + // changes.insert(params.text_document_position.text_document.uri, edits); + // Ok(Some(WorkspaceEdit { + // changes: Some(changes), + // document_changes: None, + // change_annotations: None, + // })) + // } else { + // Ok(None) + // } + // } else { + // Ok(None) + // } + // } } fn token_to_range(token: &FullToken) -> lsp_types::Range { @@ -373,6 +378,17 @@ fn token_to_range(token: &FullToken) -> lsp_types::Range { } impl Backend { + fn state_map(&self) -> StateMap { + let mut state_map = HashMap::new(); + for path_state in self.analysis.iter() { + let (path, state) = path_state.pair(); + if let Some(module_id) = self.path_to_module.get(path) { + state_map.insert(*module_id, state.state.clone()); + } + } + state_map + } + async fn recompile(&self, params: TextDocumentItem) { let file_path = PathBuf::from(params.uri.clone().path()); @@ -397,19 +413,11 @@ impl Backend { module_id }; - let mut state_map = HashMap::new(); - for path_state in self.analysis.iter() { - let (path, state) = path_state.pair(); - if let Some(module_id) = self.path_to_module.get(path) { - state_map.insert(*module_id, state.state.clone()); - } - } - let parse_res = parse(¶ms.text, file_path.clone(), &mut map, module_id); let (tokens, result) = match parse_res { Ok((module_id, tokens)) => ( tokens.clone(), - analyze(module_id, tokens, file_path.clone(), &mut map, &state_map), + analyze(module_id, tokens, file_path.clone(), &mut map, &self.state_map()), ), Err(e) => (Vec::new(), Err(e)), };