Make get-references work for multifiles

This commit is contained in:
Sofia 2025-08-03 23:54:49 +03:00
parent d27ec2bb70
commit 109fedb624
2 changed files with 259 additions and 143 deletions

View File

@ -1,4 +1,4 @@
use std::{collections::HashMap, fmt::format, path::PathBuf};
use std::{collections::HashMap, fmt::format, hash::Hash, path::PathBuf};
use reid::{
ast::{
@ -45,26 +45,36 @@ pub struct StaticAnalysis {
}
impl StaticAnalysis {
pub fn find_definition(&self, token_idx: usize) -> Option<&FullToken> {
pub fn find_definition(&self, token_idx: usize, map: &StateMap) -> Option<(SourceModuleId, &FullToken)> {
let semantic_token = self.state.map.get(&token_idx)?;
let symbol_id = semantic_token.symbol?;
let definition_id = self.state.find_definition(&symbol_id);
let (module_id, definition_id) = self.state.find_definition(&symbol_id, map);
let def_token_idx = self.state.symbol_to_token.get(&definition_id)?;
self.tokens.get(*def_token_idx)
if module_id == self.state.module_id {
self.state
.symbol_to_token
.get(&definition_id)
.and_then(|def_token_idx| self.tokens.get(*def_token_idx).map(|t| ((module_id, t))))
} else {
map.get(&module_id)
.and_then(|state| state.symbol_to_token.get(&definition_id))
.and_then(|def_token_idx| self.tokens.get(*def_token_idx).map(|t| ((module_id, t))))
}
}
pub fn find_references(&self, token_idx: usize) -> Option<Vec<SymbolId>> {
pub fn find_references(&self, token_idx: usize, map: &StateMap) -> Option<Vec<(SourceModuleId, SymbolId)>> {
let mut references = Vec::new();
let semantic_token = self.state.map.get(&token_idx)?;
let symbol_id = semantic_token.symbol?;
let definition_id = self.state.find_definition(&symbol_id);
references.push(definition_id);
let (def_module_id, definition_id) = self.state.find_definition(&symbol_id, map);
references.push((def_module_id, definition_id));
for (symbol_idx, semantic_symbol) in self.state.symbol_table.iter().enumerate() {
if let SemanticKind::Reference(ref_idx) = semantic_symbol.kind {
if ref_idx == definition_id {
references.push(SymbolId(symbol_idx));
for state in map.values() {
for (symbol_idx, semantic_symbol) in state.symbol_table.iter().enumerate() {
if let SemanticKind::Reference(module_id, ref_idx) = semantic_symbol.kind {
if def_module_id == module_id && ref_idx == definition_id {
references.push((state.module_id, SymbolId(symbol_idx)));
}
}
}
}
@ -121,6 +131,8 @@ pub struct AnalysisState {
/// SymbolID -> Symbol
pub symbol_to_token: HashMap<SymbolId, usize>,
module_id: SourceModuleId,
functions: HashMap<String, SymbolId>,
associated_functions: HashMap<(TypeKind, String), SymbolId>,
properties: HashMap<(TypeKind, String), SymbolId>,
@ -130,7 +142,7 @@ pub struct AnalysisState {
pub type StateMap = HashMap<SourceModuleId, AnalysisState>;
impl AnalysisState {
pub fn get_symbol(&self, id: SymbolId) -> &Symbol {
pub fn get_local_symbol(&self, id: SymbolId) -> &Symbol {
self.symbol_table.get(id.0).unwrap()
}
}
@ -190,11 +202,19 @@ impl AnalysisState {
id
}
pub fn find_definition(&self, id: &SymbolId) -> SymbolId {
let symbol = self.get_symbol(*id);
pub fn find_definition(&self, id: &SymbolId, map: &StateMap) -> (SourceModuleId, SymbolId) {
let symbol = self.get_local_symbol(*id);
match symbol.kind {
SemanticKind::Reference(idx) => self.find_definition(&idx),
_ => *id,
SemanticKind::Reference(module_id, idx) => {
if module_id == self.module_id {
self.find_definition(&idx, map)
} else {
map.get(&module_id)
.map(|state| state.find_definition(&idx, map))
.unwrap_or((self.module_id, *id))
}
}
_ => (self.module_id, *id),
}
}
}
@ -210,6 +230,9 @@ pub struct AnalysisScope<'a> {
state: &'a mut AnalysisState,
tokens: &'a Vec<FullToken>,
variables: HashMap<String, SymbolId>,
types: HashMap<TypeKind, (SourceModuleId, SymbolId)>,
functions: HashMap<String, (SourceModuleId, SymbolId)>,
associated_functions: HashMap<(TypeKind, String), (SourceModuleId, SymbolId)>,
map: &'a StateMap,
}
@ -220,6 +243,9 @@ impl<'a> AnalysisScope<'a> {
map: self.map,
tokens: self.tokens,
variables: self.variables.clone(),
types: self.types.clone(),
functions: self.functions.clone(),
associated_functions: self.associated_functions.clone(),
}
}
@ -237,6 +263,36 @@ impl<'a> AnalysisScope<'a> {
}
return None;
}
pub fn find_property(&self, ty: TypeKind, property: String) -> Option<(SourceModuleId, SymbolId)> {
match &ty {
TypeKind::CustomType(CustomTypeKey(_, module_id)) => {
if *module_id == self.state.module_id {
self.state
.properties
.get(&(ty.clone(), property.clone()))
.cloned()
.map(|p| (*module_id, p))
} else {
if let Some(state) = self.map.get(&module_id) {
state
.properties
.get(&(ty.clone(), property.clone()))
.cloned()
.map(|p| (*module_id, p))
} else {
None
}
}
}
_ => self
.state
.properties
.get(&(ty, property.clone()))
.cloned()
.map(|p| (self.state.module_id, p)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
@ -252,7 +308,7 @@ pub enum SemanticKind {
Comment,
Operator,
Keyword,
Reference(SymbolId),
Reference(SourceModuleId, SymbolId),
}
impl Default for SemanticKind {
@ -262,7 +318,7 @@ impl Default for SemanticKind {
}
impl SemanticKind {
pub fn into_token_idx(&self, state: &AnalysisState) -> Option<u32> {
pub fn into_token_idx(&self, map: &StateMap) -> Option<u32> {
let token_type = match self {
SemanticKind::Variable => SemanticTokenType::VARIABLE,
SemanticKind::Function => SemanticTokenType::FUNCTION,
@ -275,7 +331,16 @@ impl SemanticKind {
SemanticKind::Operator => SemanticTokenType::OPERATOR,
SemanticKind::Keyword => SemanticTokenType::KEYWORD,
SemanticKind::Default => return None,
SemanticKind::Reference(symbol_id) => return state.get_symbol(*symbol_id).kind.into_token_idx(state),
SemanticKind::Reference(module_id, symbol_id) => {
return map
.get(module_id)
.unwrap()
.symbol_table
.get(symbol_id.0)
.unwrap()
.kind
.into_token_idx(map);
}
};
TOKEN_LEGEND
.iter()
@ -297,7 +362,7 @@ impl SemanticKind {
SemanticKind::Comment => return None,
SemanticKind::Operator => return None,
SemanticKind::Keyword => return None,
SemanticKind::Reference(_) => SEMANTIC_REFERENCE,
SemanticKind::Reference(..) => SEMANTIC_REFERENCE,
};
MODIFIER_LEGEND
.iter()
@ -355,6 +420,7 @@ pub fn analyze_context(
associated_functions: HashMap::new(),
properties: HashMap::new(),
types: HashMap::new(),
module_id: module.module_id,
};
let mut scope = AnalysisScope {
@ -362,6 +428,9 @@ pub fn analyze_context(
tokens: &module.tokens,
variables: HashMap::new(),
map,
types: HashMap::new(),
functions: HashMap::new(),
associated_functions: HashMap::new(),
};
for (i, token) in module.tokens.iter().enumerate() {
@ -454,6 +523,12 @@ pub fn analyze_context(
for typedef in &module.typedefs {
if typedef.source_module != module.module_id {
if let Some(state) = map.get(&typedef.source_module) {
let ty = TypeKind::CustomType(CustomTypeKey(typedef.name.clone(), typedef.source_module));
if let Some(symbol) = state.types.get(&ty) {
scope.types.insert(ty, (typedef.source_module, *symbol));
}
}
continue;
}
@ -466,10 +541,10 @@ pub fn analyze_context(
.state
.new_symbol(struct_idx, SemanticKind::Struct, module.module_id);
scope.state.set_symbol(struct_idx, struct_symbol);
scope.state.types.insert(
TypeKind::CustomType(CustomTypeKey(typedef.name.clone(), typedef.source_module)),
struct_symbol,
);
let ty = TypeKind::CustomType(CustomTypeKey(typedef.name.clone(), typedef.source_module));
scope.state.types.insert(ty.clone(), struct_symbol);
scope.types.insert(ty, (typedef.source_module, struct_symbol));
for field in fields {
let field_idx = scope
@ -526,6 +601,17 @@ pub fn analyze_context(
}
for (ty, function) in &module.associated_functions {
if let Some(source_id) = function.source {
if let Some(state) = map.get(&source_id) {
if let Some(symbol) = state.associated_functions.get(&(ty.clone(), function.name.clone())) {
scope
.associated_functions
.insert((ty.clone(), function.name.clone()), (source_id, *symbol));
}
}
continue;
}
let idx = scope
.token_idx(&function.signature(), |t| matches!(t, Token::Identifier(_)))
.unwrap_or(function.signature().range.end);
@ -559,6 +645,18 @@ pub fn analyze_context(
}
for function in &module.functions {
if let Some(source_id) = function.source {
if source_id != module.module_id {
dbg!(source_id, &function.name);
if let Some(state) = map.get(&source_id) {
if let Some(symbol) = state.functions.get(&function.name) {
scope.functions.insert(function.name.clone(), (source_id, *symbol));
}
}
continue;
}
}
scope
.state
.init_types(&function.signature(), Some(function.return_type.clone()));
@ -668,9 +766,11 @@ pub fn analyze_expr(
.token_idx(&var_ref.2, |t| matches!(t, Token::Identifier(_)))
.unwrap_or(var_ref.2.range.end);
let symbol = if let Some(symbol_id) = scope.variables.get(&var_ref.1) {
scope
.state
.new_symbol(idx, SemanticKind::Reference(*symbol_id), source_module.module_id)
scope.state.new_symbol(
idx,
SemanticKind::Reference(source_module.module_id, *symbol_id),
source_module.module_id,
)
} else {
scope.state.new_symbol(idx, SemanticKind::Type, source_module.module_id)
};
@ -709,12 +809,12 @@ pub fn analyze_expr(
.token_idx(&meta, |t| matches!(t, Token::Identifier(_)))
.unwrap_or(meta.range.end);
let field_symbol = if let Some(symbol_id) =
scope.state.properties.get(&(accessed_type.clone(), name.clone()))
let field_symbol = if let Some((module_id, symbol_id)) =
scope.find_property(accessed_type.clone(), name.clone())
{
scope.state.new_symbol(
field_idx,
SemanticKind::Reference(*symbol_id),
SemanticKind::Reference(module_id, symbol_id),
source_module.module_id,
)
} else {
@ -755,9 +855,11 @@ pub fn analyze_expr(
.unwrap_or(expr.1.range.end);
let struct_symbol = if let Some(symbol_id) = scope.state.types.get(&struct_type) {
scope
.state
.new_symbol(struct_idx, SemanticKind::Reference(*symbol_id), source_module.module_id)
scope.state.new_symbol(
struct_idx,
SemanticKind::Reference(source_module.module_id, *symbol_id),
source_module.module_id,
)
} else {
scope
.state
@ -772,9 +874,11 @@ pub fn analyze_expr(
let field_symbol =
if let Some(symbol_id) = scope.state.properties.get(&(struct_type.clone(), field_name.clone())) {
scope
.state
.new_symbol(field_idx, SemanticKind::Reference(*symbol_id), source_module.module_id)
scope.state.new_symbol(
field_idx,
SemanticKind::Reference(source_module.module_id, *symbol_id),
source_module.module_id,
)
} else {
scope
.state
@ -816,10 +920,12 @@ pub fn analyze_expr(
let idx = scope
.token_idx(&meta, |t| matches!(t, Token::Identifier(_)))
.unwrap_or(meta.range.end);
let symbol = if let Some(symbol_id) = scope.state.functions.get(name) {
scope
.state
.new_symbol(idx, SemanticKind::Reference(*symbol_id), source_module.module_id)
let symbol = if let Some((module_id, symbol_id)) = scope.functions.get(name) {
scope.state.new_symbol(
idx,
SemanticKind::Reference(*module_id, *symbol_id),
source_module.module_id,
)
} else {
scope
.state
@ -842,10 +948,12 @@ pub fn analyze_expr(
ty.clone()
};
let type_symbol = if let Some(symbol_id) = scope.state.types.get(&invoked_ty) {
scope
.state
.new_symbol(type_idx, SemanticKind::Reference(*symbol_id), source_module.module_id)
let type_symbol = if let Some((module_id, symbol_id)) = scope.types.get(&invoked_ty) {
scope.state.new_symbol(
type_idx,
SemanticKind::Reference(*module_id, *symbol_id),
source_module.module_id,
)
} else {
scope
.state
@ -856,14 +964,14 @@ pub fn analyze_expr(
let fn_idx = scope
.token_idx(&meta, |t| matches!(t, Token::Identifier(_)))
.unwrap_or(meta.range.end);
let fn_symbol = if let Some(symbol_id) = scope
.state
.associated_functions
.get(&(invoked_ty.clone(), name.clone()))
let fn_symbol = if let Some((module_id, symbol_id)) =
scope.associated_functions.get(&(invoked_ty.clone(), name.clone()))
{
scope
.state
.new_symbol(fn_idx, SemanticKind::Reference(*symbol_id), source_module.module_id)
scope.state.new_symbol(
fn_idx,
SemanticKind::Reference(*module_id, *symbol_id),
source_module.module_id,
)
} else {
scope
.state

View File

@ -15,11 +15,11 @@ use tower_lsp::lsp_types::{
ReferenceParams, RenameParams, SemanticToken, SemanticTokensLegend, SemanticTokensOptions, SemanticTokensParams,
SemanticTokensResult, SemanticTokensServerCapabilities, ServerCapabilities, TextDocumentItem,
TextDocumentRegistrationOptions, TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions,
TextEdit, WorkspaceEdit, WorkspaceFoldersServerCapabilities, WorkspaceServerCapabilities,
TextEdit, Url, WorkspaceEdit, WorkspaceFoldersServerCapabilities, WorkspaceServerCapabilities,
};
use tower_lsp::{Client, LanguageServer, LspService, Server, jsonrpc};
use crate::analysis::{MODIFIER_LEGEND, StaticAnalysis, TOKEN_LEGEND, analyze};
use crate::analysis::{MODIFIER_LEGEND, StateMap, StaticAnalysis, TOKEN_LEGEND, analyze};
mod analysis;
@ -79,9 +79,7 @@ impl LanguageServer for Backend {
static_registration_options: Default::default(),
},
)),
definition_provider: Some(OneOf::Left(true)),
references_provider: Some(OneOf::Left(true)),
rename_provider: Some(OneOf::Left(true)),
..Default::default()
};
Ok(InitializeResult {
@ -228,8 +226,8 @@ impl LanguageServer for Backend {
if let Some(token_analysis) = analysis.state.map.get(&i) {
if let Some(symbol_id) = token_analysis.symbol {
let symbol = analysis.state.get_symbol(symbol_id);
if let Some(idx) = symbol.kind.into_token_idx(&analysis.state) {
let symbol = analysis.state.get_local_symbol(symbol_id);
if let Some(idx) = symbol.kind.into_token_idx(&self.state_map()) {
let semantic_token = SemanticToken {
delta_line,
delta_start,
@ -252,30 +250,34 @@ impl LanguageServer for Backend {
})))
}
async fn goto_definition(&self, params: GotoDefinitionParams) -> jsonrpc::Result<Option<GotoDefinitionResponse>> {
let path = PathBuf::from(params.text_document_position_params.text_document.uri.path());
let analysis = self.analysis.get(&path);
let position = params.text_document_position_params.position;
// async fn goto_definition(&self, params: GotoDefinitionParams) -> jsonrpc::Result<Option<GotoDefinitionResponse>> {
// let path = PathBuf::from(params.text_document_position_params.text_document.uri.path());
// let analysis = self.analysis.get(&path);
// let position = params.text_document_position_params.position;
if let Some(analysis) = &analysis {
let token = analysis.tokens.iter().enumerate().find(|(_, tok)| {
tok.position.1 == position.line + 1
&& (tok.position.0 <= position.character + 1
&& (tok.position.0 + tok.token.len() as u32) > position.character + 1)
});
// if let Some(analysis) = &analysis {
// let token = analysis.tokens.iter().enumerate().find(|(_, tok)| {
// tok.position.1 == position.line + 1
// && (tok.position.0 <= position.character + 1
// && (tok.position.0 + tok.token.len() as u32) > position.character + 1)
// });
if let Some(token) = token {
if let Some(def_token) = analysis.find_definition(token.0) {
return Ok(Some(GotoDefinitionResponse::Scalar(lsp_types::Location {
uri: params.text_document_position_params.text_document.uri,
range: token_to_range(def_token),
})));
}
}
};
// if let Some(token) = token {
// if let Some((module_id, def_token)) = analysis.find_definition(token.0, &self.state_map()) {
// return if let Some(path) = self.module_to_url.get(&module_id) {
// Ok(Some(GotoDefinitionResponse::Scalar(lsp_types::Location {
// uri: Url::from_file_path(path.value()).unwrap(),
// range: token_to_range(def_token),
// })))
// } else {
// Ok(None)
// };
// }
// }
// };
Ok(None)
}
// Ok(None)
// }
async fn references(&self, params: ReferenceParams) -> jsonrpc::Result<Option<Vec<Location>>> {
let path = PathBuf::from(params.text_document_position.text_document.uri.path());
@ -289,20 +291,24 @@ impl LanguageServer for Backend {
&& (tok.position.0 + tok.token.len() as u32) > position.character + 1)
});
if let Some(token) = token {
let tokens = analysis.find_references(token.0).map(|symbols| {
symbols
.iter()
.map(|symbol_id| analysis.state.symbol_to_token.get(&symbol_id).cloned().unwrap())
.collect::<Vec<_>>()
});
let reference_tokens = analysis.find_references(token.0, &self.state_map());
dbg!(&reference_tokens);
let mut locations = Vec::new();
if let Some(tokens) = tokens {
for token_idx in tokens {
let token = analysis.tokens.get(token_idx).unwrap();
locations.push(Location {
uri: params.text_document_position.text_document.uri.clone(),
range: token_to_range(token),
});
if let Some(reference_tokens) = reference_tokens {
for (module_id, symbol_idx) in reference_tokens {
if let Some(path) = self.module_to_url.get(&module_id) {
let url = Url::from_file_path(path.value()).unwrap();
if let Some(inner_analysis) = self.analysis.get(path.value()) {
if let Some(token_idx) = inner_analysis.state.symbol_to_token.get(&symbol_idx) {
let token = inner_analysis.tokens.get(*token_idx).unwrap();
locations.push(lsp_types::Location {
uri: url,
range: token_to_range(token),
});
}
}
}
}
}
Ok(Some(locations))
@ -314,49 +320,48 @@ impl LanguageServer for Backend {
}
}
async fn rename(&self, params: RenameParams) -> jsonrpc::Result<Option<WorkspaceEdit>> {
let path = PathBuf::from(params.text_document_position.text_document.uri.path());
let file_name = path.file_name().unwrap().to_str().unwrap().to_owned();
let analysis = self.analysis.get(&path);
let position = params.text_document_position.position;
// async fn rename(&self, params: RenameParams) -> jsonrpc::Result<Option<WorkspaceEdit>> {
// let path = PathBuf::from(params.text_document_position.text_document.uri.path());
// let analysis = self.analysis.get(&path);
// let position = params.text_document_position.position;
if let Some(analysis) = &analysis {
let token = analysis.tokens.iter().enumerate().find(|(_, tok)| {
tok.position.1 == position.line + 1
&& (tok.position.0 <= position.character + 1
&& (tok.position.0 + tok.token.len() as u32) > position.character + 1)
});
if let Some(token) = token {
let tokens = analysis.find_references(token.0).map(|symbols| {
symbols
.iter()
.map(|symbol_id| analysis.state.symbol_to_token.get(&symbol_id).cloned().unwrap())
.collect::<Vec<_>>()
});
let mut edits = Vec::new();
if let Some(tokens) = tokens {
for token_idx in tokens {
let token = analysis.tokens.get(token_idx).unwrap();
edits.push(TextEdit {
range: token_to_range(token),
new_text: params.new_name.clone(),
});
}
}
let mut changes = HashMap::new();
changes.insert(params.text_document_position.text_document.uri, edits);
Ok(Some(WorkspaceEdit {
changes: Some(changes),
document_changes: None,
change_annotations: None,
}))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
// if let Some(analysis) = &analysis {
// let token = analysis.tokens.iter().enumerate().find(|(_, tok)| {
// tok.position.1 == position.line + 1
// && (tok.position.0 <= position.character + 1
// && (tok.position.0 + tok.token.len() as u32) > position.character + 1)
// });
// if let Some(token) = token {
// let tokens = analysis.find_references(token.0, &self.state_map()).map(|symbols| {
// symbols
// .iter()
// .map(|symbol_id| analysis.state.symbol_to_token.get(&symbol_id).cloned().unwrap())
// .collect::<Vec<_>>()
// });
// let mut edits = Vec::new();
// if let Some(tokens) = tokens {
// for token_idx in tokens {
// let token = analysis.tokens.get(token_idx).unwrap();
// edits.push(TextEdit {
// range: token_to_range(token),
// new_text: params.new_name.clone(),
// });
// }
// }
// let mut changes = HashMap::new();
// changes.insert(params.text_document_position.text_document.uri, edits);
// Ok(Some(WorkspaceEdit {
// changes: Some(changes),
// document_changes: None,
// change_annotations: None,
// }))
// } else {
// Ok(None)
// }
// } else {
// Ok(None)
// }
// }
}
fn token_to_range(token: &FullToken) -> lsp_types::Range {
@ -373,6 +378,17 @@ fn token_to_range(token: &FullToken) -> lsp_types::Range {
}
impl Backend {
fn state_map(&self) -> StateMap {
let mut state_map = HashMap::new();
for path_state in self.analysis.iter() {
let (path, state) = path_state.pair();
if let Some(module_id) = self.path_to_module.get(path) {
state_map.insert(*module_id, state.state.clone());
}
}
state_map
}
async fn recompile(&self, params: TextDocumentItem) {
let file_path = PathBuf::from(params.uri.clone().path());
@ -397,19 +413,11 @@ impl Backend {
module_id
};
let mut state_map = HashMap::new();
for path_state in self.analysis.iter() {
let (path, state) = path_state.pair();
if let Some(module_id) = self.path_to_module.get(path) {
state_map.insert(*module_id, state.state.clone());
}
}
let parse_res = parse(&params.text, file_path.clone(), &mut map, module_id);
let (tokens, result) = match parse_res {
Ok((module_id, tokens)) => (
tokens.clone(),
analyze(module_id, tokens, file_path.clone(), &mut map, &state_map),
analyze(module_id, tokens, file_path.clone(), &mut map, &self.state_map()),
),
Err(e) => (Vec::new(), Err(e)),
};