vtcode_core/core/
prompt_caching.rs

1use crate::llm::provider::{Message, MessageRole};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::PathBuf;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8/// Cached prompt entry
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CachedPrompt {
11    pub prompt_hash: String,
12    pub original_prompt: String,
13    pub optimized_prompt: String,
14    pub model_used: String,
15    pub tokens_saved: Option<u32>,
16    pub quality_score: Option<f64>,
17    pub created_at: u64,
18    pub last_used: u64,
19    pub usage_count: u32,
20}
21
22/// Prompt caching configuration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PromptCacheConfig {
25    pub cache_dir: PathBuf,
26    pub max_cache_size: usize,
27    pub max_age_days: u64,
28    pub enable_auto_cleanup: bool,
29    pub min_quality_threshold: f64,
30}
31
32impl Default for PromptCacheConfig {
33    fn default() -> Self {
34        Self {
35            cache_dir: dirs::home_dir()
36                .unwrap_or_else(|| PathBuf::from("."))
37                .join(".vtcode")
38                .join("cache")
39                .join("prompts"),
40            max_cache_size: 1000, // Maximum number of cached prompts
41            max_age_days: 30,     // Cache entries older than 30 days are cleaned up
42            enable_auto_cleanup: true,
43            min_quality_threshold: 0.7, // Minimum quality score to cache
44        }
45    }
46}
47
48/// Prompt caching system
49pub struct PromptCache {
50    config: PromptCacheConfig,
51    cache: HashMap<String, CachedPrompt>,
52    dirty: bool,
53}
54
55impl PromptCache {
56    pub fn new() -> Self {
57        Self::with_config(PromptCacheConfig::default())
58    }
59
60    pub fn with_config(config: PromptCacheConfig) -> Self {
61        let mut cache = Self {
62            config,
63            cache: HashMap::new(),
64            dirty: false,
65        };
66
67        // Load existing cache
68        let _ = cache.load_cache();
69
70        // Auto cleanup if enabled
71        if cache.config.enable_auto_cleanup {
72            let _ = cache.cleanup_expired();
73        }
74
75        cache
76    }
77
78    /// Get cached optimized prompt
79    pub fn get(&mut self, prompt_hash: &str) -> Option<&CachedPrompt> {
80        if let Some(entry) = self.cache.get_mut(prompt_hash) {
81            entry.last_used = Self::current_timestamp();
82            entry.usage_count += 1;
83            self.dirty = true;
84            Some(entry)
85        } else {
86            None
87        }
88    }
89
90    /// Store optimized prompt in cache
91    pub fn put(&mut self, entry: CachedPrompt) -> Result<(), PromptCacheError> {
92        // Check quality threshold
93        if let Some(quality) = entry.quality_score {
94            if quality < self.config.min_quality_threshold {
95                return Ok(()); // Don't cache low-quality entries
96            }
97        }
98
99        // Check cache size limit
100        if self.cache.len() >= self.config.max_cache_size {
101            self.evict_oldest()?;
102        }
103
104        self.cache.insert(entry.prompt_hash.clone(), entry);
105        self.dirty = true;
106
107        Ok(())
108    }
109
110    /// Check if prompt is cached
111    pub fn contains(&self, prompt_hash: &str) -> bool {
112        self.cache.contains_key(prompt_hash)
113    }
114
115    /// Get cache statistics
116    pub fn stats(&self) -> CacheStats {
117        let total_entries = self.cache.len();
118        let total_usage = self.cache.values().map(|e| e.usage_count).sum::<u32>();
119        let total_tokens_saved = self
120            .cache
121            .values()
122            .filter_map(|e| e.tokens_saved)
123            .sum::<u32>();
124        let avg_quality = if !self.cache.is_empty() {
125            self.cache
126                .values()
127                .filter_map(|e| e.quality_score)
128                .sum::<f64>()
129                / self.cache.len() as f64
130        } else {
131            0.0
132        };
133
134        CacheStats {
135            total_entries,
136            total_usage,
137            total_tokens_saved,
138            avg_quality,
139        }
140    }
141
142    /// Clear all cache entries
143    pub fn clear(&mut self) -> Result<(), PromptCacheError> {
144        self.cache.clear();
145        self.dirty = true;
146        self.save_cache()
147    }
148
149    /// Generate hash for prompt
150    pub fn hash_prompt(prompt: &str) -> String {
151        use sha2::{Digest, Sha256};
152        let mut hasher = Sha256::new();
153        hasher.update(prompt.as_bytes());
154        format!("{:x}", hasher.finalize())
155    }
156
157    /// Save cache to disk
158    pub fn save_cache(&self) -> Result<(), PromptCacheError> {
159        if !self.dirty {
160            return Ok(());
161        }
162
163        // Ensure cache directory exists
164        fs::create_dir_all(&self.config.cache_dir).map_err(|e| PromptCacheError::Io(e))?;
165
166        let cache_path = self.config.cache_dir.join("prompt_cache.json");
167        let data = serde_json::to_string_pretty(&self.cache)
168            .map_err(|e| PromptCacheError::Serialization(e))?;
169
170        fs::write(cache_path, data).map_err(|e| PromptCacheError::Io(e))?;
171
172        Ok(())
173    }
174
175    /// Load cache from disk
176    fn load_cache(&mut self) -> Result<(), PromptCacheError> {
177        let cache_path = self.config.cache_dir.join("prompt_cache.json");
178
179        if !cache_path.exists() {
180            return Ok(());
181        }
182
183        let data = fs::read_to_string(cache_path).map_err(|e| PromptCacheError::Io(e))?;
184
185        self.cache = serde_json::from_str(&data).map_err(|e| PromptCacheError::Serialization(e))?;
186
187        Ok(())
188    }
189
190    /// Clean up expired cache entries
191    fn cleanup_expired(&mut self) -> Result<(), PromptCacheError> {
192        let now = Self::current_timestamp();
193        let max_age_seconds = self.config.max_age_days * 24 * 60 * 60;
194
195        self.cache
196            .retain(|_, entry| now - entry.created_at < max_age_seconds);
197
198        self.dirty = true;
199        Ok(())
200    }
201
202    /// Evict oldest cache entries when cache is full
203    fn evict_oldest(&mut self) -> Result<(), PromptCacheError> {
204        if self.cache.is_empty() {
205            return Ok(());
206        }
207
208        // Find the oldest entry
209        let oldest_key = self
210            .cache
211            .iter()
212            .min_by_key(|(_, entry)| entry.last_used)
213            .map(|(key, _)| key.clone())
214            .unwrap();
215
216        self.cache.remove(&oldest_key);
217        self.dirty = true;
218
219        Ok(())
220    }
221
222    /// Get current timestamp
223    fn current_timestamp() -> u64 {
224        SystemTime::now()
225            .duration_since(UNIX_EPOCH)
226            .unwrap()
227            .as_secs()
228    }
229}
230
231impl Drop for PromptCache {
232    fn drop(&mut self) {
233        let _ = self.save_cache();
234    }
235}
236
237/// Cache statistics
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct CacheStats {
240    pub total_entries: usize,
241    pub total_usage: u32,
242    pub total_tokens_saved: u32,
243    pub avg_quality: f64,
244}
245
246/// Prompt cache errors
247#[derive(Debug, thiserror::Error)]
248pub enum PromptCacheError {
249    #[error("IO error: {0}")]
250    Io(#[from] std::io::Error),
251
252    #[error("Serialization error: {0}")]
253    Serialization(#[from] serde_json::Error),
254
255    #[error("Cache full")]
256    CacheFull,
257}
258
259/// Prompt optimizer that uses caching
260pub struct PromptOptimizer {
261    cache: PromptCache,
262    llm_provider: Box<dyn crate::llm::provider::LLMProvider>,
263}
264
265impl PromptOptimizer {
266    pub fn new(llm_provider: Box<dyn crate::llm::provider::LLMProvider>) -> Self {
267        Self {
268            cache: PromptCache::new(),
269            llm_provider,
270        }
271    }
272
273    pub fn with_cache(mut self, cache: PromptCache) -> Self {
274        self.cache = cache;
275        self
276    }
277
278    /// Optimize a prompt using caching
279    pub async fn optimize_prompt(
280        &mut self,
281        original_prompt: &str,
282        target_model: &str,
283        context: Option<&str>,
284    ) -> Result<String, PromptOptimizationError> {
285        let prompt_hash = PromptCache::hash_prompt(original_prompt);
286
287        // Check cache first
288        if let Some(cached) = self.cache.get(&prompt_hash) {
289            return Ok(cached.optimized_prompt.clone());
290        }
291
292        // Generate optimized prompt
293        let optimized = self
294            .generate_optimized_prompt(original_prompt, target_model, context)
295            .await?;
296
297        // Calculate tokens saved (rough estimate)
298        let original_tokens = Self::estimate_tokens(original_prompt);
299        let optimized_tokens = Self::estimate_tokens(&optimized);
300        let tokens_saved = original_tokens.saturating_sub(optimized_tokens);
301
302        // Create cache entry
303        let entry = CachedPrompt {
304            prompt_hash: prompt_hash.clone(),
305            original_prompt: original_prompt.to_string(),
306            optimized_prompt: optimized.clone(),
307            model_used: target_model.to_string(),
308            tokens_saved: Some(tokens_saved),
309            quality_score: Some(0.8), // Placeholder quality score
310            created_at: PromptCache::current_timestamp(),
311            last_used: PromptCache::current_timestamp(),
312            usage_count: 1,
313        };
314
315        // Store in cache
316        self.cache.put(entry)?;
317
318        Ok(optimized)
319    }
320
321    /// Generate optimized prompt using LLM
322    async fn generate_optimized_prompt(
323        &self,
324        original_prompt: &str,
325        target_model: &str,
326        context: Option<&str>,
327    ) -> Result<String, PromptOptimizationError> {
328        let system_prompt = format!(
329            "You are an expert prompt engineer. Your task is to optimize prompts for {} \
330             to make them more effective, clearer, and more likely to produce high-quality responses. \
331             Focus on improving clarity, specificity, structure, and effectiveness while preserving \
332             the original intent and requirements.",
333            target_model
334        );
335
336        let mut user_prompt = format!(
337            "Please optimize the following prompt for {}:\n\nORIGINAL PROMPT:\n{}\n\n",
338            target_model, original_prompt
339        );
340
341        if let Some(ctx) = context {
342            user_prompt.push_str(&format!("CONTEXT:\n{}\n\n", ctx));
343        }
344
345        user_prompt.push_str(
346            "OPTIMIZATION REQUIREMENTS:\n\
347             1. Make the prompt clearer and more specific\n\
348             2. Improve structure and formatting\n\
349             3. Add relevant context or examples if helpful\n\
350             4. Ensure the prompt is appropriate for the target model\n\
351             5. Maintain the original intent and requirements\n\
352             6. Keep the optimized prompt concise but comprehensive\n\n\
353             Provide only the optimized prompt without any explanation or additional text.",
354        );
355
356        let request = crate::llm::provider::LLMRequest {
357            messages: vec![
358                Message {
359                    role: MessageRole::System,
360                    content: system_prompt,
361                    tool_calls: None,
362                    tool_call_id: None,
363                },
364                Message {
365                    role: MessageRole::User,
366                    content: user_prompt,
367                    tool_calls: None,
368                    tool_call_id: None,
369                },
370            ],
371            system_prompt: None,
372            tools: None,
373            model: target_model.to_string(),
374            max_tokens: Some(2000),
375            temperature: Some(0.3),
376            stream: false,
377            tool_choice: None,
378            parallel_tool_calls: None,
379            parallel_tool_config: None,
380            reasoning_effort: None,
381        };
382
383        let response = self
384            .llm_provider
385            .generate(request)
386            .await
387            .map_err(|e| PromptOptimizationError::LLMError(e.to_string()))?;
388
389        Ok(response
390            .content
391            .unwrap_or_else(|| original_prompt.to_string()))
392    }
393
394    /// Estimate token count (rough approximation)
395    fn estimate_tokens(text: &str) -> u32 {
396        // Rough approximation: 1 token ≈ 4 characters for English text
397        (text.len() / 4) as u32
398    }
399
400    /// Get cache statistics
401    pub fn cache_stats(&self) -> CacheStats {
402        self.cache.stats()
403    }
404
405    /// Clear cache
406    pub fn clear_cache(&mut self) -> Result<(), PromptCacheError> {
407        self.cache.clear()
408    }
409}
410
411/// Prompt optimization errors
412#[derive(Debug, thiserror::Error)]
413pub enum PromptOptimizationError {
414    #[error("LLM error: {0}")]
415    LLMError(String),
416
417    #[error("Cache error: {0}")]
418    CacheError(#[from] PromptCacheError),
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_prompt_hash() {
427        let prompt = "Test prompt";
428        let hash1 = PromptCache::hash_prompt(prompt);
429        let hash2 = PromptCache::hash_prompt(prompt);
430        assert_eq!(hash1, hash2);
431        assert!(!hash1.is_empty());
432    }
433
434    #[test]
435    fn test_cache_operations() {
436        let mut cache = PromptCache::new();
437
438        let entry = CachedPrompt {
439            prompt_hash: "test_hash".to_string(),
440            original_prompt: "original".to_string(),
441            optimized_prompt: "optimized".to_string(),
442            model_used: crate::config::constants::models::GEMINI_2_5_FLASH.to_string(),
443            tokens_saved: Some(100),
444            quality_score: Some(0.9),
445            created_at: 1000,
446            last_used: 1000,
447            usage_count: 0,
448        };
449
450        cache.put(entry).unwrap();
451        assert!(cache.contains("test_hash"));
452
453        let retrieved = cache.get("test_hash");
454        assert!(retrieved.is_some());
455        assert_eq!(retrieved.unwrap().usage_count, 1);
456    }
457}