1use std::collections::{HashMap, HashSet, VecDeque};
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityAdjacencyMap, EntityGraph};
9
10#[derive(Debug, Clone)]
11pub struct ContextEntry {
12 pub entity_id: String,
13 pub entity_name: String,
14 pub entity_type: String,
15 pub file_path: String,
16 pub role: String,
17 pub content: String,
18 pub estimated_tokens: usize,
19}
20
21#[derive(Debug, Clone, Default)]
22pub struct ContextResult {
23 pub entries: Vec<ContextEntry>,
24 pub total_tokens: usize,
25 pub truncated: bool,
26 pub target_omitted: bool,
27}
28
29fn estimate_tokens(content: &str) -> usize {
31 let words = content.split_whitespace().count();
32 words * 13 / 10
33}
34
35fn signature_only(content: &str) -> String {
37 content.lines().next().unwrap_or("").to_string()
38}
39
40pub fn build_context(
49 graph: &EntityGraph,
50 entity_id: &str,
51 all_entities: &[SemanticEntity],
52 token_budget: usize,
53) -> Vec<ContextEntry> {
54 build_context_result(graph, entity_id, all_entities, token_budget).entries
55}
56
57pub fn build_context_result(
59 graph: &EntityGraph,
60 entity_id: &str,
61 all_entities: &[SemanticEntity],
62 token_budget: usize,
63) -> ContextResult {
64 let entity_lookup: HashMap<&str, &SemanticEntity> =
66 all_entities.iter().map(|e| (e.id.as_str(), e)).collect();
67
68 let mut result = ContextResult::default();
69 let mut included_ids = HashSet::new();
70
71 if let Some(entity) = entity_lookup.get(entity_id) {
74 let full_tokens = estimate_tokens(&entity.content);
75 if full_tokens <= token_budget {
76 push_entry(
77 &mut result,
78 entity,
79 "target",
80 entity.content.clone(),
81 full_tokens,
82 &mut included_ids,
83 );
84 } else {
85 result.truncated = true;
86 let sig = signature_only(&entity.content);
87 let sig_tokens = estimate_tokens(&sig);
88 if sig_tokens <= token_budget {
89 push_entry(
90 &mut result,
91 entity,
92 "target",
93 sig,
94 sig_tokens,
95 &mut included_ids,
96 );
97 } else {
98 result.target_omitted = true;
101 return result;
102 }
103 };
104 }
105
106 let direct_dependencies = graph.get_dependencies(entity_id);
107 for dep_info in &direct_dependencies {
108 add_full_or_signature(
109 &mut result,
110 &entity_lookup,
111 dep_info.id.as_str(),
112 "direct_dependency",
113 token_budget,
114 &mut included_ids,
115 );
116 }
117
118 let direct_dependents = graph.get_dependents(entity_id);
119 for dep_info in &direct_dependents {
120 add_full_or_signature(
121 &mut result,
122 &entity_lookup,
123 dep_info.id.as_str(),
124 "direct_dependent",
125 token_budget,
126 &mut included_ids,
127 );
128 }
129
130 let direct_dependency_ids: HashSet<&str> =
131 direct_dependencies.iter().map(|d| d.id.as_str()).collect();
132 let direct_dependent_ids: HashSet<&str> =
133 direct_dependents.iter().map(|d| d.id.as_str()).collect();
134
135 for dep_info in collect_reachable_related(graph, entity_id, &graph.dependencies) {
136 if direct_dependency_ids.contains(dep_info.id.as_str()) {
137 continue;
138 }
139 add_signature(
140 &mut result,
141 &entity_lookup,
142 dep_info.id.as_str(),
143 "transitive_dependency",
144 token_budget,
145 &mut included_ids,
146 );
147 }
148
149 for dep_info in collect_reachable_related(graph, entity_id, &graph.dependents) {
150 if direct_dependent_ids.contains(dep_info.id.as_str()) {
151 continue;
152 }
153 add_signature(
154 &mut result,
155 &entity_lookup,
156 dep_info.id.as_str(),
157 "transitive_dependent",
158 token_budget,
159 &mut included_ids,
160 );
161 }
162
163 result
164}
165
166fn push_entry(
167 result: &mut ContextResult,
168 entity: &SemanticEntity,
169 role: &str,
170 content: String,
171 tokens: usize,
172 included_ids: &mut HashSet<String>,
173) {
174 result.entries.push(ContextEntry {
175 entity_id: entity.id.clone(),
176 entity_name: entity.name.clone(),
177 entity_type: entity.entity_type.clone(),
178 file_path: entity.file_path.clone(),
179 role: role.to_string(),
180 content,
181 estimated_tokens: tokens,
182 });
183 result.total_tokens += tokens;
184 included_ids.insert(entity.id.clone());
185}
186
187fn add_full_or_signature(
188 result: &mut ContextResult,
189 entity_lookup: &HashMap<&str, &SemanticEntity>,
190 entity_id: &str,
191 role: &str,
192 token_budget: usize,
193 included_ids: &mut HashSet<String>,
194) {
195 if included_ids.contains(entity_id) {
196 return;
197 }
198
199 let Some(entity) = entity_lookup.get(entity_id) else {
200 return;
201 };
202
203 let full_tokens = estimate_tokens(&entity.content);
204 if result.total_tokens + full_tokens <= token_budget {
205 push_entry(
206 result,
207 entity,
208 role,
209 entity.content.clone(),
210 full_tokens,
211 included_ids,
212 );
213 return;
214 }
215
216 result.truncated = true;
217 add_signature(
218 result,
219 entity_lookup,
220 entity_id,
221 role,
222 token_budget,
223 included_ids,
224 );
225}
226
227fn add_signature(
228 result: &mut ContextResult,
229 entity_lookup: &HashMap<&str, &SemanticEntity>,
230 entity_id: &str,
231 role: &str,
232 token_budget: usize,
233 included_ids: &mut HashSet<String>,
234) {
235 if included_ids.contains(entity_id) {
236 return;
237 }
238
239 let Some(entity) = entity_lookup.get(entity_id) else {
240 return;
241 };
242
243 let sig = signature_only(&entity.content);
244 let tokens = estimate_tokens(&sig);
245 if result.total_tokens + tokens <= token_budget {
246 push_entry(result, entity, role, sig, tokens, included_ids);
247 } else {
248 result.truncated = true;
249 }
250}
251
252fn collect_reachable_related<'a>(
254 graph: &'a EntityGraph,
255 entity_id: &str,
256 relationships: &'a EntityAdjacencyMap,
257) -> Vec<&'a crate::parser::graph::EntityInfo> {
258 const MAX_VISITED: usize = 10_000;
259
260 let mut visited: HashSet<&str> = HashSet::new();
261 let mut queue: VecDeque<&str> = VecDeque::new();
262 let mut result = Vec::new();
263
264 let start_key = match graph.entities.get_key_value(entity_id) {
265 Some((key, _)) => key.as_str(),
266 None => return result,
267 };
268
269 queue.push_back(start_key);
270 visited.insert(start_key);
271
272 while let Some(current) = queue.pop_front() {
273 if result.len() >= MAX_VISITED {
274 break;
275 }
276
277 if let Some(next_ids) = relationships.get(current) {
278 for next_id in next_ids {
279 if visited.insert(next_id.as_str()) {
280 if let Some(info) = graph.entities.get(next_id.as_str()) {
281 result.push(info);
282 if result.len() >= MAX_VISITED {
283 return result;
284 }
285 }
286 queue.push_back(next_id.as_str());
287 }
288 }
289 }
290 }
291
292 result
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::parser::graph::{EntityGraph, EntityInfo, EntityRef, RefType};
299 use std::collections::HashMap;
300
301 #[test]
302 fn test_estimate_tokens() {
303 assert_eq!(estimate_tokens("hello world"), 2); assert_eq!(estimate_tokens("fn foo(a: i32, b: i32) -> bool {"), 10); }
306
307 #[test]
308 fn test_signature_only() {
309 assert_eq!(
310 signature_only("fn foo(a: i32) {\n a + 1\n}"),
311 "fn foo(a: i32) {"
312 );
313 }
314
315 #[test]
316 fn test_target_omitted_when_signature_exceeds_budget() {
317 let entities = vec![entity(
318 "a.py::function::helper_b",
319 "helper_b",
320 "def helper_b():\n return 1",
321 )];
322 let graph = graph_from_entities(&entities, vec![]);
323
324 let result = build_context_result(&graph, "a.py::function::helper_b", &entities, 1);
325
326 assert!(result.entries.is_empty());
327 assert_eq!(result.total_tokens, 0);
328 assert!(result.truncated);
329 assert!(result.target_omitted);
330 }
331
332 #[test]
333 fn test_target_signature_respects_budget() {
334 let entities = vec![entity(
335 "a.py::function::helper_b",
336 "helper_b",
337 "def helper_b():\n return expensive_value()",
338 )];
339 let graph = graph_from_entities(&entities, vec![]);
340
341 let result = build_context_result(&graph, "a.py::function::helper_b", &entities, 2);
342
343 assert_eq!(result.total_tokens, 2);
344 assert!(result.truncated);
345 assert!(!result.target_omitted);
346 assert_eq!(result.entries.len(), 1);
347 assert_eq!(result.entries[0].role, "target");
348 assert_eq!(result.entries[0].content, "def helper_b():");
349 }
350
351 #[test]
352 fn test_context_includes_dependencies_before_dependents() {
353 let entities = vec![
354 entity(
355 "a.py::function::main",
356 "main",
357 "def main():\n return helper_a() + helper_b()",
358 ),
359 entity(
360 "a.py::function::helper_a",
361 "helper_a",
362 "def helper_a():\n return leaf()",
363 ),
364 entity(
365 "a.py::function::helper_b",
366 "helper_b",
367 "def helper_b():\n return 2",
368 ),
369 entity("a.py::function::leaf", "leaf", "def leaf():\n return 1"),
370 entity(
371 "a.py::class::Caller",
372 "Caller",
373 "class Caller:\n def go(self):\n return main()",
374 ),
375 entity(
376 "a.py::class::Outer",
377 "Outer",
378 "class Outer:\n def go(self):\n return Caller().go()",
379 ),
380 ];
381 let graph = graph_from_entities(
382 &entities,
383 vec![
384 edge("a.py::function::main", "a.py::function::helper_a"),
385 edge("a.py::function::main", "a.py::function::helper_b"),
386 edge("a.py::function::helper_a", "a.py::function::leaf"),
387 edge("a.py::class::Caller", "a.py::function::main"),
388 edge("a.py::class::Outer", "a.py::class::Caller"),
389 ],
390 );
391
392 let result = build_context_result(&graph, "a.py::function::main", &entities, 999);
393 let roles_and_names: Vec<(&str, &str)> = result
394 .entries
395 .iter()
396 .map(|entry| (entry.role.as_str(), entry.entity_name.as_str()))
397 .collect();
398
399 assert_eq!(
400 roles_and_names,
401 vec![
402 ("target", "main"),
403 ("direct_dependency", "helper_a"),
404 ("direct_dependency", "helper_b"),
405 ("direct_dependent", "Caller"),
406 ("transitive_dependency", "leaf"),
407 ("transitive_dependent", "Outer"),
408 ]
409 );
410 assert!(!result.truncated);
411 assert!(!result.target_omitted);
412 assert!(result.total_tokens <= 999);
413 }
414
415 #[test]
416 fn test_collect_transitive_caps_results() {
417 let mut entities = Vec::new();
418 let mut edges = Vec::new();
419
420 for index in 0..=10_001 {
421 let id = format!("a.py::function::helper_{index}");
422 entities.push(entity(
423 &id,
424 &format!("helper_{index}"),
425 "def helper():\n return 1",
426 ));
427 if index > 0 {
428 edges.push(edge(&format!("a.py::function::helper_{}", index - 1), &id));
429 }
430 }
431
432 let graph = graph_from_entities(&entities, edges);
433 let result =
434 collect_reachable_related(&graph, "a.py::function::helper_0", &graph.dependencies);
435
436 assert_eq!(result.len(), 10_000);
437 }
438
439 fn entity(id: &str, name: &str, content: &str) -> SemanticEntity {
440 SemanticEntity {
441 id: id.to_string(),
442 file_path: "a.py".to_string(),
443 entity_type: id.split("::").nth(1).unwrap_or("function").to_string(),
444 name: name.to_string(),
445 parent_id: None,
446 content: content.to_string(),
447 content_hash: String::new(),
448 structural_hash: None,
449 start_line: 1,
450 end_line: content.lines().count(),
451 metadata: None,
452 }
453 }
454
455 fn edge(from_entity: &str, to_entity: &str) -> EntityRef {
456 EntityRef {
457 from_entity: from_entity.to_string(),
458 to_entity: to_entity.to_string(),
459 ref_type: RefType::Calls,
460 }
461 }
462
463 fn graph_from_entities(entities: &[SemanticEntity], edges: Vec<EntityRef>) -> EntityGraph {
464 let entity_infos: HashMap<String, EntityInfo> = entities
465 .iter()
466 .map(|entity| {
467 (
468 entity.id.clone(),
469 EntityInfo {
470 id: entity.id.clone(),
471 name: entity.name.clone(),
472 entity_type: entity.entity_type.clone(),
473 file_path: entity.file_path.clone(),
474 parent_id: entity.parent_id.clone(),
475 start_line: entity.start_line,
476 end_line: entity.end_line,
477 },
478 )
479 })
480 .collect();
481
482 EntityGraph::from_parts(entity_infos, edges)
483 }
484}