Skip to main content

sem_core/parser/
context.rs

1//! Context budgeting: pack optimal entity context into a token budget.
2//! Priority: target entity (full) > direct dependents (full) > transitive (signature only).
3
4use std::collections::HashMap;
5
6use crate::model::entity::SemanticEntity;
7use crate::parser::graph::EntityGraph;
8
9#[derive(Debug, Clone)]
10pub struct ContextEntry {
11    pub entity_name: String,
12    pub entity_type: String,
13    pub file_path: String,
14    pub role: String, // "target", "direct_dependent", "transitive_dependent"
15    pub content: String,
16    pub estimated_tokens: usize,
17}
18
19/// Estimate token count from content. Rough heuristic: ~1.3 tokens per whitespace-separated word.
20fn estimate_tokens(content: &str) -> usize {
21    let words = content.split_whitespace().count();
22    words * 13 / 10
23}
24
25/// Extract just the first line (signature) of an entity's content.
26fn signature_only(content: &str) -> String {
27    content.lines().next().unwrap_or("").to_string()
28}
29
30/// Build a context set for a target entity within a token budget.
31///
32/// Greedy knapsack by priority:
33/// 1. Target entity (full content)
34/// 2. Direct dependents (full content)
35/// 3. Transitive dependents (signature only)
36pub fn build_context(
37    graph: &EntityGraph,
38    entity_id: &str,
39    all_entities: &[SemanticEntity],
40    token_budget: usize,
41) -> Vec<ContextEntry> {
42    // Build content lookup: entity_id -> SemanticEntity
43    let entity_lookup: HashMap<&str, &SemanticEntity> = all_entities
44        .iter()
45        .map(|e| (e.id.as_str(), e))
46        .collect();
47
48    let mut entries = Vec::new();
49    let mut tokens_used = 0usize;
50
51    // 1. Target entity (full)
52    if let Some(entity) = entity_lookup.get(entity_id) {
53        let tokens = estimate_tokens(&entity.content);
54        if tokens_used + tokens <= token_budget {
55            entries.push(ContextEntry {
56                entity_name: entity.name.clone(),
57                entity_type: entity.entity_type.clone(),
58                file_path: entity.file_path.clone(),
59                role: "target".to_string(),
60                content: entity.content.clone(),
61                estimated_tokens: tokens,
62            });
63            tokens_used += tokens;
64        }
65    }
66
67    // 2. Direct dependents (full content)
68    let direct_deps = graph.get_dependents(entity_id);
69    for dep_info in &direct_deps {
70        if tokens_used >= token_budget {
71            break;
72        }
73        if let Some(entity) = entity_lookup.get(dep_info.id.as_str()) {
74            let tokens = estimate_tokens(&entity.content);
75            if tokens_used + tokens <= token_budget {
76                entries.push(ContextEntry {
77                    entity_name: entity.name.clone(),
78                    entity_type: entity.entity_type.clone(),
79                    file_path: entity.file_path.clone(),
80                    role: "direct_dependent".to_string(),
81                    content: entity.content.clone(),
82                    estimated_tokens: tokens,
83                });
84                tokens_used += tokens;
85            }
86        }
87    }
88
89    // 3. Transitive dependents (signature only)
90    let all_impact = graph.impact_analysis(entity_id);
91    let direct_ids: std::collections::HashSet<&str> =
92        direct_deps.iter().map(|d| d.id.as_str()).collect();
93
94    for dep_info in &all_impact {
95        if tokens_used >= token_budget {
96            break;
97        }
98        // Skip direct deps (already included with full content)
99        if direct_ids.contains(dep_info.id.as_str()) {
100            continue;
101        }
102        if let Some(entity) = entity_lookup.get(dep_info.id.as_str()) {
103            let sig = signature_only(&entity.content);
104            let tokens = estimate_tokens(&sig);
105            if tokens_used + tokens <= token_budget {
106                entries.push(ContextEntry {
107                    entity_name: entity.name.clone(),
108                    entity_type: entity.entity_type.clone(),
109                    file_path: entity.file_path.clone(),
110                    role: "transitive_dependent".to_string(),
111                    content: sig,
112                    estimated_tokens: tokens,
113                });
114                tokens_used += tokens;
115            }
116        }
117    }
118
119    entries
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_estimate_tokens() {
128        assert_eq!(estimate_tokens("hello world"), 2); // 2 * 13 / 10 = 2
129        assert_eq!(estimate_tokens("fn foo(a: i32, b: i32) -> bool {"), 10); // 8 words * 13 / 10 = 10
130    }
131
132    #[test]
133    fn test_signature_only() {
134        assert_eq!(
135            signature_only("fn foo(a: i32) {\n    a + 1\n}"),
136            "fn foo(a: i32) {"
137        );
138    }
139}