Skip to main content

vtcode_core/core/
prompt_caching.rs

1use crate::config::constants::prompt_cache;
2use crate::config::core::PromptCachingConfig;
3use crate::llm::provider::{Message, MessageContent, MessageRole};
4use hashbrown::HashMap;
5use serde::{Deserialize, Serialize};
6use std::fmt::Write;
7use std::path::{Path, PathBuf};
8use tokio::fs;
9use vtcode_commons::utils::current_timestamp;
10
11use crate::utils::tokens::estimate_tokens;
12
13/// Cached prompt entry
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CachedPrompt {
16    pub prompt_hash: String,
17    pub original_prompt: String,
18    pub optimized_prompt: String,
19    pub model_used: String,
20    pub tokens_saved: Option<u32>,
21    pub quality_score: Option<f64>,
22    pub created_at: u64,
23    pub last_used: u64,
24    pub usage_count: u32,
25}
26
27/// Prompt caching configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PromptCacheConfig {
30    pub enabled: bool,
31    pub cache_dir: PathBuf,
32    pub max_cache_size: usize,
33    pub max_age_days: u64,
34    pub enable_auto_cleanup: bool,
35    pub min_quality_threshold: f64,
36}
37
38impl Default for PromptCacheConfig {
39    fn default() -> Self {
40        Self {
41            enabled: prompt_cache::DEFAULT_ENABLED,
42            cache_dir: default_cache_dir(),
43            max_cache_size: prompt_cache::DEFAULT_MAX_ENTRIES,
44            max_age_days: prompt_cache::DEFAULT_MAX_AGE_DAYS,
45            enable_auto_cleanup: prompt_cache::DEFAULT_AUTO_CLEANUP,
46            min_quality_threshold: prompt_cache::DEFAULT_MIN_QUALITY_THRESHOLD,
47        }
48    }
49}
50
51impl PromptCacheConfig {
52    /// Build runtime configuration from high-level TOML settings
53    pub fn from_settings(settings: &PromptCachingConfig, workspace_root: Option<&Path>) -> Self {
54        Self {
55            enabled: settings.enabled,
56            cache_dir: settings.resolve_cache_dir(workspace_root),
57            max_cache_size: settings.max_entries,
58            max_age_days: settings.max_age_days,
59            enable_auto_cleanup: settings.enable_auto_cleanup,
60            min_quality_threshold: settings.min_quality_threshold,
61        }
62    }
63
64    pub fn is_enabled(&self) -> bool {
65        self.enabled
66    }
67}
68
69fn default_cache_dir() -> PathBuf {
70    if let Some(home) = dirs::home_dir() {
71        return home.join(prompt_cache::DEFAULT_CACHE_DIR);
72    }
73    PathBuf::from(prompt_cache::DEFAULT_CACHE_DIR)
74}
75
76/// Prompt caching system
77pub struct PromptCache {
78    config: PromptCacheConfig,
79    cache: HashMap<String, CachedPrompt>,
80    dirty: bool,
81}
82
83impl PromptCache {
84    pub async fn new() -> Self {
85        Self::with_config(PromptCacheConfig::default()).await
86    }
87
88    pub async fn with_config(config: PromptCacheConfig) -> Self {
89        let mut cache = Self {
90            config,
91            cache: HashMap::new(),
92            dirty: false,
93        };
94
95        // Load existing cache
96        if cache.config.enabled {
97            let _ = cache.load_cache().await;
98
99            // Auto cleanup if enabled
100            if cache.config.enable_auto_cleanup {
101                let _ = cache.cleanup_expired();
102            }
103        }
104
105        cache
106    }
107
108    /// Get cached optimized prompt
109    pub fn get(&mut self, prompt_hash: &str) -> Option<CachedPrompt> {
110        if !self.config.enabled {
111            return None;
112        }
113        self.cache.get_mut(prompt_hash).map(|entry| {
114            entry.last_used = current_timestamp();
115            entry.usage_count += 1;
116            self.dirty = true;
117            entry.clone()
118        })
119    }
120
121    /// Store optimized prompt in cache
122    pub fn put(&mut self, entry: CachedPrompt) -> Result<(), PromptCacheError> {
123        if !self.config.enabled {
124            return Ok(());
125        }
126        // Check quality threshold
127        if entry
128            .quality_score
129            .is_some_and(|quality| quality < self.config.min_quality_threshold)
130        {
131            return Ok(()); // Don't cache low-quality entries
132        }
133
134        // Check cache size limit
135        if self.cache.len() >= self.config.max_cache_size {
136            self.evict_oldest()?;
137        }
138
139        self.cache.insert(entry.prompt_hash.clone(), entry);
140        self.dirty = true;
141
142        Ok(())
143    }
144
145    /// Check if prompt is cached
146    pub fn contains(&self, prompt_hash: &str) -> bool {
147        self.config.enabled && self.cache.contains_key(prompt_hash)
148    }
149
150    /// Get cache statistics
151    pub fn stats(&self) -> CacheStats {
152        if !self.config.enabled {
153            return CacheStats::default();
154        }
155        let total_entries = self.cache.len();
156        let total_usage = self.cache.values().map(|e| e.usage_count).sum::<u32>();
157        let total_tokens_saved = self
158            .cache
159            .values()
160            .filter_map(|e| e.tokens_saved)
161            .sum::<u32>();
162        let avg_quality = if !self.cache.is_empty() {
163            self.cache
164                .values()
165                .filter_map(|e| e.quality_score)
166                .sum::<f64>()
167                / self.cache.len() as f64
168        } else {
169            0.0
170        };
171
172        CacheStats {
173            total_entries,
174            total_usage,
175            total_tokens_saved,
176            avg_quality,
177        }
178    }
179
180    /// Clear all cache entries
181    pub async fn clear(&mut self) -> Result<(), PromptCacheError> {
182        if !self.config.enabled {
183            return Ok(());
184        }
185        self.cache.clear();
186        self.dirty = true;
187        self.save_cache().await
188    }
189
190    /// Generate hash for prompt
191    pub fn hash_prompt(prompt: &str) -> String {
192        vtcode_commons::utils::calculate_sha256(prompt.as_bytes())
193    }
194
195    /// Save cache to disk
196    pub async fn save_cache(&self) -> Result<(), PromptCacheError> {
197        if !self.config.enabled || !self.dirty {
198            return Ok(());
199        }
200
201        // Ensure cache directory exists
202        fs::create_dir_all(&self.config.cache_dir)
203            .await
204            .map_err(PromptCacheError::Io)?;
205
206        let cache_path = self.config.cache_dir.join("prompt_cache.json");
207        let data =
208            serde_json::to_string_pretty(&self.cache).map_err(PromptCacheError::Serialization)?;
209
210        fs::write(cache_path, data)
211            .await
212            .map_err(PromptCacheError::Io)?;
213
214        Ok(())
215    }
216
217    /// Load cache from disk
218    async fn load_cache(&mut self) -> Result<(), PromptCacheError> {
219        if !self.config.enabled {
220            return Ok(());
221        }
222        let cache_path = self.config.cache_dir.join("prompt_cache.json");
223
224        if !fs::try_exists(&cache_path).await.unwrap_or(false) {
225            return Ok(());
226        }
227
228        let data = fs::read_to_string(cache_path)
229            .await
230            .map_err(PromptCacheError::Io)?;
231
232        self.cache = serde_json::from_str(&data).map_err(PromptCacheError::Serialization)?;
233
234        Ok(())
235    }
236
237    /// Clean up expired cache entries
238    fn cleanup_expired(&mut self) -> Result<(), PromptCacheError> {
239        if !self.config.enabled {
240            return Ok(());
241        }
242        let now = current_timestamp();
243        let max_age_seconds = self.config.max_age_days * 24 * 60 * 60;
244
245        self.cache
246            .retain(|_, entry| now - entry.created_at < max_age_seconds);
247
248        self.dirty = true;
249        Ok(())
250    }
251
252    /// Evict oldest cache entries when cache is full
253    fn evict_oldest(&mut self) -> Result<(), PromptCacheError> {
254        if !self.config.enabled {
255            return Ok(());
256        }
257        if self.cache.is_empty() {
258            return Ok(());
259        }
260
261        // Find the oldest entry
262        let Some(oldest_key) = self
263            .cache
264            .iter()
265            .min_by_key(|(_, entry)| entry.last_used)
266            .map(|(key, _)| key.clone())
267        else {
268            return Ok(());
269        };
270
271        self.cache.remove(&oldest_key);
272        self.dirty = true;
273
274        Ok(())
275    }
276}
277
278// Note: Drop trait cannot be async, so we remove automatic save on drop.
279// Users must explicitly call save_cache() or use a wrapper that handles this.
280
281/// Cache statistics
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct CacheStats {
284    pub total_entries: usize,
285    pub total_usage: u32,
286    pub total_tokens_saved: u32,
287    pub avg_quality: f64,
288}
289
290impl Default for CacheStats {
291    fn default() -> Self {
292        Self {
293            total_entries: 0,
294            total_usage: 0,
295            total_tokens_saved: 0,
296            avg_quality: 0.0,
297        }
298    }
299}
300
301/// Prompt cache errors
302#[derive(Debug, thiserror::Error)]
303pub enum PromptCacheError {
304    #[error("IO error: {0}")]
305    Io(#[from] std::io::Error),
306
307    #[error("Serialization error: {0}")]
308    Serialization(#[from] serde_json::Error),
309
310    #[error("Cache full")]
311    CacheFull,
312}
313
314/// Prompt optimizer that uses caching
315pub struct PromptOptimizer {
316    cache: PromptCache,
317    llm_provider: Box<dyn crate::llm::provider::LLMProvider>,
318}
319
320impl PromptOptimizer {
321    pub async fn new(llm_provider: Box<dyn crate::llm::provider::LLMProvider>) -> Self {
322        Self {
323            cache: PromptCache::new().await,
324            llm_provider,
325        }
326    }
327
328    pub fn with_cache(mut self, cache: PromptCache) -> Self {
329        self.cache = cache;
330        self
331    }
332
333    /// Explicitly save the cache to disk
334    pub async fn save_cache(&self) -> Result<(), PromptCacheError> {
335        self.cache.save_cache().await
336    }
337
338    /// Optimize a prompt using caching
339    pub async fn optimize_prompt(
340        &mut self,
341        original_prompt: &str,
342        target_model: &str,
343        context: Option<&str>,
344    ) -> Result<String, PromptOptimizationError> {
345        let prompt_hash = PromptCache::hash_prompt(original_prompt);
346
347        // Check cache first
348        if let Some(cached) = self.cache.get(&prompt_hash) {
349            return Ok(cached.optimized_prompt);
350        }
351
352        // Generate optimized prompt
353        let optimized = self
354            .generate_optimized_prompt(original_prompt, target_model, context)
355            .await?;
356
357        // Calculate tokens saved (rough estimate)
358        let original_tokens = estimate_tokens(original_prompt);
359        let optimized_tokens = estimate_tokens(&optimized);
360        let tokens_saved = original_tokens.saturating_sub(optimized_tokens);
361
362        // Create cache entry
363        let entry = CachedPrompt {
364            prompt_hash: prompt_hash.clone(),
365            original_prompt: original_prompt.to_string(),
366            optimized_prompt: optimized.clone(),
367            model_used: target_model.to_string(),
368            tokens_saved: Some(tokens_saved.min(u32::MAX as usize) as u32),
369            quality_score: Some(0.8), // Placeholder quality score
370            created_at: current_timestamp(),
371            last_used: current_timestamp(),
372            usage_count: 1,
373        };
374
375        // Store in cache
376        self.cache.put(entry)?;
377
378        Ok(optimized)
379    }
380
381    /// Generate optimized prompt using LLM
382    async fn generate_optimized_prompt(
383        &self,
384        original_prompt: &str,
385        target_model: &str,
386        context: Option<&str>,
387    ) -> Result<String, PromptOptimizationError> {
388        let system_prompt = format!(
389            "You are an expert prompt engineer. Your task is to optimize prompts for {} \
390             to make them more effective, clearer, and more likely to produce high-quality responses. \
391             Focus on improving clarity, specificity, structure, and effectiveness while preserving \
392             the original intent and requirements.",
393            target_model
394        );
395
396        let mut user_prompt = format!(
397            "Please optimize the following prompt for {}:\n\nORIGINAL PROMPT:\n{}\n\n",
398            target_model, original_prompt
399        );
400
401        if let Some(ctx) = context {
402            let _ = write!(user_prompt, "CONTEXT:\n{}\n\n", ctx);
403        }
404
405        user_prompt.push_str(
406            "OPTIMIZATION REQUIREMENTS:\n\
407             1. Make the prompt clearer and more specific\n\
408             2. Improve structure and formatting\n\
409             3. Add relevant context or examples if helpful\n\
410             4. Ensure the prompt is appropriate for the target model\n\
411             5. Maintain the original intent and requirements\n\
412             6. Keep the optimized prompt concise but comprehensive\n\n\
413             Provide only the optimized prompt without any explanation or additional text.",
414        );
415
416        let request = crate::llm::provider::LLMRequest {
417            messages: vec![
418                Message {
419                    role: MessageRole::System,
420                    content: MessageContent::Text(system_prompt),
421                    ..Default::default()
422                },
423                Message {
424                    role: MessageRole::User,
425                    content: MessageContent::Text(user_prompt),
426                    ..Default::default()
427                },
428            ],
429            model: target_model.to_string(),
430            max_tokens: Some(2000),
431            temperature: Some(0.3),
432            ..Default::default()
433        };
434
435        let response = self
436            .llm_provider
437            .generate(request)
438            .await
439            .map_err(|e| PromptOptimizationError::LLMError(e.to_string()))?;
440
441        Ok(response
442            .content
443            .unwrap_or_else(|| original_prompt.to_string()))
444    }
445
446    /// Get cache statistics
447    pub fn cache_stats(&self) -> CacheStats {
448        self.cache.stats()
449    }
450
451    /// Clear cache
452    pub async fn clear_cache(&mut self) -> Result<(), PromptCacheError> {
453        self.cache.clear().await
454    }
455}
456
457/// Prompt optimization errors
458#[derive(Debug, thiserror::Error)]
459pub enum PromptOptimizationError {
460    #[error("LLM error: {0}")]
461    LLMError(String),
462
463    #[error("Cache error: {0}")]
464    CacheError(#[from] PromptCacheError),
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn test_prompt_hash() {
473        let prompt = "Test prompt";
474        let hash1 = PromptCache::hash_prompt(prompt);
475        let hash2 = PromptCache::hash_prompt(prompt);
476        assert_eq!(hash1, hash2);
477        assert!(!hash1.is_empty());
478    }
479
480    #[tokio::test]
481    async fn test_cache_operations() {
482        let mut cache = PromptCache::new().await;
483
484        let entry = CachedPrompt {
485            prompt_hash: "test_hash".to_owned(),
486            original_prompt: "original".to_owned(),
487            optimized_prompt: "optimized".to_owned(),
488            model_used: crate::config::constants::models::google::GEMINI_3_FLASH_PREVIEW.to_owned(),
489            tokens_saved: Some(100),
490            quality_score: Some(0.9),
491            created_at: 1000,
492            last_used: 1000,
493            usage_count: 0,
494        };
495
496        cache.put(entry).unwrap();
497        assert!(cache.contains("test_hash"));
498
499        let retrieved = cache.get("test_hash");
500        assert!(retrieved.is_some());
501        assert_eq!(retrieved.unwrap().usage_count, 1);
502    }
503
504    #[tokio::test]
505    async fn disabled_cache_config_is_no_op() {
506        let settings = PromptCachingConfig {
507            enabled: false,
508            cache_dir: "relative/cache".to_owned(),
509            ..PromptCachingConfig::default()
510        };
511        let cfg = PromptCacheConfig::from_settings(&settings, None);
512        assert!(!cfg.is_enabled());
513
514        let mut cache = PromptCache::with_config(cfg).await;
515        assert!(!cache.contains("missing"));
516        assert_eq!(cache.stats().total_entries, 0);
517
518        let entry = CachedPrompt {
519            prompt_hash: "noop".to_owned(),
520            original_prompt: "original".to_owned(),
521            optimized_prompt: "optimized".to_owned(),
522            model_used: crate::config::constants::models::google::GEMINI_3_FLASH_PREVIEW.to_owned(),
523            tokens_saved: Some(10),
524            quality_score: Some(0.9),
525            created_at: 1,
526            last_used: 1,
527            usage_count: 0,
528        };
529
530        cache.put(entry).unwrap();
531        assert!(!cache.contains("noop"));
532        assert_eq!(cache.stats().total_entries, 0);
533    }
534}