1use crate::config::constants::prompt_cache;
2use crate::config::core::PromptCachingConfig;
3use crate::llm::provider::{Message, MessageContent, MessageRole};
4use hashbrown::HashMap;
5use serde::{Deserialize, Serialize};
6use std::fmt::Write;
7use std::path::{Path, PathBuf};
8use tokio::fs;
9use vtcode_commons::utils::current_timestamp;
10
11use crate::utils::tokens::estimate_tokens;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CachedPrompt {
16 pub prompt_hash: String,
17 pub original_prompt: String,
18 pub optimized_prompt: String,
19 pub model_used: String,
20 pub tokens_saved: Option<u32>,
21 pub quality_score: Option<f64>,
22 pub created_at: u64,
23 pub last_used: u64,
24 pub usage_count: u32,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PromptCacheConfig {
30 pub enabled: bool,
31 pub cache_dir: PathBuf,
32 pub max_cache_size: usize,
33 pub max_age_days: u64,
34 pub enable_auto_cleanup: bool,
35 pub min_quality_threshold: f64,
36}
37
38impl Default for PromptCacheConfig {
39 fn default() -> Self {
40 Self {
41 enabled: prompt_cache::DEFAULT_ENABLED,
42 cache_dir: default_cache_dir(),
43 max_cache_size: prompt_cache::DEFAULT_MAX_ENTRIES,
44 max_age_days: prompt_cache::DEFAULT_MAX_AGE_DAYS,
45 enable_auto_cleanup: prompt_cache::DEFAULT_AUTO_CLEANUP,
46 min_quality_threshold: prompt_cache::DEFAULT_MIN_QUALITY_THRESHOLD,
47 }
48 }
49}
50
51impl PromptCacheConfig {
52 pub fn from_settings(settings: &PromptCachingConfig, workspace_root: Option<&Path>) -> Self {
54 Self {
55 enabled: settings.enabled,
56 cache_dir: settings.resolve_cache_dir(workspace_root),
57 max_cache_size: settings.max_entries,
58 max_age_days: settings.max_age_days,
59 enable_auto_cleanup: settings.enable_auto_cleanup,
60 min_quality_threshold: settings.min_quality_threshold,
61 }
62 }
63
64 pub fn is_enabled(&self) -> bool {
65 self.enabled
66 }
67}
68
69fn default_cache_dir() -> PathBuf {
70 if let Some(home) = dirs::home_dir() {
71 return home.join(prompt_cache::DEFAULT_CACHE_DIR);
72 }
73 PathBuf::from(prompt_cache::DEFAULT_CACHE_DIR)
74}
75
76pub struct PromptCache {
78 config: PromptCacheConfig,
79 cache: HashMap<String, CachedPrompt>,
80 dirty: bool,
81}
82
83impl PromptCache {
84 pub async fn new() -> Self {
85 Self::with_config(PromptCacheConfig::default()).await
86 }
87
88 pub async fn with_config(config: PromptCacheConfig) -> Self {
89 let mut cache = Self {
90 config,
91 cache: HashMap::new(),
92 dirty: false,
93 };
94
95 if cache.config.enabled {
97 let _ = cache.load_cache().await;
98
99 if cache.config.enable_auto_cleanup {
101 let _ = cache.cleanup_expired();
102 }
103 }
104
105 cache
106 }
107
108 pub fn get(&mut self, prompt_hash: &str) -> Option<CachedPrompt> {
110 if !self.config.enabled {
111 return None;
112 }
113 self.cache.get_mut(prompt_hash).map(|entry| {
114 entry.last_used = current_timestamp();
115 entry.usage_count += 1;
116 self.dirty = true;
117 entry.clone()
118 })
119 }
120
121 pub fn put(&mut self, entry: CachedPrompt) -> Result<(), PromptCacheError> {
123 if !self.config.enabled {
124 return Ok(());
125 }
126 if entry
128 .quality_score
129 .is_some_and(|quality| quality < self.config.min_quality_threshold)
130 {
131 return Ok(()); }
133
134 if self.cache.len() >= self.config.max_cache_size {
136 self.evict_oldest()?;
137 }
138
139 self.cache.insert(entry.prompt_hash.clone(), entry);
140 self.dirty = true;
141
142 Ok(())
143 }
144
145 pub fn contains(&self, prompt_hash: &str) -> bool {
147 self.config.enabled && self.cache.contains_key(prompt_hash)
148 }
149
150 pub fn stats(&self) -> CacheStats {
152 if !self.config.enabled {
153 return CacheStats::default();
154 }
155 let total_entries = self.cache.len();
156 let total_usage = self.cache.values().map(|e| e.usage_count).sum::<u32>();
157 let total_tokens_saved = self
158 .cache
159 .values()
160 .filter_map(|e| e.tokens_saved)
161 .sum::<u32>();
162 let avg_quality = if !self.cache.is_empty() {
163 self.cache
164 .values()
165 .filter_map(|e| e.quality_score)
166 .sum::<f64>()
167 / self.cache.len() as f64
168 } else {
169 0.0
170 };
171
172 CacheStats {
173 total_entries,
174 total_usage,
175 total_tokens_saved,
176 avg_quality,
177 }
178 }
179
180 pub async fn clear(&mut self) -> Result<(), PromptCacheError> {
182 if !self.config.enabled {
183 return Ok(());
184 }
185 self.cache.clear();
186 self.dirty = true;
187 self.save_cache().await
188 }
189
190 pub fn hash_prompt(prompt: &str) -> String {
192 vtcode_commons::utils::calculate_sha256(prompt.as_bytes())
193 }
194
195 pub async fn save_cache(&self) -> Result<(), PromptCacheError> {
197 if !self.config.enabled || !self.dirty {
198 return Ok(());
199 }
200
201 fs::create_dir_all(&self.config.cache_dir)
203 .await
204 .map_err(PromptCacheError::Io)?;
205
206 let cache_path = self.config.cache_dir.join("prompt_cache.json");
207 let data =
208 serde_json::to_string_pretty(&self.cache).map_err(PromptCacheError::Serialization)?;
209
210 fs::write(cache_path, data)
211 .await
212 .map_err(PromptCacheError::Io)?;
213
214 Ok(())
215 }
216
217 async fn load_cache(&mut self) -> Result<(), PromptCacheError> {
219 if !self.config.enabled {
220 return Ok(());
221 }
222 let cache_path = self.config.cache_dir.join("prompt_cache.json");
223
224 if !fs::try_exists(&cache_path).await.unwrap_or(false) {
225 return Ok(());
226 }
227
228 let data = fs::read_to_string(cache_path)
229 .await
230 .map_err(PromptCacheError::Io)?;
231
232 self.cache = serde_json::from_str(&data).map_err(PromptCacheError::Serialization)?;
233
234 Ok(())
235 }
236
237 fn cleanup_expired(&mut self) -> Result<(), PromptCacheError> {
239 if !self.config.enabled {
240 return Ok(());
241 }
242 let now = current_timestamp();
243 let max_age_seconds = self.config.max_age_days * 24 * 60 * 60;
244
245 self.cache
246 .retain(|_, entry| now - entry.created_at < max_age_seconds);
247
248 self.dirty = true;
249 Ok(())
250 }
251
252 fn evict_oldest(&mut self) -> Result<(), PromptCacheError> {
254 if !self.config.enabled {
255 return Ok(());
256 }
257 if self.cache.is_empty() {
258 return Ok(());
259 }
260
261 let Some(oldest_key) = self
263 .cache
264 .iter()
265 .min_by_key(|(_, entry)| entry.last_used)
266 .map(|(key, _)| key.clone())
267 else {
268 return Ok(());
269 };
270
271 self.cache.remove(&oldest_key);
272 self.dirty = true;
273
274 Ok(())
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct CacheStats {
284 pub total_entries: usize,
285 pub total_usage: u32,
286 pub total_tokens_saved: u32,
287 pub avg_quality: f64,
288}
289
290impl Default for CacheStats {
291 fn default() -> Self {
292 Self {
293 total_entries: 0,
294 total_usage: 0,
295 total_tokens_saved: 0,
296 avg_quality: 0.0,
297 }
298 }
299}
300
301#[derive(Debug, thiserror::Error)]
303pub enum PromptCacheError {
304 #[error("IO error: {0}")]
305 Io(#[from] std::io::Error),
306
307 #[error("Serialization error: {0}")]
308 Serialization(#[from] serde_json::Error),
309
310 #[error("Cache full")]
311 CacheFull,
312}
313
314pub struct PromptOptimizer {
316 cache: PromptCache,
317 llm_provider: Box<dyn crate::llm::provider::LLMProvider>,
318}
319
320impl PromptOptimizer {
321 pub async fn new(llm_provider: Box<dyn crate::llm::provider::LLMProvider>) -> Self {
322 Self {
323 cache: PromptCache::new().await,
324 llm_provider,
325 }
326 }
327
328 pub fn with_cache(mut self, cache: PromptCache) -> Self {
329 self.cache = cache;
330 self
331 }
332
333 pub async fn save_cache(&self) -> Result<(), PromptCacheError> {
335 self.cache.save_cache().await
336 }
337
338 pub async fn optimize_prompt(
340 &mut self,
341 original_prompt: &str,
342 target_model: &str,
343 context: Option<&str>,
344 ) -> Result<String, PromptOptimizationError> {
345 let prompt_hash = PromptCache::hash_prompt(original_prompt);
346
347 if let Some(cached) = self.cache.get(&prompt_hash) {
349 return Ok(cached.optimized_prompt);
350 }
351
352 let optimized = self
354 .generate_optimized_prompt(original_prompt, target_model, context)
355 .await?;
356
357 let original_tokens = estimate_tokens(original_prompt);
359 let optimized_tokens = estimate_tokens(&optimized);
360 let tokens_saved = original_tokens.saturating_sub(optimized_tokens);
361
362 let entry = CachedPrompt {
364 prompt_hash: prompt_hash.clone(),
365 original_prompt: original_prompt.to_string(),
366 optimized_prompt: optimized.clone(),
367 model_used: target_model.to_string(),
368 tokens_saved: Some(tokens_saved.min(u32::MAX as usize) as u32),
369 quality_score: Some(0.8), created_at: current_timestamp(),
371 last_used: current_timestamp(),
372 usage_count: 1,
373 };
374
375 self.cache.put(entry)?;
377
378 Ok(optimized)
379 }
380
381 async fn generate_optimized_prompt(
383 &self,
384 original_prompt: &str,
385 target_model: &str,
386 context: Option<&str>,
387 ) -> Result<String, PromptOptimizationError> {
388 let system_prompt = format!(
389 "You are an expert prompt engineer. Your task is to optimize prompts for {} \
390 to make them more effective, clearer, and more likely to produce high-quality responses. \
391 Focus on improving clarity, specificity, structure, and effectiveness while preserving \
392 the original intent and requirements.",
393 target_model
394 );
395
396 let mut user_prompt = format!(
397 "Please optimize the following prompt for {}:\n\nORIGINAL PROMPT:\n{}\n\n",
398 target_model, original_prompt
399 );
400
401 if let Some(ctx) = context {
402 let _ = write!(user_prompt, "CONTEXT:\n{}\n\n", ctx);
403 }
404
405 user_prompt.push_str(
406 "OPTIMIZATION REQUIREMENTS:\n\
407 1. Make the prompt clearer and more specific\n\
408 2. Improve structure and formatting\n\
409 3. Add relevant context or examples if helpful\n\
410 4. Ensure the prompt is appropriate for the target model\n\
411 5. Maintain the original intent and requirements\n\
412 6. Keep the optimized prompt concise but comprehensive\n\n\
413 Provide only the optimized prompt without any explanation or additional text.",
414 );
415
416 let request = crate::llm::provider::LLMRequest {
417 messages: vec![
418 Message {
419 role: MessageRole::System,
420 content: MessageContent::Text(system_prompt),
421 ..Default::default()
422 },
423 Message {
424 role: MessageRole::User,
425 content: MessageContent::Text(user_prompt),
426 ..Default::default()
427 },
428 ],
429 model: target_model.to_string(),
430 max_tokens: Some(2000),
431 temperature: Some(0.3),
432 ..Default::default()
433 };
434
435 let response = self
436 .llm_provider
437 .generate(request)
438 .await
439 .map_err(|e| PromptOptimizationError::LLMError(e.to_string()))?;
440
441 Ok(response
442 .content
443 .unwrap_or_else(|| original_prompt.to_string()))
444 }
445
446 pub fn cache_stats(&self) -> CacheStats {
448 self.cache.stats()
449 }
450
451 pub async fn clear_cache(&mut self) -> Result<(), PromptCacheError> {
453 self.cache.clear().await
454 }
455}
456
457#[derive(Debug, thiserror::Error)]
459pub enum PromptOptimizationError {
460 #[error("LLM error: {0}")]
461 LLMError(String),
462
463 #[error("Cache error: {0}")]
464 CacheError(#[from] PromptCacheError),
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 #[test]
472 fn test_prompt_hash() {
473 let prompt = "Test prompt";
474 let hash1 = PromptCache::hash_prompt(prompt);
475 let hash2 = PromptCache::hash_prompt(prompt);
476 assert_eq!(hash1, hash2);
477 assert!(!hash1.is_empty());
478 }
479
480 #[tokio::test]
481 async fn test_cache_operations() {
482 let mut cache = PromptCache::new().await;
483
484 let entry = CachedPrompt {
485 prompt_hash: "test_hash".to_owned(),
486 original_prompt: "original".to_owned(),
487 optimized_prompt: "optimized".to_owned(),
488 model_used: crate::config::constants::models::google::GEMINI_3_FLASH_PREVIEW.to_owned(),
489 tokens_saved: Some(100),
490 quality_score: Some(0.9),
491 created_at: 1000,
492 last_used: 1000,
493 usage_count: 0,
494 };
495
496 cache.put(entry).unwrap();
497 assert!(cache.contains("test_hash"));
498
499 let retrieved = cache.get("test_hash");
500 assert!(retrieved.is_some());
501 assert_eq!(retrieved.unwrap().usage_count, 1);
502 }
503
504 #[tokio::test]
505 async fn disabled_cache_config_is_no_op() {
506 let settings = PromptCachingConfig {
507 enabled: false,
508 cache_dir: "relative/cache".to_owned(),
509 ..PromptCachingConfig::default()
510 };
511 let cfg = PromptCacheConfig::from_settings(&settings, None);
512 assert!(!cfg.is_enabled());
513
514 let mut cache = PromptCache::with_config(cfg).await;
515 assert!(!cache.contains("missing"));
516 assert_eq!(cache.stats().total_entries, 0);
517
518 let entry = CachedPrompt {
519 prompt_hash: "noop".to_owned(),
520 original_prompt: "original".to_owned(),
521 optimized_prompt: "optimized".to_owned(),
522 model_used: crate::config::constants::models::google::GEMINI_3_FLASH_PREVIEW.to_owned(),
523 tokens_saved: Some(10),
524 quality_score: Some(0.9),
525 created_at: 1,
526 last_used: 1,
527 usage_count: 0,
528 };
529
530 cache.put(entry).unwrap();
531 assert!(!cache.contains("noop"));
532 assert_eq!(cache.stats().total_entries, 0);
533 }
534}