1use crate::llm::provider::{Message, MessageRole};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs;
5use std::path::PathBuf;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CachedPrompt {
11 pub prompt_hash: String,
12 pub original_prompt: String,
13 pub optimized_prompt: String,
14 pub model_used: String,
15 pub tokens_saved: Option<u32>,
16 pub quality_score: Option<f64>,
17 pub created_at: u64,
18 pub last_used: u64,
19 pub usage_count: u32,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct PromptCacheConfig {
25 pub cache_dir: PathBuf,
26 pub max_cache_size: usize,
27 pub max_age_days: u64,
28 pub enable_auto_cleanup: bool,
29 pub min_quality_threshold: f64,
30}
31
32impl Default for PromptCacheConfig {
33 fn default() -> Self {
34 Self {
35 cache_dir: dirs::home_dir()
36 .unwrap_or_else(|| PathBuf::from("."))
37 .join(".vtcode")
38 .join("cache")
39 .join("prompts"),
40 max_cache_size: 1000, max_age_days: 30, enable_auto_cleanup: true,
43 min_quality_threshold: 0.7, }
45 }
46}
47
48pub struct PromptCache {
50 config: PromptCacheConfig,
51 cache: HashMap<String, CachedPrompt>,
52 dirty: bool,
53}
54
55impl PromptCache {
56 pub fn new() -> Self {
57 Self::with_config(PromptCacheConfig::default())
58 }
59
60 pub fn with_config(config: PromptCacheConfig) -> Self {
61 let mut cache = Self {
62 config,
63 cache: HashMap::new(),
64 dirty: false,
65 };
66
67 let _ = cache.load_cache();
69
70 if cache.config.enable_auto_cleanup {
72 let _ = cache.cleanup_expired();
73 }
74
75 cache
76 }
77
78 pub fn get(&mut self, prompt_hash: &str) -> Option<&CachedPrompt> {
80 if let Some(entry) = self.cache.get_mut(prompt_hash) {
81 entry.last_used = Self::current_timestamp();
82 entry.usage_count += 1;
83 self.dirty = true;
84 Some(entry)
85 } else {
86 None
87 }
88 }
89
90 pub fn put(&mut self, entry: CachedPrompt) -> Result<(), PromptCacheError> {
92 if let Some(quality) = entry.quality_score {
94 if quality < self.config.min_quality_threshold {
95 return Ok(()); }
97 }
98
99 if self.cache.len() >= self.config.max_cache_size {
101 self.evict_oldest()?;
102 }
103
104 self.cache.insert(entry.prompt_hash.clone(), entry);
105 self.dirty = true;
106
107 Ok(())
108 }
109
110 pub fn contains(&self, prompt_hash: &str) -> bool {
112 self.cache.contains_key(prompt_hash)
113 }
114
115 pub fn stats(&self) -> CacheStats {
117 let total_entries = self.cache.len();
118 let total_usage = self.cache.values().map(|e| e.usage_count).sum::<u32>();
119 let total_tokens_saved = self
120 .cache
121 .values()
122 .filter_map(|e| e.tokens_saved)
123 .sum::<u32>();
124 let avg_quality = if !self.cache.is_empty() {
125 self.cache
126 .values()
127 .filter_map(|e| e.quality_score)
128 .sum::<f64>()
129 / self.cache.len() as f64
130 } else {
131 0.0
132 };
133
134 CacheStats {
135 total_entries,
136 total_usage,
137 total_tokens_saved,
138 avg_quality,
139 }
140 }
141
142 pub fn clear(&mut self) -> Result<(), PromptCacheError> {
144 self.cache.clear();
145 self.dirty = true;
146 self.save_cache()
147 }
148
149 pub fn hash_prompt(prompt: &str) -> String {
151 use sha2::{Digest, Sha256};
152 let mut hasher = Sha256::new();
153 hasher.update(prompt.as_bytes());
154 format!("{:x}", hasher.finalize())
155 }
156
157 pub fn save_cache(&self) -> Result<(), PromptCacheError> {
159 if !self.dirty {
160 return Ok(());
161 }
162
163 fs::create_dir_all(&self.config.cache_dir).map_err(|e| PromptCacheError::Io(e))?;
165
166 let cache_path = self.config.cache_dir.join("prompt_cache.json");
167 let data = serde_json::to_string_pretty(&self.cache)
168 .map_err(|e| PromptCacheError::Serialization(e))?;
169
170 fs::write(cache_path, data).map_err(|e| PromptCacheError::Io(e))?;
171
172 Ok(())
173 }
174
175 fn load_cache(&mut self) -> Result<(), PromptCacheError> {
177 let cache_path = self.config.cache_dir.join("prompt_cache.json");
178
179 if !cache_path.exists() {
180 return Ok(());
181 }
182
183 let data = fs::read_to_string(cache_path).map_err(|e| PromptCacheError::Io(e))?;
184
185 self.cache = serde_json::from_str(&data).map_err(|e| PromptCacheError::Serialization(e))?;
186
187 Ok(())
188 }
189
190 fn cleanup_expired(&mut self) -> Result<(), PromptCacheError> {
192 let now = Self::current_timestamp();
193 let max_age_seconds = self.config.max_age_days * 24 * 60 * 60;
194
195 self.cache
196 .retain(|_, entry| now - entry.created_at < max_age_seconds);
197
198 self.dirty = true;
199 Ok(())
200 }
201
202 fn evict_oldest(&mut self) -> Result<(), PromptCacheError> {
204 if self.cache.is_empty() {
205 return Ok(());
206 }
207
208 let oldest_key = self
210 .cache
211 .iter()
212 .min_by_key(|(_, entry)| entry.last_used)
213 .map(|(key, _)| key.clone())
214 .unwrap();
215
216 self.cache.remove(&oldest_key);
217 self.dirty = true;
218
219 Ok(())
220 }
221
222 fn current_timestamp() -> u64 {
224 SystemTime::now()
225 .duration_since(UNIX_EPOCH)
226 .unwrap()
227 .as_secs()
228 }
229}
230
231impl Drop for PromptCache {
232 fn drop(&mut self) {
233 let _ = self.save_cache();
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct CacheStats {
240 pub total_entries: usize,
241 pub total_usage: u32,
242 pub total_tokens_saved: u32,
243 pub avg_quality: f64,
244}
245
246#[derive(Debug, thiserror::Error)]
248pub enum PromptCacheError {
249 #[error("IO error: {0}")]
250 Io(#[from] std::io::Error),
251
252 #[error("Serialization error: {0}")]
253 Serialization(#[from] serde_json::Error),
254
255 #[error("Cache full")]
256 CacheFull,
257}
258
259pub struct PromptOptimizer {
261 cache: PromptCache,
262 llm_provider: Box<dyn crate::llm::provider::LLMProvider>,
263}
264
265impl PromptOptimizer {
266 pub fn new(llm_provider: Box<dyn crate::llm::provider::LLMProvider>) -> Self {
267 Self {
268 cache: PromptCache::new(),
269 llm_provider,
270 }
271 }
272
273 pub fn with_cache(mut self, cache: PromptCache) -> Self {
274 self.cache = cache;
275 self
276 }
277
278 pub async fn optimize_prompt(
280 &mut self,
281 original_prompt: &str,
282 target_model: &str,
283 context: Option<&str>,
284 ) -> Result<String, PromptOptimizationError> {
285 let prompt_hash = PromptCache::hash_prompt(original_prompt);
286
287 if let Some(cached) = self.cache.get(&prompt_hash) {
289 return Ok(cached.optimized_prompt.clone());
290 }
291
292 let optimized = self
294 .generate_optimized_prompt(original_prompt, target_model, context)
295 .await?;
296
297 let original_tokens = Self::estimate_tokens(original_prompt);
299 let optimized_tokens = Self::estimate_tokens(&optimized);
300 let tokens_saved = original_tokens.saturating_sub(optimized_tokens);
301
302 let entry = CachedPrompt {
304 prompt_hash: prompt_hash.clone(),
305 original_prompt: original_prompt.to_string(),
306 optimized_prompt: optimized.clone(),
307 model_used: target_model.to_string(),
308 tokens_saved: Some(tokens_saved),
309 quality_score: Some(0.8), created_at: PromptCache::current_timestamp(),
311 last_used: PromptCache::current_timestamp(),
312 usage_count: 1,
313 };
314
315 self.cache.put(entry)?;
317
318 Ok(optimized)
319 }
320
321 async fn generate_optimized_prompt(
323 &self,
324 original_prompt: &str,
325 target_model: &str,
326 context: Option<&str>,
327 ) -> Result<String, PromptOptimizationError> {
328 let system_prompt = format!(
329 "You are an expert prompt engineer. Your task is to optimize prompts for {} \
330 to make them more effective, clearer, and more likely to produce high-quality responses. \
331 Focus on improving clarity, specificity, structure, and effectiveness while preserving \
332 the original intent and requirements.",
333 target_model
334 );
335
336 let mut user_prompt = format!(
337 "Please optimize the following prompt for {}:\n\nORIGINAL PROMPT:\n{}\n\n",
338 target_model, original_prompt
339 );
340
341 if let Some(ctx) = context {
342 user_prompt.push_str(&format!("CONTEXT:\n{}\n\n", ctx));
343 }
344
345 user_prompt.push_str(
346 "OPTIMIZATION REQUIREMENTS:\n\
347 1. Make the prompt clearer and more specific\n\
348 2. Improve structure and formatting\n\
349 3. Add relevant context or examples if helpful\n\
350 4. Ensure the prompt is appropriate for the target model\n\
351 5. Maintain the original intent and requirements\n\
352 6. Keep the optimized prompt concise but comprehensive\n\n\
353 Provide only the optimized prompt without any explanation or additional text.",
354 );
355
356 let request = crate::llm::provider::LLMRequest {
357 messages: vec![
358 Message {
359 role: MessageRole::System,
360 content: system_prompt,
361 tool_calls: None,
362 tool_call_id: None,
363 },
364 Message {
365 role: MessageRole::User,
366 content: user_prompt,
367 tool_calls: None,
368 tool_call_id: None,
369 },
370 ],
371 system_prompt: None,
372 tools: None,
373 model: target_model.to_string(),
374 max_tokens: Some(2000),
375 temperature: Some(0.3),
376 stream: false,
377 tool_choice: None,
378 parallel_tool_calls: None,
379 parallel_tool_config: None,
380 reasoning_effort: None,
381 };
382
383 let response = self
384 .llm_provider
385 .generate(request)
386 .await
387 .map_err(|e| PromptOptimizationError::LLMError(e.to_string()))?;
388
389 Ok(response
390 .content
391 .unwrap_or_else(|| original_prompt.to_string()))
392 }
393
394 fn estimate_tokens(text: &str) -> u32 {
396 (text.len() / 4) as u32
398 }
399
400 pub fn cache_stats(&self) -> CacheStats {
402 self.cache.stats()
403 }
404
405 pub fn clear_cache(&mut self) -> Result<(), PromptCacheError> {
407 self.cache.clear()
408 }
409}
410
411#[derive(Debug, thiserror::Error)]
413pub enum PromptOptimizationError {
414 #[error("LLM error: {0}")]
415 LLMError(String),
416
417 #[error("Cache error: {0}")]
418 CacheError(#[from] PromptCacheError),
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_prompt_hash() {
427 let prompt = "Test prompt";
428 let hash1 = PromptCache::hash_prompt(prompt);
429 let hash2 = PromptCache::hash_prompt(prompt);
430 assert_eq!(hash1, hash2);
431 assert!(!hash1.is_empty());
432 }
433
434 #[test]
435 fn test_cache_operations() {
436 let mut cache = PromptCache::new();
437
438 let entry = CachedPrompt {
439 prompt_hash: "test_hash".to_string(),
440 original_prompt: "original".to_string(),
441 optimized_prompt: "optimized".to_string(),
442 model_used: crate::config::constants::models::GEMINI_2_5_FLASH.to_string(),
443 tokens_saved: Some(100),
444 quality_score: Some(0.9),
445 created_at: 1000,
446 last_used: 1000,
447 usage_count: 0,
448 };
449
450 cache.put(entry).unwrap();
451 assert!(cache.contains("test_hash"));
452
453 let retrieved = cache.get("test_hash");
454 assert!(retrieved.is_some());
455 assert_eq!(retrieved.unwrap().usage_count, 1);
456 }
457}