Skip to main content

sem_core/parser/
context.rs

1//! Context budgeting: pack optimal entity context into a token budget.
2//! Priority: target entity (full) > direct dependents (full) > transitive (signature only).
3
4use std::collections::HashMap;
5
6use crate::model::entity::SemanticEntity;
7use crate::parser::graph::EntityGraph;
8
9#[derive(Debug, Clone)]
10pub struct ContextEntry {
11    pub entity_id: String,
12    pub entity_name: String,
13    pub entity_type: String,
14    pub file_path: String,
15    pub role: String, // "target", "direct_dependent", "transitive_dependent"
16    pub content: String,
17    pub estimated_tokens: usize,
18}
19
20/// Estimate token count from content. Rough heuristic: ~1.3 tokens per whitespace-separated word.
21fn estimate_tokens(content: &str) -> usize {
22    let words = content.split_whitespace().count();
23    words * 13 / 10
24}
25
26/// Extract just the first line (signature) of an entity's content.
27fn signature_only(content: &str) -> String {
28    content.lines().next().unwrap_or("").to_string()
29}
30
31/// Build a context set for a target entity within a token budget.
32///
33/// Greedy knapsack by priority:
34/// 1. Target entity (full content)
35/// 2. Direct dependents (full content)
36/// 3. Transitive dependents (signature only)
37pub fn build_context(
38    graph: &EntityGraph,
39    entity_id: &str,
40    all_entities: &[SemanticEntity],
41    token_budget: usize,
42) -> Vec<ContextEntry> {
43    // Build content lookup: entity_id -> SemanticEntity
44    let entity_lookup: HashMap<&str, &SemanticEntity> = all_entities
45        .iter()
46        .map(|e| (e.id.as_str(), e))
47        .collect();
48
49    let mut entries = Vec::new();
50    let mut tokens_used = 0usize;
51
52    // 1. Target entity (full)
53    if let Some(entity) = entity_lookup.get(entity_id) {
54        let tokens = estimate_tokens(&entity.content);
55        if tokens_used + tokens <= token_budget {
56            entries.push(ContextEntry {
57                entity_id: entity.id.clone(),
58                entity_name: entity.name.clone(),
59                entity_type: entity.entity_type.clone(),
60                file_path: entity.file_path.clone(),
61                role: "target".to_string(),
62                content: entity.content.clone(),
63                estimated_tokens: tokens,
64            });
65            tokens_used += tokens;
66        }
67    }
68
69    // 2. Direct dependents (full content)
70    let direct_deps = graph.get_dependents(entity_id);
71    for dep_info in &direct_deps {
72        if tokens_used >= token_budget {
73            break;
74        }
75        if let Some(entity) = entity_lookup.get(dep_info.id.as_str()) {
76            let tokens = estimate_tokens(&entity.content);
77            if tokens_used + tokens <= token_budget {
78                entries.push(ContextEntry {
79                    entity_id: entity.id.clone(),
80                    entity_name: entity.name.clone(),
81                    entity_type: entity.entity_type.clone(),
82                    file_path: entity.file_path.clone(),
83                    role: "direct_dependent".to_string(),
84                    content: entity.content.clone(),
85                    estimated_tokens: tokens,
86                });
87                tokens_used += tokens;
88            }
89        }
90    }
91
92    // 3. Transitive dependents (signature only)
93    let all_impact = graph.impact_analysis(entity_id);
94    let direct_ids: std::collections::HashSet<&str> =
95        direct_deps.iter().map(|d| d.id.as_str()).collect();
96
97    for dep_info in &all_impact {
98        if tokens_used >= token_budget {
99            break;
100        }
101        // Skip direct deps (already included with full content)
102        if direct_ids.contains(dep_info.id.as_str()) {
103            continue;
104        }
105        if let Some(entity) = entity_lookup.get(dep_info.id.as_str()) {
106            let sig = signature_only(&entity.content);
107            let tokens = estimate_tokens(&sig);
108            if tokens_used + tokens <= token_budget {
109                entries.push(ContextEntry {
110                    entity_id: entity.id.clone(),
111                    entity_name: entity.name.clone(),
112                    entity_type: entity.entity_type.clone(),
113                    file_path: entity.file_path.clone(),
114                    role: "transitive_dependent".to_string(),
115                    content: sig,
116                    estimated_tokens: tokens,
117                });
118                tokens_used += tokens;
119            }
120        }
121    }
122
123    entries
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_estimate_tokens() {
132        assert_eq!(estimate_tokens("hello world"), 2); // 2 * 13 / 10 = 2
133        assert_eq!(estimate_tokens("fn foo(a: i32, b: i32) -> bool {"), 10); // 8 words * 13 / 10 = 10
134    }
135
136    #[test]
137    fn test_signature_only() {
138        assert_eq!(
139            signature_only("fn foo(a: i32) {\n    a + 1\n}"),
140            "fn foo(a: i32) {"
141        );
142    }
143}