1use crate::config::constants::prompt_cache;
2use crate::config::core::PromptCachingConfig;
3use crate::llm::provider::{Message, MessageRole};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::{Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CachedPrompt {
13 pub prompt_hash: String,
14 pub original_prompt: String,
15 pub optimized_prompt: String,
16 pub model_used: String,
17 pub tokens_saved: Option<u32>,
18 pub quality_score: Option<f64>,
19 pub created_at: u64,
20 pub last_used: u64,
21 pub usage_count: u32,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct PromptCacheConfig {
27 pub enabled: bool,
28 pub cache_dir: PathBuf,
29 pub max_cache_size: usize,
30 pub max_age_days: u64,
31 pub enable_auto_cleanup: bool,
32 pub min_quality_threshold: f64,
33}
34
35impl Default for PromptCacheConfig {
36 fn default() -> Self {
37 Self {
38 enabled: prompt_cache::DEFAULT_ENABLED,
39 cache_dir: default_cache_dir(),
40 max_cache_size: prompt_cache::DEFAULT_MAX_ENTRIES,
41 max_age_days: prompt_cache::DEFAULT_MAX_AGE_DAYS,
42 enable_auto_cleanup: prompt_cache::DEFAULT_AUTO_CLEANUP,
43 min_quality_threshold: prompt_cache::DEFAULT_MIN_QUALITY_THRESHOLD,
44 }
45 }
46}
47
48impl PromptCacheConfig {
49 pub fn from_settings(settings: &PromptCachingConfig, workspace_root: Option<&Path>) -> Self {
51 Self {
52 enabled: settings.enabled,
53 cache_dir: settings.resolve_cache_dir(workspace_root),
54 max_cache_size: settings.max_entries,
55 max_age_days: settings.max_age_days,
56 enable_auto_cleanup: settings.enable_auto_cleanup,
57 min_quality_threshold: settings.min_quality_threshold,
58 }
59 }
60
61 pub fn is_enabled(&self) -> bool {
62 self.enabled
63 }
64}
65
66fn default_cache_dir() -> PathBuf {
67 if let Some(home) = dirs::home_dir() {
68 return home.join(prompt_cache::DEFAULT_CACHE_DIR);
69 }
70 PathBuf::from(prompt_cache::DEFAULT_CACHE_DIR)
71}
72
73pub struct PromptCache {
75 config: PromptCacheConfig,
76 cache: HashMap<String, CachedPrompt>,
77 dirty: bool,
78}
79
80impl PromptCache {
81 pub fn new() -> Self {
82 Self::with_config(PromptCacheConfig::default())
83 }
84
85 pub fn with_config(config: PromptCacheConfig) -> Self {
86 let mut cache = Self {
87 config,
88 cache: HashMap::new(),
89 dirty: false,
90 };
91
92 if cache.config.enabled {
94 let _ = cache.load_cache();
95
96 if cache.config.enable_auto_cleanup {
98 let _ = cache.cleanup_expired();
99 }
100 }
101
102 cache
103 }
104
105 pub fn get(&mut self, prompt_hash: &str) -> Option<&CachedPrompt> {
107 if !self.config.enabled {
108 return None;
109 }
110 if let Some(entry) = self.cache.get_mut(prompt_hash) {
111 entry.last_used = Self::current_timestamp();
112 entry.usage_count += 1;
113 self.dirty = true;
114 Some(entry)
115 } else {
116 None
117 }
118 }
119
120 pub fn put(&mut self, entry: CachedPrompt) -> Result<(), PromptCacheError> {
122 if !self.config.enabled {
123 return Ok(());
124 }
125 if let Some(quality) = entry.quality_score {
127 if quality < self.config.min_quality_threshold {
128 return Ok(()); }
130 }
131
132 if self.cache.len() >= self.config.max_cache_size {
134 self.evict_oldest()?;
135 }
136
137 self.cache.insert(entry.prompt_hash.clone(), entry);
138 self.dirty = true;
139
140 Ok(())
141 }
142
143 pub fn contains(&self, prompt_hash: &str) -> bool {
145 self.config.enabled && self.cache.contains_key(prompt_hash)
146 }
147
148 pub fn stats(&self) -> CacheStats {
150 if !self.config.enabled {
151 return CacheStats::default();
152 }
153 let total_entries = self.cache.len();
154 let total_usage = self.cache.values().map(|e| e.usage_count).sum::<u32>();
155 let total_tokens_saved = self
156 .cache
157 .values()
158 .filter_map(|e| e.tokens_saved)
159 .sum::<u32>();
160 let avg_quality = if !self.cache.is_empty() {
161 self.cache
162 .values()
163 .filter_map(|e| e.quality_score)
164 .sum::<f64>()
165 / self.cache.len() as f64
166 } else {
167 0.0
168 };
169
170 CacheStats {
171 total_entries,
172 total_usage,
173 total_tokens_saved,
174 avg_quality,
175 }
176 }
177
178 pub fn clear(&mut self) -> Result<(), PromptCacheError> {
180 if !self.config.enabled {
181 return Ok(());
182 }
183 self.cache.clear();
184 self.dirty = true;
185 self.save_cache()
186 }
187
188 pub fn hash_prompt(prompt: &str) -> String {
190 use sha2::{Digest, Sha256};
191 let mut hasher = Sha256::new();
192 hasher.update(prompt.as_bytes());
193 format!("{:x}", hasher.finalize())
194 }
195
196 pub fn save_cache(&self) -> Result<(), PromptCacheError> {
198 if !self.config.enabled || !self.dirty {
199 return Ok(());
200 }
201
202 fs::create_dir_all(&self.config.cache_dir).map_err(|e| PromptCacheError::Io(e))?;
204
205 let cache_path = self.config.cache_dir.join("prompt_cache.json");
206 let data = serde_json::to_string_pretty(&self.cache)
207 .map_err(|e| PromptCacheError::Serialization(e))?;
208
209 fs::write(cache_path, data).map_err(|e| PromptCacheError::Io(e))?;
210
211 Ok(())
212 }
213
214 fn load_cache(&mut self) -> Result<(), PromptCacheError> {
216 if !self.config.enabled {
217 return Ok(());
218 }
219 let cache_path = self.config.cache_dir.join("prompt_cache.json");
220
221 if !cache_path.exists() {
222 return Ok(());
223 }
224
225 let data = fs::read_to_string(cache_path).map_err(|e| PromptCacheError::Io(e))?;
226
227 self.cache = serde_json::from_str(&data).map_err(|e| PromptCacheError::Serialization(e))?;
228
229 Ok(())
230 }
231
232 fn cleanup_expired(&mut self) -> Result<(), PromptCacheError> {
234 if !self.config.enabled {
235 return Ok(());
236 }
237 let now = Self::current_timestamp();
238 let max_age_seconds = self.config.max_age_days * 24 * 60 * 60;
239
240 self.cache
241 .retain(|_, entry| now - entry.created_at < max_age_seconds);
242
243 self.dirty = true;
244 Ok(())
245 }
246
247 fn evict_oldest(&mut self) -> Result<(), PromptCacheError> {
249 if !self.config.enabled {
250 return Ok(());
251 }
252 if self.cache.is_empty() {
253 return Ok(());
254 }
255
256 let oldest_key = self
258 .cache
259 .iter()
260 .min_by_key(|(_, entry)| entry.last_used)
261 .map(|(key, _)| key.clone())
262 .unwrap();
263
264 self.cache.remove(&oldest_key);
265 self.dirty = true;
266
267 Ok(())
268 }
269
270 fn current_timestamp() -> u64 {
272 SystemTime::now()
273 .duration_since(UNIX_EPOCH)
274 .unwrap()
275 .as_secs()
276 }
277}
278
279impl Drop for PromptCache {
280 fn drop(&mut self) {
281 let _ = self.save_cache();
282 }
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct CacheStats {
288 pub total_entries: usize,
289 pub total_usage: u32,
290 pub total_tokens_saved: u32,
291 pub avg_quality: f64,
292}
293
294impl Default for CacheStats {
295 fn default() -> Self {
296 Self {
297 total_entries: 0,
298 total_usage: 0,
299 total_tokens_saved: 0,
300 avg_quality: 0.0,
301 }
302 }
303}
304
305#[derive(Debug, thiserror::Error)]
307pub enum PromptCacheError {
308 #[error("IO error: {0}")]
309 Io(#[from] std::io::Error),
310
311 #[error("Serialization error: {0}")]
312 Serialization(#[from] serde_json::Error),
313
314 #[error("Cache full")]
315 CacheFull,
316}
317
318pub struct PromptOptimizer {
320 cache: PromptCache,
321 llm_provider: Box<dyn crate::llm::provider::LLMProvider>,
322}
323
324impl PromptOptimizer {
325 pub fn new(llm_provider: Box<dyn crate::llm::provider::LLMProvider>) -> Self {
326 Self {
327 cache: PromptCache::new(),
328 llm_provider,
329 }
330 }
331
332 pub fn with_cache(mut self, cache: PromptCache) -> Self {
333 self.cache = cache;
334 self
335 }
336
337 pub async fn optimize_prompt(
339 &mut self,
340 original_prompt: &str,
341 target_model: &str,
342 context: Option<&str>,
343 ) -> Result<String, PromptOptimizationError> {
344 let prompt_hash = PromptCache::hash_prompt(original_prompt);
345
346 if let Some(cached) = self.cache.get(&prompt_hash) {
348 return Ok(cached.optimized_prompt.clone());
349 }
350
351 let optimized = self
353 .generate_optimized_prompt(original_prompt, target_model, context)
354 .await?;
355
356 let original_tokens = Self::estimate_tokens(original_prompt);
358 let optimized_tokens = Self::estimate_tokens(&optimized);
359 let tokens_saved = original_tokens.saturating_sub(optimized_tokens);
360
361 let entry = CachedPrompt {
363 prompt_hash: prompt_hash.clone(),
364 original_prompt: original_prompt.to_string(),
365 optimized_prompt: optimized.clone(),
366 model_used: target_model.to_string(),
367 tokens_saved: Some(tokens_saved),
368 quality_score: Some(0.8), created_at: PromptCache::current_timestamp(),
370 last_used: PromptCache::current_timestamp(),
371 usage_count: 1,
372 };
373
374 self.cache.put(entry)?;
376
377 Ok(optimized)
378 }
379
380 async fn generate_optimized_prompt(
382 &self,
383 original_prompt: &str,
384 target_model: &str,
385 context: Option<&str>,
386 ) -> Result<String, PromptOptimizationError> {
387 let system_prompt = format!(
388 "You are an expert prompt engineer. Your task is to optimize prompts for {} \
389 to make them more effective, clearer, and more likely to produce high-quality responses. \
390 Focus on improving clarity, specificity, structure, and effectiveness while preserving \
391 the original intent and requirements.",
392 target_model
393 );
394
395 let mut user_prompt = format!(
396 "Please optimize the following prompt for {}:\n\nORIGINAL PROMPT:\n{}\n\n",
397 target_model, original_prompt
398 );
399
400 if let Some(ctx) = context {
401 user_prompt.push_str(&format!("CONTEXT:\n{}\n\n", ctx));
402 }
403
404 user_prompt.push_str(
405 "OPTIMIZATION REQUIREMENTS:\n\
406 1. Make the prompt clearer and more specific\n\
407 2. Improve structure and formatting\n\
408 3. Add relevant context or examples if helpful\n\
409 4. Ensure the prompt is appropriate for the target model\n\
410 5. Maintain the original intent and requirements\n\
411 6. Keep the optimized prompt concise but comprehensive\n\n\
412 Provide only the optimized prompt without any explanation or additional text.",
413 );
414
415 let request = crate::llm::provider::LLMRequest {
416 messages: vec![
417 Message {
418 role: MessageRole::System,
419 content: system_prompt,
420 tool_calls: None,
421 tool_call_id: None,
422 },
423 Message {
424 role: MessageRole::User,
425 content: user_prompt,
426 tool_calls: None,
427 tool_call_id: None,
428 },
429 ],
430 system_prompt: None,
431 tools: None,
432 model: target_model.to_string(),
433 max_tokens: Some(2000),
434 temperature: Some(0.3),
435 stream: false,
436 tool_choice: None,
437 parallel_tool_calls: None,
438 parallel_tool_config: None,
439 reasoning_effort: None,
440 };
441
442 let response = self
443 .llm_provider
444 .generate(request)
445 .await
446 .map_err(|e| PromptOptimizationError::LLMError(e.to_string()))?;
447
448 Ok(response
449 .content
450 .unwrap_or_else(|| original_prompt.to_string()))
451 }
452
453 fn estimate_tokens(text: &str) -> u32 {
455 (text.len() / 4) as u32
457 }
458
459 pub fn cache_stats(&self) -> CacheStats {
461 self.cache.stats()
462 }
463
464 pub fn clear_cache(&mut self) -> Result<(), PromptCacheError> {
466 self.cache.clear()
467 }
468}
469
470#[derive(Debug, thiserror::Error)]
472pub enum PromptOptimizationError {
473 #[error("LLM error: {0}")]
474 LLMError(String),
475
476 #[error("Cache error: {0}")]
477 CacheError(#[from] PromptCacheError),
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_prompt_hash() {
486 let prompt = "Test prompt";
487 let hash1 = PromptCache::hash_prompt(prompt);
488 let hash2 = PromptCache::hash_prompt(prompt);
489 assert_eq!(hash1, hash2);
490 assert!(!hash1.is_empty());
491 }
492
493 #[test]
494 fn test_cache_operations() {
495 let mut cache = PromptCache::new();
496
497 let entry = CachedPrompt {
498 prompt_hash: "test_hash".to_string(),
499 original_prompt: "original".to_string(),
500 optimized_prompt: "optimized".to_string(),
501 model_used: crate::config::constants::models::GEMINI_2_5_FLASH.to_string(),
502 tokens_saved: Some(100),
503 quality_score: Some(0.9),
504 created_at: 1000,
505 last_used: 1000,
506 usage_count: 0,
507 };
508
509 cache.put(entry).unwrap();
510 assert!(cache.contains("test_hash"));
511
512 let retrieved = cache.get("test_hash");
513 assert!(retrieved.is_some());
514 assert_eq!(retrieved.unwrap().usage_count, 1);
515 }
516
517 #[test]
518 fn disabled_cache_config_is_no_op() {
519 let settings = PromptCachingConfig {
520 enabled: false,
521 cache_dir: "relative/cache".to_string(),
522 ..PromptCachingConfig::default()
523 };
524 let cfg = PromptCacheConfig::from_settings(&settings, None);
525 assert!(!cfg.is_enabled());
526
527 let mut cache = PromptCache::with_config(cfg);
528 assert!(!cache.contains("missing"));
529 assert_eq!(cache.stats().total_entries, 0);
530
531 let entry = CachedPrompt {
532 prompt_hash: "noop".to_string(),
533 original_prompt: "original".to_string(),
534 optimized_prompt: "optimized".to_string(),
535 model_used: crate::config::constants::models::GEMINI_2_5_FLASH.to_string(),
536 tokens_saved: Some(10),
537 quality_score: Some(0.9),
538 created_at: 1,
539 last_used: 1,
540 usage_count: 0,
541 };
542
543 cache.put(entry).unwrap();
544 assert!(!cache.contains("noop"));
545 assert_eq!(cache.stats().total_entries, 0);
546 }
547}