Skip to main content

sem_core/parser/
context.rs

1//! Context budgeting: pack optimal entity context into a token budget.
2//! Priority: target entity > direct dependencies > direct dependents > transitive dependencies >
3//! transitive dependents.
4
5use std::collections::{HashMap, HashSet, VecDeque};
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::EntityGraph;
9
10#[derive(Debug, Clone)]
11pub struct ContextEntry {
12    pub entity_id: String,
13    pub entity_name: String,
14    pub entity_type: String,
15    pub file_path: String,
16    pub role: String,
17    pub content: String,
18    pub estimated_tokens: usize,
19}
20
21#[derive(Debug, Clone, Default)]
22pub struct ContextResult {
23    pub entries: Vec<ContextEntry>,
24    pub total_tokens: usize,
25    pub truncated: bool,
26    pub target_omitted: bool,
27}
28
29/// Estimate token count from content. Rough heuristic: ~1.3 tokens per whitespace-separated word.
30fn estimate_tokens(content: &str) -> usize {
31    let words = content.split_whitespace().count();
32    words * 13 / 10
33}
34
35/// Extract just the first line (signature) of an entity's content.
36fn signature_only(content: &str) -> String {
37    content.lines().next().unwrap_or("").to_string()
38}
39
40/// Build a context set for a target entity within a token budget.
41///
42/// Greedy knapsack by priority:
43/// 1. Target entity (full content)
44/// 2. Direct dependencies (full content, signature fallback)
45/// 3. Direct dependents (full content, signature fallback)
46/// 4. Transitive dependencies (signature only)
47/// 5. Transitive dependents (signature only)
48pub fn build_context(
49    graph: &EntityGraph,
50    entity_id: &str,
51    all_entities: &[SemanticEntity],
52    token_budget: usize,
53) -> Vec<ContextEntry> {
54    build_context_result(graph, entity_id, all_entities, token_budget).entries
55}
56
57/// Build a context set plus budget metadata for a target entity.
58pub fn build_context_result(
59    graph: &EntityGraph,
60    entity_id: &str,
61    all_entities: &[SemanticEntity],
62    token_budget: usize,
63) -> ContextResult {
64    // Build content lookup: entity_id -> SemanticEntity
65    let entity_lookup: HashMap<&str, &SemanticEntity> =
66        all_entities.iter().map(|e| (e.id.as_str(), e)).collect();
67
68    let mut result = ContextResult::default();
69    let mut included_ids = HashSet::new();
70
71    // 1. Target entity. Keep the budget strict: if even the signature does not fit,
72    // omit the target and return an empty result instead of overspending.
73    if let Some(entity) = entity_lookup.get(entity_id) {
74        let full_tokens = estimate_tokens(&entity.content);
75        if full_tokens <= token_budget {
76            push_entry(
77                &mut result,
78                entity,
79                "target",
80                entity.content.clone(),
81                full_tokens,
82                &mut included_ids,
83            );
84        } else {
85            result.truncated = true;
86            let sig = signature_only(&entity.content);
87            let sig_tokens = estimate_tokens(&sig);
88            if sig_tokens <= token_budget {
89                push_entry(
90                    &mut result,
91                    entity,
92                    "target",
93                    sig,
94                    sig_tokens,
95                    &mut included_ids,
96                );
97            } else {
98                // Strict context budget contract: no related entries are useful if the
99                // requested target cannot be represented inside the budget.
100                result.target_omitted = true;
101                return result;
102            }
103        };
104    }
105
106    let direct_dependencies = graph.get_dependencies(entity_id);
107    for dep_info in &direct_dependencies {
108        add_full_or_signature(
109            &mut result,
110            &entity_lookup,
111            dep_info.id.as_str(),
112            "direct_dependency",
113            token_budget,
114            &mut included_ids,
115        );
116    }
117
118    let direct_dependents = graph.get_dependents(entity_id);
119    for dep_info in &direct_dependents {
120        add_full_or_signature(
121            &mut result,
122            &entity_lookup,
123            dep_info.id.as_str(),
124            "direct_dependent",
125            token_budget,
126            &mut included_ids,
127        );
128    }
129
130    let direct_dependency_ids: HashSet<&str> =
131        direct_dependencies.iter().map(|d| d.id.as_str()).collect();
132    let direct_dependent_ids: HashSet<&str> =
133        direct_dependents.iter().map(|d| d.id.as_str()).collect();
134
135    for dep_info in collect_reachable_related(graph, entity_id, &graph.dependencies) {
136        if direct_dependency_ids.contains(dep_info.id.as_str()) {
137            continue;
138        }
139        add_signature(
140            &mut result,
141            &entity_lookup,
142            dep_info.id.as_str(),
143            "transitive_dependency",
144            token_budget,
145            &mut included_ids,
146        );
147    }
148
149    for dep_info in collect_reachable_related(graph, entity_id, &graph.dependents) {
150        if direct_dependent_ids.contains(dep_info.id.as_str()) {
151            continue;
152        }
153        add_signature(
154            &mut result,
155            &entity_lookup,
156            dep_info.id.as_str(),
157            "transitive_dependent",
158            token_budget,
159            &mut included_ids,
160        );
161    }
162
163    result
164}
165
166fn push_entry(
167    result: &mut ContextResult,
168    entity: &SemanticEntity,
169    role: &str,
170    content: String,
171    tokens: usize,
172    included_ids: &mut HashSet<String>,
173) {
174    result.entries.push(ContextEntry {
175        entity_id: entity.id.clone(),
176        entity_name: entity.name.clone(),
177        entity_type: entity.entity_type.clone(),
178        file_path: entity.file_path.clone(),
179        role: role.to_string(),
180        content,
181        estimated_tokens: tokens,
182    });
183    result.total_tokens += tokens;
184    included_ids.insert(entity.id.clone());
185}
186
187fn add_full_or_signature(
188    result: &mut ContextResult,
189    entity_lookup: &HashMap<&str, &SemanticEntity>,
190    entity_id: &str,
191    role: &str,
192    token_budget: usize,
193    included_ids: &mut HashSet<String>,
194) {
195    if included_ids.contains(entity_id) {
196        return;
197    }
198
199    let Some(entity) = entity_lookup.get(entity_id) else {
200        return;
201    };
202
203    let full_tokens = estimate_tokens(&entity.content);
204    if result.total_tokens + full_tokens <= token_budget {
205        push_entry(
206            result,
207            entity,
208            role,
209            entity.content.clone(),
210            full_tokens,
211            included_ids,
212        );
213        return;
214    }
215
216    result.truncated = true;
217    add_signature(
218        result,
219        entity_lookup,
220        entity_id,
221        role,
222        token_budget,
223        included_ids,
224    );
225}
226
227fn add_signature(
228    result: &mut ContextResult,
229    entity_lookup: &HashMap<&str, &SemanticEntity>,
230    entity_id: &str,
231    role: &str,
232    token_budget: usize,
233    included_ids: &mut HashSet<String>,
234) {
235    if included_ids.contains(entity_id) {
236        return;
237    }
238
239    let Some(entity) = entity_lookup.get(entity_id) else {
240        return;
241    };
242
243    let sig = signature_only(&entity.content);
244    let tokens = estimate_tokens(&sig);
245    if result.total_tokens + tokens <= token_budget {
246        push_entry(result, entity, role, sig, tokens, included_ids);
247    } else {
248        result.truncated = true;
249    }
250}
251
252/// Collect related entities reachable from `entity_id`, excluding the starting entity.
253fn collect_reachable_related<'a>(
254    graph: &'a EntityGraph,
255    entity_id: &str,
256    relationships: &'a HashMap<String, Vec<String>>,
257) -> Vec<&'a crate::parser::graph::EntityInfo> {
258    const MAX_VISITED: usize = 10_000;
259
260    let mut visited: HashSet<&str> = HashSet::new();
261    let mut queue: VecDeque<&str> = VecDeque::new();
262    let mut result = Vec::new();
263
264    let start_key = match graph.entities.get_key_value(entity_id) {
265        Some((key, _)) => key.as_str(),
266        None => return result,
267    };
268
269    queue.push_back(start_key);
270    visited.insert(start_key);
271
272    while let Some(current) = queue.pop_front() {
273        if result.len() >= MAX_VISITED {
274            break;
275        }
276
277        if let Some(next_ids) = relationships.get(current) {
278            for next_id in next_ids {
279                if visited.insert(next_id.as_str()) {
280                    if let Some(info) = graph.entities.get(next_id.as_str()) {
281                        result.push(info);
282                        if result.len() >= MAX_VISITED {
283                            return result;
284                        }
285                    }
286                    queue.push_back(next_id.as_str());
287                }
288            }
289        }
290    }
291
292    result
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::parser::graph::{EntityGraph, EntityInfo, EntityRef, RefType};
299    use std::collections::HashMap;
300
301    #[test]
302    fn test_estimate_tokens() {
303        assert_eq!(estimate_tokens("hello world"), 2); // 2 * 13 / 10 = 2
304        assert_eq!(estimate_tokens("fn foo(a: i32, b: i32) -> bool {"), 10); // 8 words * 13 / 10 = 10
305    }
306
307    #[test]
308    fn test_signature_only() {
309        assert_eq!(
310            signature_only("fn foo(a: i32) {\n    a + 1\n}"),
311            "fn foo(a: i32) {"
312        );
313    }
314
315    #[test]
316    fn test_target_omitted_when_signature_exceeds_budget() {
317        let entities = vec![entity(
318            "a.py::function::helper_b",
319            "helper_b",
320            "def helper_b():\n    return 1",
321        )];
322        let graph = graph_from_entities(&entities, vec![]);
323
324        let result = build_context_result(&graph, "a.py::function::helper_b", &entities, 1);
325
326        assert!(result.entries.is_empty());
327        assert_eq!(result.total_tokens, 0);
328        assert!(result.truncated);
329        assert!(result.target_omitted);
330    }
331
332    #[test]
333    fn test_target_signature_respects_budget() {
334        let entities = vec![entity(
335            "a.py::function::helper_b",
336            "helper_b",
337            "def helper_b():\n    return expensive_value()",
338        )];
339        let graph = graph_from_entities(&entities, vec![]);
340
341        let result = build_context_result(&graph, "a.py::function::helper_b", &entities, 2);
342
343        assert_eq!(result.total_tokens, 2);
344        assert!(result.truncated);
345        assert!(!result.target_omitted);
346        assert_eq!(result.entries.len(), 1);
347        assert_eq!(result.entries[0].role, "target");
348        assert_eq!(result.entries[0].content, "def helper_b():");
349    }
350
351    #[test]
352    fn test_context_includes_dependencies_before_dependents() {
353        let entities = vec![
354            entity(
355                "a.py::function::main",
356                "main",
357                "def main():\n    return helper_a() + helper_b()",
358            ),
359            entity(
360                "a.py::function::helper_a",
361                "helper_a",
362                "def helper_a():\n    return leaf()",
363            ),
364            entity(
365                "a.py::function::helper_b",
366                "helper_b",
367                "def helper_b():\n    return 2",
368            ),
369            entity("a.py::function::leaf", "leaf", "def leaf():\n    return 1"),
370            entity(
371                "a.py::class::Caller",
372                "Caller",
373                "class Caller:\n    def go(self):\n        return main()",
374            ),
375            entity(
376                "a.py::class::Outer",
377                "Outer",
378                "class Outer:\n    def go(self):\n        return Caller().go()",
379            ),
380        ];
381        let graph = graph_from_entities(
382            &entities,
383            vec![
384                edge("a.py::function::main", "a.py::function::helper_a"),
385                edge("a.py::function::main", "a.py::function::helper_b"),
386                edge("a.py::function::helper_a", "a.py::function::leaf"),
387                edge("a.py::class::Caller", "a.py::function::main"),
388                edge("a.py::class::Outer", "a.py::class::Caller"),
389            ],
390        );
391
392        let result = build_context_result(&graph, "a.py::function::main", &entities, 999);
393        let roles_and_names: Vec<(&str, &str)> = result
394            .entries
395            .iter()
396            .map(|entry| (entry.role.as_str(), entry.entity_name.as_str()))
397            .collect();
398
399        assert_eq!(
400            roles_and_names,
401            vec![
402                ("target", "main"),
403                ("direct_dependency", "helper_a"),
404                ("direct_dependency", "helper_b"),
405                ("direct_dependent", "Caller"),
406                ("transitive_dependency", "leaf"),
407                ("transitive_dependent", "Outer"),
408            ]
409        );
410        assert!(!result.truncated);
411        assert!(!result.target_omitted);
412        assert!(result.total_tokens <= 999);
413    }
414
415    #[test]
416    fn test_collect_transitive_caps_results() {
417        let mut entities = Vec::new();
418        let mut edges = Vec::new();
419
420        for index in 0..=10_001 {
421            let id = format!("a.py::function::helper_{index}");
422            entities.push(entity(
423                &id,
424                &format!("helper_{index}"),
425                "def helper():\n    return 1",
426            ));
427            if index > 0 {
428                edges.push(edge(&format!("a.py::function::helper_{}", index - 1), &id));
429            }
430        }
431
432        let graph = graph_from_entities(&entities, edges);
433        let result = collect_reachable_related(
434            &graph,
435            "a.py::function::helper_0",
436            &graph.dependencies,
437        );
438
439        assert_eq!(result.len(), 10_000);
440    }
441
442    fn entity(id: &str, name: &str, content: &str) -> SemanticEntity {
443        SemanticEntity {
444            id: id.to_string(),
445            file_path: "a.py".to_string(),
446            entity_type: id.split("::").nth(1).unwrap_or("function").to_string(),
447            name: name.to_string(),
448            parent_id: None,
449            content: content.to_string(),
450            content_hash: String::new(),
451            structural_hash: None,
452            start_line: 1,
453            end_line: content.lines().count(),
454            metadata: None,
455        }
456    }
457
458    fn edge(from_entity: &str, to_entity: &str) -> EntityRef {
459        EntityRef {
460            from_entity: from_entity.to_string(),
461            to_entity: to_entity.to_string(),
462            ref_type: RefType::Calls,
463        }
464    }
465
466    fn graph_from_entities(entities: &[SemanticEntity], edges: Vec<EntityRef>) -> EntityGraph {
467        let entity_infos: HashMap<String, EntityInfo> = entities
468            .iter()
469            .map(|entity| {
470                (
471                    entity.id.clone(),
472                    EntityInfo {
473                        id: entity.id.clone(),
474                        name: entity.name.clone(),
475                        entity_type: entity.entity_type.clone(),
476                        file_path: entity.file_path.clone(),
477                        parent_id: entity.parent_id.clone(),
478                        start_line: entity.start_line,
479                        end_line: entity.end_line,
480                    },
481                )
482            })
483            .collect();
484
485        EntityGraph::from_parts(entity_infos, edges)
486    }
487}