1use anyhow::{Context as AnyhowContext, Result};
7use hashbrown::HashMap;
8use std::fmt::Write;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13use super::entity_resolver::{EntityMatch, FileLocation};
14use super::workspace_state::WorkspaceState;
15use crate::tools::grep_file::{GrepSearchInput, GrepSearchManager};
16use crate::utils::file_utils::read_file_with_context;
17
18const MAX_CONTEXT_FILES: usize = 3;
20
21#[expect(dead_code)]
23const MAX_SNIPPETS_PER_FILE: usize = 20;
24
25const MAX_CONTEXT_TOKENS: usize = 2000;
27
28const CONTEXT_LINES: usize = 10;
30
31#[derive(Debug, Clone)]
33pub struct FileSnippet {
34 pub file: PathBuf,
35 pub line_start: usize,
36 pub line_end: usize,
37 pub content: String,
38 pub relevance_score: f32,
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct GatheredContext {
44 pub files: Vec<PathBuf>,
46
47 pub snippets: HashMap<PathBuf, Vec<FileSnippet>>,
49
50 pub search_results: Option<serde_json::Value>,
52
53 pub estimated_tokens: usize,
55}
56
57impl GatheredContext {
58 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn add_file(&mut self, file: PathBuf) {
65 if !self.files.contains(&file) {
66 self.files.push(file);
67 }
68 }
69
70 pub fn add_files(&mut self, files: Vec<PathBuf>) {
72 for file in files {
73 self.add_file(file);
74 }
75 }
76
77 pub fn add_snippet(&mut self, file: PathBuf, snippet: FileSnippet) {
79 self.estimated_tokens += snippet.content.len() / 4;
81
82 self.snippets.entry(file.clone()).or_default().push(snippet);
83 }
84
85 pub fn add_search_results(&mut self, results: serde_json::Value) {
87 if let Ok(json_str) = serde_json::to_string(&results) {
89 self.estimated_tokens += json_str.len() / 4;
90 }
91
92 self.search_results = Some(results);
93 }
94
95 pub fn is_over_budget(&self) -> bool {
97 self.estimated_tokens > MAX_CONTEXT_TOKENS
98 }
99
100 pub fn to_prompt_text(&self) -> String {
102 let mut text = String::from("## Proactively Gathered Context\n\n");
103
104 if !self.files.is_empty() {
105 text.push_str("### Relevant Files:\n");
106 for file in &self.files {
107 let _ = writeln!(text, "- {}", file.display());
108 }
109 text.push('\n');
110 }
111
112 if !self.snippets.is_empty() {
113 text.push_str("### File Snippets:\n\n");
114 for (file, snippets) in &self.snippets {
115 let _ = writeln!(text, "**{}**:", file.display());
116 for snippet in snippets {
117 let _ = write!(
118 text,
119 "```\n{} (lines {}-{})\n{}\n```\n\n",
120 file.display(),
121 snippet.line_start,
122 snippet.line_end,
123 snippet.content
124 );
125 }
126 }
127 }
128
129 text
130 }
131}
132
133pub struct ProactiveGatherer {
135 grep_manager: Option<Arc<GrepSearchManager>>,
137
138 workspace_state: Arc<RwLock<WorkspaceState>>,
140
141 workspace_root: PathBuf,
143}
144
145impl ProactiveGatherer {
146 pub fn new(workspace_root: PathBuf, workspace_state: Arc<RwLock<WorkspaceState>>) -> Self {
148 Self {
149 grep_manager: None,
150 workspace_state,
151 workspace_root,
152 }
153 }
154
155 pub fn with_grep(
157 workspace_root: PathBuf,
158 workspace_state: Arc<RwLock<WorkspaceState>>,
159 grep_manager: Arc<GrepSearchManager>,
160 ) -> Self {
161 Self {
162 grep_manager: Some(grep_manager),
163 workspace_state,
164 workspace_root,
165 }
166 }
167
168 pub async fn gather_context(&self, entity_matches: &[EntityMatch]) -> Result<GatheredContext> {
170 let mut context = GatheredContext::new();
171
172 let mut candidate_files = Vec::new();
174 for entity_match in entity_matches {
175 for location in &entity_match.locations {
176 candidate_files.push((location.path.clone(), entity_match.total_score()));
177 }
178 }
179
180 let ranked_files = self.rank_files_for_context(candidate_files).await;
182
183 for (file, _score) in ranked_files.iter().take(MAX_CONTEXT_FILES) {
185 if let Some(snippet) = self.read_file_snippet(file, entity_matches).await? {
186 context.add_file(file.clone());
187 context.add_snippet(file.clone(), snippet);
188 }
189
190 if context.is_over_budget() {
192 break;
193 }
194 }
195
196 Ok(context)
197 }
198
199 async fn rank_files_for_context(&self, files: Vec<(PathBuf, f32)>) -> Vec<(PathBuf, f32)> {
201 let state = self.workspace_state.read().await;
202
203 let mut scored_files: Vec<(PathBuf, f32)> = files
204 .into_iter()
205 .map(|(file, base_score)| {
206 let mut score = base_score;
207
208 if state.was_recently_accessed(&file) {
210 score += 0.3;
211 }
212
213 if state.hot_files().iter().any(|(f, _)| f == &file) {
215 score += 0.2;
216 }
217
218 (file, score)
219 })
220 .collect();
221
222 scored_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
224
225 scored_files
226 }
227
228 async fn read_file_snippet(
230 &self,
231 file: &Path,
232 entity_matches: &[EntityMatch],
233 ) -> Result<Option<FileSnippet>> {
234 let mut locations: Vec<&FileLocation> = Vec::new();
236 for entity_match in entity_matches {
237 for location in &entity_match.locations {
238 if location.path == file {
239 locations.push(location);
240 }
241 }
242 }
243
244 if locations.is_empty() {
245 return Ok(None);
246 }
247
248 let location = locations[0];
250
251 let content = read_file_with_context(file, "context file snippet")
253 .await
254 .with_context(|| format!("Failed to read file {:?}", file))?;
255
256 let lines: Vec<&str> = content.lines().collect();
257
258 let line_start = location.line_start.saturating_sub(CONTEXT_LINES).max(1);
260 let line_end = (location.line_end + CONTEXT_LINES).min(lines.len());
261
262 let snippet_content = lines[line_start.saturating_sub(1)..line_end].join("\n");
264
265 Ok(Some(FileSnippet {
266 file: file.to_path_buf(),
267 line_start,
268 line_end,
269 content: snippet_content,
270 relevance_score: 1.0,
271 }))
272 }
273
274 pub fn infer_search_term(&self, vague_term: &str) -> Option<String> {
276 if vague_term.len() >= 3 {
279 Some(vague_term.to_string())
280 } else {
281 None
282 }
283 }
284
285 pub async fn proactive_search(&self, search_term: &str) -> Result<Option<serde_json::Value>> {
287 if let Some(grep_manager) = &self.grep_manager {
288 let input = GrepSearchInput::with_defaults(
289 search_term.to_string(),
290 self.workspace_root.to_string_lossy().to_string(),
291 );
292
293 match grep_manager.perform_search(input).await {
294 Ok(result) => {
295 let json_value = serde_json::to_value(&result.matches)
297 .with_context(|| "Failed to serialize search results")?;
298 Ok(Some(json_value))
299 }
300 Err(_) => Ok(None),
301 }
302 } else {
303 Ok(None)
304 }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_gathered_context_new() {
314 let context = GatheredContext::new();
315 assert_eq!(context.files.len(), 0);
316 assert_eq!(context.estimated_tokens, 0);
317 assert!(!context.is_over_budget());
318 }
319
320 #[test]
321 fn test_gathered_context_add_file() {
322 let mut context = GatheredContext::new();
323 let file = PathBuf::from("src/test.rs");
324
325 context.add_file(file.clone());
326
327 assert_eq!(context.files.len(), 1);
328 assert_eq!(context.files[0], file);
329 }
330
331 #[test]
332 fn test_gathered_context_add_snippet() {
333 let mut context = GatheredContext::new();
334 let file = PathBuf::from("src/test.rs");
335
336 let snippet = FileSnippet {
337 file: file.clone(),
338 line_start: 1,
339 line_end: 10,
340 content: "fn test() {}\n".repeat(10),
341 relevance_score: 1.0,
342 };
343
344 context.add_snippet(file.clone(), snippet);
345
346 assert_eq!(context.snippets.len(), 1);
347 assert!(context.estimated_tokens > 0);
348 }
349
350 #[test]
351 fn test_gathered_context_to_prompt_text() {
352 let mut context = GatheredContext::new();
353 let file = PathBuf::from("src/test.rs");
354
355 context.add_file(file.clone());
356
357 let snippet = FileSnippet {
358 file: file.clone(),
359 line_start: 1,
360 line_end: 5,
361 content: "fn main() {}\n".to_string(),
362 relevance_score: 1.0,
363 };
364
365 context.add_snippet(file, snippet);
366
367 let text = context.to_prompt_text();
368
369 assert!(text.contains("Proactively Gathered Context"));
370 assert!(text.contains("src/test.rs"));
371 assert!(text.contains("fn main()"));
372 }
373
374 #[tokio::test]
375 async fn test_proactive_gatherer_new() {
376 let workspace_root = PathBuf::from("/test");
377 let workspace_state = Arc::new(RwLock::new(WorkspaceState::new()));
378
379 let gatherer = ProactiveGatherer::new(workspace_root.clone(), workspace_state);
380
381 assert_eq!(gatherer.workspace_root, workspace_root);
382 }
383}