vtcode_core/core/
prompt_caching.rs

1use crate::config::constants::prompt_cache;
2use crate::config::core::PromptCachingConfig;
3use crate::llm::provider::{Message, MessageRole};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10/// Cached prompt entry
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CachedPrompt {
13    pub prompt_hash: String,
14    pub original_prompt: String,
15    pub optimized_prompt: String,
16    pub model_used: String,
17    pub tokens_saved: Option<u32>,
18    pub quality_score: Option<f64>,
19    pub created_at: u64,
20    pub last_used: u64,
21    pub usage_count: u32,
22}
23
24/// Prompt caching configuration
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct PromptCacheConfig {
27    pub enabled: bool,
28    pub cache_dir: PathBuf,
29    pub max_cache_size: usize,
30    pub max_age_days: u64,
31    pub enable_auto_cleanup: bool,
32    pub min_quality_threshold: f64,
33}
34
35impl Default for PromptCacheConfig {
36    fn default() -> Self {
37        Self {
38            enabled: prompt_cache::DEFAULT_ENABLED,
39            cache_dir: default_cache_dir(),
40            max_cache_size: prompt_cache::DEFAULT_MAX_ENTRIES,
41            max_age_days: prompt_cache::DEFAULT_MAX_AGE_DAYS,
42            enable_auto_cleanup: prompt_cache::DEFAULT_AUTO_CLEANUP,
43            min_quality_threshold: prompt_cache::DEFAULT_MIN_QUALITY_THRESHOLD,
44        }
45    }
46}
47
48impl PromptCacheConfig {
49    /// Build runtime configuration from high-level TOML settings
50    pub fn from_settings(settings: &PromptCachingConfig, workspace_root: Option<&Path>) -> Self {
51        Self {
52            enabled: settings.enabled,
53            cache_dir: settings.resolve_cache_dir(workspace_root),
54            max_cache_size: settings.max_entries,
55            max_age_days: settings.max_age_days,
56            enable_auto_cleanup: settings.enable_auto_cleanup,
57            min_quality_threshold: settings.min_quality_threshold,
58        }
59    }
60
61    pub fn is_enabled(&self) -> bool {
62        self.enabled
63    }
64}
65
66fn default_cache_dir() -> PathBuf {
67    if let Some(home) = dirs::home_dir() {
68        return home.join(prompt_cache::DEFAULT_CACHE_DIR);
69    }
70    PathBuf::from(prompt_cache::DEFAULT_CACHE_DIR)
71}
72
73/// Prompt caching system
74pub struct PromptCache {
75    config: PromptCacheConfig,
76    cache: HashMap<String, CachedPrompt>,
77    dirty: bool,
78}
79
80impl PromptCache {
81    pub fn new() -> Self {
82        Self::with_config(PromptCacheConfig::default())
83    }
84
85    pub fn with_config(config: PromptCacheConfig) -> Self {
86        let mut cache = Self {
87            config,
88            cache: HashMap::new(),
89            dirty: false,
90        };
91
92        // Load existing cache
93        if cache.config.enabled {
94            let _ = cache.load_cache();
95
96            // Auto cleanup if enabled
97            if cache.config.enable_auto_cleanup {
98                let _ = cache.cleanup_expired();
99            }
100        }
101
102        cache
103    }
104
105    /// Get cached optimized prompt
106    pub fn get(&mut self, prompt_hash: &str) -> Option<&CachedPrompt> {
107        if !self.config.enabled {
108            return None;
109        }
110        if let Some(entry) = self.cache.get_mut(prompt_hash) {
111            entry.last_used = Self::current_timestamp();
112            entry.usage_count += 1;
113            self.dirty = true;
114            Some(entry)
115        } else {
116            None
117        }
118    }
119
120    /// Store optimized prompt in cache
121    pub fn put(&mut self, entry: CachedPrompt) -> Result<(), PromptCacheError> {
122        if !self.config.enabled {
123            return Ok(());
124        }
125        // Check quality threshold
126        if let Some(quality) = entry.quality_score {
127            if quality < self.config.min_quality_threshold {
128                return Ok(()); // Don't cache low-quality entries
129            }
130        }
131
132        // Check cache size limit
133        if self.cache.len() >= self.config.max_cache_size {
134            self.evict_oldest()?;
135        }
136
137        self.cache.insert(entry.prompt_hash.clone(), entry);
138        self.dirty = true;
139
140        Ok(())
141    }
142
143    /// Check if prompt is cached
144    pub fn contains(&self, prompt_hash: &str) -> bool {
145        self.config.enabled && self.cache.contains_key(prompt_hash)
146    }
147
148    /// Get cache statistics
149    pub fn stats(&self) -> CacheStats {
150        if !self.config.enabled {
151            return CacheStats::default();
152        }
153        let total_entries = self.cache.len();
154        let total_usage = self.cache.values().map(|e| e.usage_count).sum::<u32>();
155        let total_tokens_saved = self
156            .cache
157            .values()
158            .filter_map(|e| e.tokens_saved)
159            .sum::<u32>();
160        let avg_quality = if !self.cache.is_empty() {
161            self.cache
162                .values()
163                .filter_map(|e| e.quality_score)
164                .sum::<f64>()
165                / self.cache.len() as f64
166        } else {
167            0.0
168        };
169
170        CacheStats {
171            total_entries,
172            total_usage,
173            total_tokens_saved,
174            avg_quality,
175        }
176    }
177
178    /// Clear all cache entries
179    pub fn clear(&mut self) -> Result<(), PromptCacheError> {
180        if !self.config.enabled {
181            return Ok(());
182        }
183        self.cache.clear();
184        self.dirty = true;
185        self.save_cache()
186    }
187
188    /// Generate hash for prompt
189    pub fn hash_prompt(prompt: &str) -> String {
190        use sha2::{Digest, Sha256};
191        let mut hasher = Sha256::new();
192        hasher.update(prompt.as_bytes());
193        format!("{:x}", hasher.finalize())
194    }
195
196    /// Save cache to disk
197    pub fn save_cache(&self) -> Result<(), PromptCacheError> {
198        if !self.config.enabled || !self.dirty {
199            return Ok(());
200        }
201
202        // Ensure cache directory exists
203        fs::create_dir_all(&self.config.cache_dir).map_err(|e| PromptCacheError::Io(e))?;
204
205        let cache_path = self.config.cache_dir.join("prompt_cache.json");
206        let data = serde_json::to_string_pretty(&self.cache)
207            .map_err(|e| PromptCacheError::Serialization(e))?;
208
209        fs::write(cache_path, data).map_err(|e| PromptCacheError::Io(e))?;
210
211        Ok(())
212    }
213
214    /// Load cache from disk
215    fn load_cache(&mut self) -> Result<(), PromptCacheError> {
216        if !self.config.enabled {
217            return Ok(());
218        }
219        let cache_path = self.config.cache_dir.join("prompt_cache.json");
220
221        if !cache_path.exists() {
222            return Ok(());
223        }
224
225        let data = fs::read_to_string(cache_path).map_err(|e| PromptCacheError::Io(e))?;
226
227        self.cache = serde_json::from_str(&data).map_err(|e| PromptCacheError::Serialization(e))?;
228
229        Ok(())
230    }
231
232    /// Clean up expired cache entries
233    fn cleanup_expired(&mut self) -> Result<(), PromptCacheError> {
234        if !self.config.enabled {
235            return Ok(());
236        }
237        let now = Self::current_timestamp();
238        let max_age_seconds = self.config.max_age_days * 24 * 60 * 60;
239
240        self.cache
241            .retain(|_, entry| now - entry.created_at < max_age_seconds);
242
243        self.dirty = true;
244        Ok(())
245    }
246
247    /// Evict oldest cache entries when cache is full
248    fn evict_oldest(&mut self) -> Result<(), PromptCacheError> {
249        if !self.config.enabled {
250            return Ok(());
251        }
252        if self.cache.is_empty() {
253            return Ok(());
254        }
255
256        // Find the oldest entry
257        let oldest_key = self
258            .cache
259            .iter()
260            .min_by_key(|(_, entry)| entry.last_used)
261            .map(|(key, _)| key.clone())
262            .unwrap();
263
264        self.cache.remove(&oldest_key);
265        self.dirty = true;
266
267        Ok(())
268    }
269
270    /// Get current timestamp
271    fn current_timestamp() -> u64 {
272        SystemTime::now()
273            .duration_since(UNIX_EPOCH)
274            .unwrap()
275            .as_secs()
276    }
277}
278
279impl Drop for PromptCache {
280    fn drop(&mut self) {
281        let _ = self.save_cache();
282    }
283}
284
285/// Cache statistics
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct CacheStats {
288    pub total_entries: usize,
289    pub total_usage: u32,
290    pub total_tokens_saved: u32,
291    pub avg_quality: f64,
292}
293
294impl Default for CacheStats {
295    fn default() -> Self {
296        Self {
297            total_entries: 0,
298            total_usage: 0,
299            total_tokens_saved: 0,
300            avg_quality: 0.0,
301        }
302    }
303}
304
305/// Prompt cache errors
306#[derive(Debug, thiserror::Error)]
307pub enum PromptCacheError {
308    #[error("IO error: {0}")]
309    Io(#[from] std::io::Error),
310
311    #[error("Serialization error: {0}")]
312    Serialization(#[from] serde_json::Error),
313
314    #[error("Cache full")]
315    CacheFull,
316}
317
318/// Prompt optimizer that uses caching
319pub struct PromptOptimizer {
320    cache: PromptCache,
321    llm_provider: Box<dyn crate::llm::provider::LLMProvider>,
322}
323
324impl PromptOptimizer {
325    pub fn new(llm_provider: Box<dyn crate::llm::provider::LLMProvider>) -> Self {
326        Self {
327            cache: PromptCache::new(),
328            llm_provider,
329        }
330    }
331
332    pub fn with_cache(mut self, cache: PromptCache) -> Self {
333        self.cache = cache;
334        self
335    }
336
337    /// Optimize a prompt using caching
338    pub async fn optimize_prompt(
339        &mut self,
340        original_prompt: &str,
341        target_model: &str,
342        context: Option<&str>,
343    ) -> Result<String, PromptOptimizationError> {
344        let prompt_hash = PromptCache::hash_prompt(original_prompt);
345
346        // Check cache first
347        if let Some(cached) = self.cache.get(&prompt_hash) {
348            return Ok(cached.optimized_prompt.clone());
349        }
350
351        // Generate optimized prompt
352        let optimized = self
353            .generate_optimized_prompt(original_prompt, target_model, context)
354            .await?;
355
356        // Calculate tokens saved (rough estimate)
357        let original_tokens = Self::estimate_tokens(original_prompt);
358        let optimized_tokens = Self::estimate_tokens(&optimized);
359        let tokens_saved = original_tokens.saturating_sub(optimized_tokens);
360
361        // Create cache entry
362        let entry = CachedPrompt {
363            prompt_hash: prompt_hash.clone(),
364            original_prompt: original_prompt.to_string(),
365            optimized_prompt: optimized.clone(),
366            model_used: target_model.to_string(),
367            tokens_saved: Some(tokens_saved),
368            quality_score: Some(0.8), // Placeholder quality score
369            created_at: PromptCache::current_timestamp(),
370            last_used: PromptCache::current_timestamp(),
371            usage_count: 1,
372        };
373
374        // Store in cache
375        self.cache.put(entry)?;
376
377        Ok(optimized)
378    }
379
380    /// Generate optimized prompt using LLM
381    async fn generate_optimized_prompt(
382        &self,
383        original_prompt: &str,
384        target_model: &str,
385        context: Option<&str>,
386    ) -> Result<String, PromptOptimizationError> {
387        let system_prompt = format!(
388            "You are an expert prompt engineer. Your task is to optimize prompts for {} \
389             to make them more effective, clearer, and more likely to produce high-quality responses. \
390             Focus on improving clarity, specificity, structure, and effectiveness while preserving \
391             the original intent and requirements.",
392            target_model
393        );
394
395        let mut user_prompt = format!(
396            "Please optimize the following prompt for {}:\n\nORIGINAL PROMPT:\n{}\n\n",
397            target_model, original_prompt
398        );
399
400        if let Some(ctx) = context {
401            user_prompt.push_str(&format!("CONTEXT:\n{}\n\n", ctx));
402        }
403
404        user_prompt.push_str(
405            "OPTIMIZATION REQUIREMENTS:\n\
406             1. Make the prompt clearer and more specific\n\
407             2. Improve structure and formatting\n\
408             3. Add relevant context or examples if helpful\n\
409             4. Ensure the prompt is appropriate for the target model\n\
410             5. Maintain the original intent and requirements\n\
411             6. Keep the optimized prompt concise but comprehensive\n\n\
412             Provide only the optimized prompt without any explanation or additional text.",
413        );
414
415        let request = crate::llm::provider::LLMRequest {
416            messages: vec![
417                Message {
418                    role: MessageRole::System,
419                    content: system_prompt,
420                    tool_calls: None,
421                    tool_call_id: None,
422                },
423                Message {
424                    role: MessageRole::User,
425                    content: user_prompt,
426                    tool_calls: None,
427                    tool_call_id: None,
428                },
429            ],
430            system_prompt: None,
431            tools: None,
432            model: target_model.to_string(),
433            max_tokens: Some(2000),
434            temperature: Some(0.3),
435            stream: false,
436            tool_choice: None,
437            parallel_tool_calls: None,
438            parallel_tool_config: None,
439            reasoning_effort: None,
440        };
441
442        let response = self
443            .llm_provider
444            .generate(request)
445            .await
446            .map_err(|e| PromptOptimizationError::LLMError(e.to_string()))?;
447
448        Ok(response
449            .content
450            .unwrap_or_else(|| original_prompt.to_string()))
451    }
452
453    /// Estimate token count (rough approximation)
454    fn estimate_tokens(text: &str) -> u32 {
455        // Rough approximation: 1 token ≈ 4 characters for English text
456        (text.len() / 4) as u32
457    }
458
459    /// Get cache statistics
460    pub fn cache_stats(&self) -> CacheStats {
461        self.cache.stats()
462    }
463
464    /// Clear cache
465    pub fn clear_cache(&mut self) -> Result<(), PromptCacheError> {
466        self.cache.clear()
467    }
468}
469
470/// Prompt optimization errors
471#[derive(Debug, thiserror::Error)]
472pub enum PromptOptimizationError {
473    #[error("LLM error: {0}")]
474    LLMError(String),
475
476    #[error("Cache error: {0}")]
477    CacheError(#[from] PromptCacheError),
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483
484    #[test]
485    fn test_prompt_hash() {
486        let prompt = "Test prompt";
487        let hash1 = PromptCache::hash_prompt(prompt);
488        let hash2 = PromptCache::hash_prompt(prompt);
489        assert_eq!(hash1, hash2);
490        assert!(!hash1.is_empty());
491    }
492
493    #[test]
494    fn test_cache_operations() {
495        let mut cache = PromptCache::new();
496
497        let entry = CachedPrompt {
498            prompt_hash: "test_hash".to_string(),
499            original_prompt: "original".to_string(),
500            optimized_prompt: "optimized".to_string(),
501            model_used: crate::config::constants::models::GEMINI_2_5_FLASH.to_string(),
502            tokens_saved: Some(100),
503            quality_score: Some(0.9),
504            created_at: 1000,
505            last_used: 1000,
506            usage_count: 0,
507        };
508
509        cache.put(entry).unwrap();
510        assert!(cache.contains("test_hash"));
511
512        let retrieved = cache.get("test_hash");
513        assert!(retrieved.is_some());
514        assert_eq!(retrieved.unwrap().usage_count, 1);
515    }
516
517    #[test]
518    fn disabled_cache_config_is_no_op() {
519        let settings = PromptCachingConfig {
520            enabled: false,
521            cache_dir: "relative/cache".to_string(),
522            ..PromptCachingConfig::default()
523        };
524        let cfg = PromptCacheConfig::from_settings(&settings, None);
525        assert!(!cfg.is_enabled());
526
527        let mut cache = PromptCache::with_config(cfg);
528        assert!(!cache.contains("missing"));
529        assert_eq!(cache.stats().total_entries, 0);
530
531        let entry = CachedPrompt {
532            prompt_hash: "noop".to_string(),
533            original_prompt: "original".to_string(),
534            optimized_prompt: "optimized".to_string(),
535            model_used: crate::config::constants::models::GEMINI_2_5_FLASH.to_string(),
536            tokens_saved: Some(10),
537            quality_score: Some(0.9),
538            created_at: 1,
539            last_used: 1,
540            usage_count: 0,
541        };
542
543        cache.put(entry).unwrap();
544        assert!(!cache.contains("noop"));
545        assert_eq!(cache.stats().total_entries, 0);
546    }
547}