Skip to main content

sem_core/parser/
context.rs

1//! Context budgeting: pack optimal entity context into a token budget.
2//! Priority: target entity (full) > direct dependents (full) > transitive (signature only).
3
4use std::collections::HashMap;
5
6use crate::model::entity::SemanticEntity;
7use crate::parser::graph::EntityGraph;
8
9#[derive(Debug, Clone)]
10pub struct ContextEntry {
11    pub entity_id: String,
12    pub entity_name: String,
13    pub entity_type: String,
14    pub file_path: String,
15    pub role: String, // "target", "direct_dependent", "transitive_dependent"
16    pub content: String,
17    pub estimated_tokens: usize,
18}
19
20/// Estimate token count from content. Rough heuristic: ~1.3 tokens per whitespace-separated word.
21fn estimate_tokens(content: &str) -> usize {
22    let words = content.split_whitespace().count();
23    words * 13 / 10
24}
25
26/// Extract just the first line (signature) of an entity's content.
27fn signature_only(content: &str) -> String {
28    content.lines().next().unwrap_or("").to_string()
29}
30
31/// Build a context set for a target entity within a token budget.
32///
33/// Greedy knapsack by priority:
34/// 1. Target entity (full content)
35/// 2. Direct dependents (full content)
36/// 3. Transitive dependents (signature only)
37pub fn build_context(
38    graph: &EntityGraph,
39    entity_id: &str,
40    all_entities: &[SemanticEntity],
41    token_budget: usize,
42) -> Vec<ContextEntry> {
43    // Build content lookup: entity_id -> SemanticEntity
44    let entity_lookup: HashMap<&str, &SemanticEntity> = all_entities
45        .iter()
46        .map(|e| (e.id.as_str(), e))
47        .collect();
48
49    let mut entries = Vec::new();
50    let mut tokens_used = 0usize;
51
52    // 1. Target entity (always included, truncated to signature if it exceeds budget)
53    if let Some(entity) = entity_lookup.get(entity_id) {
54        let full_tokens = estimate_tokens(&entity.content);
55        let (content, tokens) = if full_tokens <= token_budget {
56            (entity.content.clone(), full_tokens)
57        } else {
58            // Truncate to signature so the target is always present (#145)
59            let sig = signature_only(&entity.content);
60            let sig_tokens = estimate_tokens(&sig);
61            (sig, sig_tokens)
62        };
63        entries.push(ContextEntry {
64            entity_id: entity.id.clone(),
65            entity_name: entity.name.clone(),
66            entity_type: entity.entity_type.clone(),
67            file_path: entity.file_path.clone(),
68            role: "target".to_string(),
69            content,
70            estimated_tokens: tokens,
71        });
72        tokens_used += tokens;
73    }
74
75    // 2. Direct dependents (full content)
76    let direct_deps = graph.get_dependents(entity_id);
77    for dep_info in &direct_deps {
78        if tokens_used >= token_budget {
79            break;
80        }
81        if let Some(entity) = entity_lookup.get(dep_info.id.as_str()) {
82            let tokens = estimate_tokens(&entity.content);
83            if tokens_used + tokens <= token_budget {
84                entries.push(ContextEntry {
85                    entity_id: entity.id.clone(),
86                    entity_name: entity.name.clone(),
87                    entity_type: entity.entity_type.clone(),
88                    file_path: entity.file_path.clone(),
89                    role: "direct_dependent".to_string(),
90                    content: entity.content.clone(),
91                    estimated_tokens: tokens,
92                });
93                tokens_used += tokens;
94            }
95        }
96    }
97
98    // 3. Transitive dependents (signature only)
99    let all_impact = graph.impact_analysis(entity_id);
100    let direct_ids: std::collections::HashSet<&str> =
101        direct_deps.iter().map(|d| d.id.as_str()).collect();
102
103    for dep_info in &all_impact {
104        if tokens_used >= token_budget {
105            break;
106        }
107        // Skip direct deps (already included with full content)
108        if direct_ids.contains(dep_info.id.as_str()) {
109            continue;
110        }
111        if let Some(entity) = entity_lookup.get(dep_info.id.as_str()) {
112            let sig = signature_only(&entity.content);
113            let tokens = estimate_tokens(&sig);
114            if tokens_used + tokens <= token_budget {
115                entries.push(ContextEntry {
116                    entity_id: entity.id.clone(),
117                    entity_name: entity.name.clone(),
118                    entity_type: entity.entity_type.clone(),
119                    file_path: entity.file_path.clone(),
120                    role: "transitive_dependent".to_string(),
121                    content: sig,
122                    estimated_tokens: tokens,
123                });
124                tokens_used += tokens;
125            }
126        }
127    }
128
129    entries
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_estimate_tokens() {
138        assert_eq!(estimate_tokens("hello world"), 2); // 2 * 13 / 10 = 2
139        assert_eq!(estimate_tokens("fn foo(a: i32, b: i32) -> bool {"), 10); // 8 words * 13 / 10 = 10
140    }
141
142    #[test]
143    fn test_signature_only() {
144        assert_eq!(
145            signature_only("fn foo(a: i32) {\n    a + 1\n}"),
146            "fn foo(a: i32) {"
147        );
148    }
149}