Skip to main content

solidity_language_server/
rename.rs

1use crate::goto;
2use crate::goto::CachedBuild;
3use crate::references;
4use serde_json::Value;
5use std::collections::HashMap;
6use tower_lsp::lsp_types::{Position, Range, TextEdit, Url, WorkspaceEdit};
7
8fn get_text_at_range(source_bytes: &[u8], range: &Range) -> Option<String> {
9    let start_byte = goto::pos_to_bytes(source_bytes, range.start);
10    let end_byte = goto::pos_to_bytes(source_bytes, range.end);
11    if end_byte > source_bytes.len() {
12        return None;
13    }
14    String::from_utf8(source_bytes[start_byte..end_byte].to_vec()).ok()
15}
16
17fn get_name_location_index(
18    ast_data: &Value,
19    file_uri: &Url,
20    position: Position,
21    source_bytes: &[u8],
22) -> Option<usize> {
23    let sources = ast_data.get("sources")?;
24    let (nodes, path_to_abs, _external_refs) = goto::cache_ids(sources);
25    let path = file_uri.to_file_path().ok()?;
26    let path_str = path.to_str()?;
27    let abs_path = path_to_abs.get(path_str)?;
28    let byte_position = goto::pos_to_bytes(source_bytes, position);
29    let node_id = references::byte_to_id(&nodes, abs_path, byte_position)?;
30    let file_nodes = nodes.get(abs_path)?;
31    let node_info = file_nodes.get(&node_id)?;
32
33    if !node_info.name_locations.is_empty() {
34        for (i, name_loc) in node_info.name_locations.iter().enumerate() {
35            let parts: Vec<&str> = name_loc.split(':').collect();
36            if parts.len() == 3
37                && let (Ok(start), Ok(length)) =
38                    (parts[0].parse::<usize>(), parts[1].parse::<usize>())
39            {
40                let end = start + length;
41                if start <= byte_position && byte_position < end {
42                    return Some(i);
43                }
44            }
45        }
46    }
47    None
48}
49
50pub fn get_identifier_at_position(source_bytes: &[u8], position: Position) -> Option<String> {
51    let text = String::from_utf8_lossy(source_bytes);
52    let lines: Vec<&str> = text.lines().collect();
53    if position.line as usize >= lines.len() {
54        return None;
55    }
56    let line = lines[position.line as usize];
57    if position.character as usize > line.len() {
58        return None;
59    }
60    let mut start = position.character as usize;
61    let mut end = position.character as usize;
62
63    while start > 0
64        && (line.as_bytes()[start - 1].is_ascii_alphanumeric()
65            || line.as_bytes()[start - 1] == b'_')
66    {
67        start -= 1;
68    }
69    while end < line.len()
70        && (line.as_bytes()[end].is_ascii_alphanumeric() || line.as_bytes()[end] == b'_')
71    {
72        end += 1;
73    }
74
75    if start == end {
76        return None;
77    }
78    if line.as_bytes()[start].is_ascii_digit() {
79        return None;
80    }
81
82    Some(line[start..end].to_string())
83}
84
85pub fn get_identifier_range(source_bytes: &[u8], position: Position) -> Option<Range> {
86    let text = String::from_utf8_lossy(source_bytes);
87    let lines: Vec<&str> = text.lines().collect();
88    if position.line as usize >= lines.len() {
89        return None;
90    }
91    let line = lines[position.line as usize];
92    if position.character as usize > line.len() {
93        return None;
94    }
95    let mut start = position.character as usize;
96    let mut end = position.character as usize;
97
98    while start > 0
99        && (line.as_bytes()[start - 1].is_ascii_alphanumeric()
100            || line.as_bytes()[start - 1] == b'_')
101    {
102        start -= 1;
103    }
104    while end < line.len()
105        && (line.as_bytes()[end].is_ascii_alphanumeric() || line.as_bytes()[end] == b'_')
106    {
107        end += 1;
108    }
109
110    if start == end {
111        return None;
112    }
113    if line.as_bytes()[start].is_ascii_digit() {
114        return None;
115    }
116
117    Some(Range {
118        start: Position {
119            line: position.line,
120            character: start as u32,
121        },
122        end: Position {
123            line: position.line,
124            character: end as u32,
125        },
126    })
127}
128
129type Type = HashMap<Url, HashMap<(u32, u32, u32, u32), TextEdit>>;
130
131pub fn rename_symbol(
132    build: &CachedBuild,
133    file_uri: &Url,
134    position: Position,
135    source_bytes: &[u8],
136    new_name: String,
137    other_builds: &[&CachedBuild],
138) -> Option<WorkspaceEdit> {
139    let original_identifier = get_identifier_at_position(source_bytes, position)?;
140    let name_location_index =
141        get_name_location_index(&build.ast, file_uri, position, source_bytes);
142    let mut locations = references::goto_references_with_index(
143        &build.ast,
144        file_uri,
145        position,
146        source_bytes,
147        name_location_index,
148    );
149
150    // Cross-file: scan other cached ASTs for the same target definition
151    if let Some((def_abs_path, def_byte_offset)) =
152        references::resolve_target_location(build, file_uri, position, source_bytes)
153    {
154        for other_build in other_builds {
155            let other_locations = references::goto_references_for_target(
156                other_build,
157                &def_abs_path,
158                def_byte_offset,
159                name_location_index,
160            );
161            locations.extend(other_locations);
162        }
163    }
164
165    // Deduplicate
166    let mut seen = std::collections::HashSet::new();
167    locations.retain(|loc| {
168        seen.insert((
169            loc.uri.clone(),
170            loc.range.start.line,
171            loc.range.start.character,
172            loc.range.end.line,
173            loc.range.end.character,
174        ))
175    });
176
177    if locations.is_empty() {
178        return None;
179    }
180    let mut changes: Type = HashMap::new();
181    for location in locations {
182        // Read the file to check the text at the range
183        let absolute_path = match location.uri.to_file_path() {
184            Ok(p) => p,
185            Err(_) => continue,
186        };
187        let file_source_bytes = match std::fs::read(&absolute_path) {
188            Ok(b) => b,
189            Err(_) => continue,
190        };
191        let text_at_range = match get_text_at_range(&file_source_bytes, &location.range) {
192            Some(t) => t,
193            None => continue,
194        };
195        if text_at_range == original_identifier {
196            let text_edit = TextEdit {
197                range: location.range,
198                new_text: new_name.clone(),
199            };
200            let key = (
201                location.range.start.line,
202                location.range.start.character,
203                location.range.end.line,
204                location.range.end.character,
205            );
206            changes.entry(location.uri).or_default().insert(key, text_edit);
207        }
208    }
209    let changes_vec: HashMap<Url, Vec<TextEdit>> = changes.into_iter()
210        .map(|(uri, edits_map)| (uri, edits_map.into_values().collect()))
211        .collect();
212    Some(WorkspaceEdit {
213        changes: Some(changes_vec),
214        document_changes: None,
215        change_annotations: None,
216    })
217}