umi_memory/dst/
llm.rs

1//! SimLLM - Deterministic LLM Simulation
2//!
3//! TigerStyle: Deterministic LLM responses for simulation testing.
4//!
5//! See ADR-012 for design rationale.
6
7use std::sync::{Arc, Mutex};
8
9use serde::de::DeserializeOwned;
10use serde_json::json;
11
12use super::clock::SimClock;
13use super::fault::{FaultInjector, FaultType};
14use super::rng::DeterministicRng;
15use crate::constants::{
16    LLM_ENTITIES_COUNT_MAX, LLM_LATENCY_MS_DEFAULT, LLM_LATENCY_MS_MAX, LLM_LATENCY_MS_MIN,
17    LLM_PROMPT_BYTES_MAX, LLM_QUERY_REWRITES_COUNT_MAX, LLM_RESPONSE_BYTES_MAX,
18};
19
20// =============================================================================
21// Error Types
22// =============================================================================
23
24/// Errors from LLM operations.
25///
26/// TigerStyle: Explicit error variants for all failure modes.
27#[derive(Debug, Clone, thiserror::Error)]
28pub enum LLMError {
29    /// Request timed out
30    #[error("LLM request timed out")]
31    Timeout,
32
33    /// Rate limit exceeded
34    #[error("Rate limit exceeded")]
35    RateLimit,
36
37    /// Context/prompt too long
38    #[error("Context length exceeded: {0} bytes")]
39    ContextOverflow(usize),
40
41    /// Response format invalid
42    #[error("Invalid response format: {0}")]
43    InvalidResponse(String),
44
45    /// Service unavailable
46    #[error("Service unavailable")]
47    ServiceUnavailable,
48
49    /// JSON serialization/deserialization error
50    #[error("JSON error: {0}")]
51    JsonError(String),
52
53    /// Prompt validation failed
54    #[error("Invalid prompt: {0}")]
55    InvalidPrompt(String),
56}
57
58// =============================================================================
59// SimLLM
60// =============================================================================
61
62/// Common names for entity extraction simulation.
63const COMMON_NAMES: &[&str] = &[
64    "Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack",
65];
66
67/// Common organizations for entity extraction simulation.
68const COMMON_ORGS: &[&str] = &[
69    "Acme",
70    "Google",
71    "Microsoft",
72    "Apple",
73    "Amazon",
74    "OpenAI",
75    "Anthropic",
76];
77
78/// Simulated LLM for deterministic testing.
79///
80/// TigerStyle:
81/// - Deterministic responses via seeded RNG
82/// - Prompt routing to domain-specific generators
83/// - Fault injection integration
84/// - Thread-safe via `Mutex` for use in async contexts
85///
86/// # Example
87///
88/// ```rust
89/// use umi_memory::dst::{SimLLM, SimClock, DeterministicRng, FaultInjector};
90/// use std::sync::Arc;
91///
92/// let clock = SimClock::new();
93/// let rng = DeterministicRng::new(42);
94/// let faults = Arc::new(FaultInjector::new(DeterministicRng::new(42)));
95/// let llm = SimLLM::new(clock, rng, faults);
96///
97/// // Same seed = same response
98/// ```
99#[derive(Debug, Clone)]
100pub struct SimLLM {
101    /// Simulated clock for latency
102    clock: SimClock,
103    /// RNG with thread-safe interior mutability (Arc for Clone)
104    rng: Arc<Mutex<DeterministicRng>>,
105    /// Shared fault injector
106    fault_injector: Arc<FaultInjector>,
107    /// Base latency for simulated responses
108    base_latency_ms: u64,
109    /// Whether to simulate latency (disable for tests without time advancement)
110    simulate_latency_enabled: bool,
111}
112
113impl SimLLM {
114    /// Create a new SimLLM.
115    ///
116    /// # Arguments
117    /// - `clock`: Simulated clock for latency
118    /// - `rng`: Deterministic RNG for response generation
119    /// - `fault_injector`: Shared fault injector
120    #[must_use]
121    pub fn new(clock: SimClock, rng: DeterministicRng, fault_injector: Arc<FaultInjector>) -> Self {
122        Self {
123            clock,
124            rng: Arc::new(Mutex::new(rng)),
125            fault_injector,
126            base_latency_ms: LLM_LATENCY_MS_DEFAULT,
127            simulate_latency_enabled: true,
128        }
129    }
130
131    /// Disable latency simulation (useful for tests without time advancement).
132    ///
133    /// By default, SimLLM simulates latency using the clock. This blocks if
134    /// the clock isn't being advanced. Use this method to disable latency
135    /// for simple tests.
136    #[must_use]
137    pub fn without_latency(mut self) -> Self {
138        self.simulate_latency_enabled = false;
139        self
140    }
141
142    /// Set base latency for simulated responses.
143    ///
144    /// # Panics
145    /// Panics if latency is outside valid range.
146    #[must_use]
147    pub fn with_latency(mut self, latency_ms: u64) -> Self {
148        // Precondition
149        assert!(
150            latency_ms >= LLM_LATENCY_MS_MIN && latency_ms <= LLM_LATENCY_MS_MAX,
151            "latency must be in [{}, {}], got {}",
152            LLM_LATENCY_MS_MIN,
153            LLM_LATENCY_MS_MAX,
154            latency_ms
155        );
156
157        self.base_latency_ms = latency_ms;
158        self
159    }
160
161    /// Complete a prompt with a deterministic response.
162    ///
163    /// # Errors
164    /// Returns `LLMError` on fault injection or validation failure.
165    ///
166    /// # Panics
167    /// Debug panics on precondition/postcondition violations.
168    pub async fn complete(&self, prompt: &str) -> Result<String, LLMError> {
169        // Preconditions (runtime checks - return errors for recoverable cases)
170        if prompt.is_empty() {
171            return Err(LLMError::InvalidPrompt("prompt must not be empty".into()));
172        }
173        if prompt.len() > LLM_PROMPT_BYTES_MAX {
174            return Err(LLMError::ContextOverflow(prompt.len()));
175        }
176
177        // Check for faults
178        if let Some(fault) = self.fault_injector.should_inject("llm_complete") {
179            return Err(self.fault_to_error(fault));
180        }
181
182        // Simulate latency
183        self.simulate_latency().await;
184
185        // Route prompt to appropriate generator
186        let response = self.route_prompt(prompt);
187
188        // Postconditions
189        debug_assert!(!response.is_empty(), "response must not be empty");
190        debug_assert!(
191            response.len() <= LLM_RESPONSE_BYTES_MAX,
192            "response exceeds limit"
193        );
194
195        Ok(response)
196    }
197
198    /// Complete a prompt expecting a JSON response.
199    ///
200    /// # Errors
201    /// Returns `LLMError` on fault injection, validation, or JSON parse failure.
202    pub async fn complete_json<T: DeserializeOwned>(&self, prompt: &str) -> Result<T, LLMError> {
203        let response = self.complete(prompt).await?;
204
205        serde_json::from_str(&response)
206            .map_err(|e| LLMError::JsonError(format!("Failed to parse JSON: {}", e)))
207    }
208
209    /// Route prompt to the appropriate generator based on content.
210    fn route_prompt(&self, prompt: &str) -> String {
211        let prompt_lower = prompt.to_lowercase();
212
213        if prompt_lower.contains("extract") && prompt_lower.contains("entit") {
214            self.sim_entity_extraction(prompt)
215        } else if prompt_lower.contains("rewrite") && prompt_lower.contains("query") {
216            self.sim_query_rewrite(prompt)
217        } else if prompt_lower.contains("detect") && prompt_lower.contains("evolution") {
218            self.sim_evolution_detection(prompt)
219        } else if prompt_lower.contains("detect")
220            && (prompt_lower.contains("relation") || prompt_lower.contains("relationship"))
221        {
222            self.sim_relation_detection(prompt)
223        } else {
224            self.sim_generic(prompt)
225        }
226    }
227
228    /// Simulate entity extraction response.
229    fn sim_entity_extraction(&self, prompt: &str) -> String {
230        let mut entities = Vec::new();
231        let mut rng = self.rng.lock().unwrap();
232
233        // Detect common names in prompt
234        for name in COMMON_NAMES {
235            if prompt.to_uppercase().contains(&name.to_uppercase()) {
236                if entities.len() >= LLM_ENTITIES_COUNT_MAX {
237                    break;
238                }
239                entities.push(json!({
240                    "name": name,
241                    "entity_type": "person",
242                    "content": format!("Information about {}", name),
243                    "confidence": 0.7 + rng.next_float() * 0.3,
244                }));
245            }
246        }
247
248        // Detect common organizations in prompt
249        for org in COMMON_ORGS {
250            if prompt.to_uppercase().contains(&org.to_uppercase()) {
251                if entities.len() >= LLM_ENTITIES_COUNT_MAX {
252                    break;
253                }
254                entities.push(json!({
255                    "name": org,
256                    "entity_type": "organization",
257                    "content": format!("Organization: {}", org),
258                    "confidence": 0.8 + rng.next_float() * 0.2,
259                }));
260            }
261        }
262
263        // Fallback to note entity if nothing found
264        if entities.is_empty() {
265            let hash = self.prompt_hash(prompt);
266            let snippet = &prompt[..100.min(prompt.len())];
267            entities.push(json!({
268                "name": format!("Note_{}", hash),
269                "entity_type": "note",
270                "content": snippet,
271                "confidence": 0.5 + rng.next_float() * 0.3,
272            }));
273        }
274
275        serde_json::to_string(&json!({
276            "entities": entities,
277            "relations": [],
278        }))
279        .unwrap()
280    }
281
282    /// Simulate query rewrite response.
283    fn sim_query_rewrite(&self, prompt: &str) -> String {
284        let mut rng = self.rng.lock().unwrap();
285
286        // Extract the actual query from the prompt (simple heuristic)
287        let query = prompt
288            .lines()
289            .find(|line| line.trim().starts_with("Query:") || line.trim().starts_with("query:"))
290            .map(|line| {
291                line.trim_start_matches("Query:")
292                    .trim_start_matches("query:")
293                    .trim()
294            })
295            .unwrap_or(&prompt[..50.min(prompt.len())]);
296
297        // Generate variations
298        let num_rewrites = rng.next_usize(2, LLM_QUERY_REWRITES_COUNT_MAX);
299        let mut rewrites = vec![query.to_string()];
300
301        let prefixes = [
302            "What is",
303            "Tell me about",
304            "Information on",
305            "Details about",
306        ];
307        let suffixes = ["?", " please", " in detail", ""];
308
309        for _ in 0..num_rewrites - 1 {
310            let prefix = prefixes[rng.next_usize(0, prefixes.len() - 1)];
311            let suffix = suffixes[rng.next_usize(0, suffixes.len() - 1)];
312            rewrites.push(format!("{} {}{}", prefix, query, suffix));
313        }
314
315        serde_json::to_string(&json!({
316            "queries": rewrites,
317        }))
318        .unwrap()
319    }
320
321    /// Simulate evolution detection response.
322    fn sim_evolution_detection(&self, prompt: &str) -> String {
323        let mut rng = self.rng.lock().unwrap();
324
325        // Weighted evolution types (update most common)
326        let evolution_types = [
327            ("update", 0.4),
328            ("extend", 0.3),
329            ("derive", 0.2),
330            ("contradict", 0.1),
331        ];
332
333        let roll = rng.next_float();
334        let mut cumulative = 0.0;
335        let mut selected_type = "update";
336
337        for (etype, weight) in &evolution_types {
338            cumulative += weight;
339            if roll < cumulative {
340                selected_type = etype;
341                break;
342            }
343        }
344
345        // Sometimes no evolution detected
346        if rng.next_bool(0.3) {
347            return serde_json::to_string(&json!({
348                "detected": false,
349                "evolution_type": null,
350                "reason": null,
351                "confidence": 0.0,
352            }))
353            .unwrap();
354        }
355
356        let reasons = match selected_type {
357            "update" => vec![
358                "New information replaces outdated data",
359                "Values have been updated",
360                "Status has changed",
361            ],
362            "extend" => vec![
363                "Additional details provided",
364                "New attributes added",
365                "Information expanded",
366            ],
367            "derive" => vec![
368                "Conclusion drawn from existing data",
369                "Inference based on prior knowledge",
370                "Logically follows from previous entity",
371            ],
372            "contradict" => vec![
373                "Information conflicts with existing record",
374                "Inconsistent values detected",
375                "Contradictory statement found",
376            ],
377            _ => vec!["Evolution detected"],
378        };
379
380        let reason = reasons[rng.next_usize(0, reasons.len() - 1)];
381        let confidence = 0.6 + rng.next_float() * 0.4;
382
383        // Extract entity names from prompt if present (for source/target)
384        let hash = self.prompt_hash(prompt);
385
386        serde_json::to_string(&json!({
387            "detected": true,
388            "evolution_type": selected_type,
389            "source_id": format!("entity_{}", hash % 1000),
390            "target_id": format!("entity_{}", (hash / 1000) % 1000),
391            "reason": reason,
392            "confidence": confidence,
393        }))
394        .unwrap()
395    }
396
397    /// Simulate relation detection response.
398    fn sim_relation_detection(&self, prompt: &str) -> String {
399        let mut rng = self.rng.lock().unwrap();
400
401        // Sometimes no relation detected
402        if rng.next_bool(0.4) {
403            return serde_json::to_string(&json!({
404                "relations": [],
405            }))
406            .unwrap();
407        }
408
409        let relation_types = [
410            "works_at",
411            "knows",
412            "located_in",
413            "part_of",
414            "created_by",
415            "related_to",
416        ];
417
418        let num_relations = rng.next_usize(1, 3);
419        let mut relations = Vec::new();
420        let hash = self.prompt_hash(prompt);
421
422        for i in 0..num_relations {
423            let rel_type = relation_types[rng.next_usize(0, relation_types.len() - 1)];
424            relations.push(json!({
425                "source": format!("entity_{}", (hash + i as u64) % 100),
426                "target": format!("entity_{}", (hash + i as u64 + 50) % 100),
427                "relation_type": rel_type,
428                "confidence": 0.5 + rng.next_float() * 0.5,
429            }));
430        }
431
432        serde_json::to_string(&json!({
433            "relations": relations,
434        }))
435        .unwrap()
436    }
437
438    /// Generic response for unrecognized prompts.
439    fn sim_generic(&self, prompt: &str) -> String {
440        let hash = self.prompt_hash(prompt);
441        let mut rng = self.rng.lock().unwrap();
442
443        let responses = [
444            "Acknowledged.",
445            "Understood.",
446            "Processing complete.",
447            "Request handled.",
448            "Task completed successfully.",
449        ];
450
451        let response = responses[rng.next_usize(0, responses.len() - 1)];
452
453        serde_json::to_string(&json!({
454            "response": response,
455            "prompt_hash": hash,
456            "success": true,
457        }))
458        .unwrap()
459    }
460
461    /// Generate a deterministic hash from prompt.
462    fn prompt_hash(&self, prompt: &str) -> u64 {
463        // Simple FNV-1a hash for determinism
464        let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
465        for byte in prompt.bytes() {
466            hash ^= u64::from(byte);
467            hash = hash.wrapping_mul(0x0100_0000_01b3);
468        }
469        hash
470    }
471
472    /// Simulate latency using the clock.
473    async fn simulate_latency(&self) {
474        if !self.simulate_latency_enabled {
475            return;
476        }
477
478        let jitter = {
479            let mut rng = self.rng.lock().unwrap();
480            rng.next_usize(0, 50) as u64
481        };
482        let latency = self.base_latency_ms + jitter;
483        self.clock.sleep_ms(latency).await;
484    }
485
486    /// Convert fault type to LLM error.
487    fn fault_to_error(&self, fault: FaultType) -> LLMError {
488        match fault {
489            FaultType::LlmTimeout => LLMError::Timeout,
490            FaultType::LlmRateLimit => LLMError::RateLimit,
491            FaultType::LlmContextOverflow => LLMError::ContextOverflow(0),
492            FaultType::LlmInvalidResponse => {
493                LLMError::InvalidResponse("Simulated invalid response".into())
494            }
495            FaultType::LlmServiceUnavailable => LLMError::ServiceUnavailable,
496            // Map network faults to service unavailable
497            FaultType::NetworkTimeout | FaultType::NetworkConnectionRefused => {
498                LLMError::ServiceUnavailable
499            }
500            // Default mapping
501            _ => LLMError::ServiceUnavailable,
502        }
503    }
504
505    /// Get the current seed (for debugging).
506    #[must_use]
507    pub fn seed(&self) -> u64 {
508        self.rng.lock().unwrap().seed()
509    }
510}
511
512// =============================================================================
513// Tests
514// =============================================================================
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use crate::dst::fault::FaultConfig;
520
521    fn create_test_llm(seed: u64) -> SimLLM {
522        let clock = SimClock::new();
523        let rng = DeterministicRng::new(seed);
524        let faults = Arc::new(FaultInjector::new(DeterministicRng::new(seed)));
525        SimLLM::new(clock, rng, faults).without_latency()
526    }
527
528    #[tokio::test]
529    async fn test_determinism() {
530        let llm1 = create_test_llm(42);
531        let llm2 = create_test_llm(42);
532
533        let prompt = "Extract entities from: Alice works at Acme Corp.";
534
535        let response1 = llm1.complete(prompt).await.unwrap();
536        let response2 = llm2.complete(prompt).await.unwrap();
537
538        assert_eq!(
539            response1, response2,
540            "Same seed should produce same response"
541        );
542    }
543
544    #[tokio::test]
545    async fn test_different_seeds_different_responses() {
546        let llm1 = create_test_llm(42);
547        let llm2 = create_test_llm(12345);
548
549        let prompt = "Extract entities from: Bob met Charlie at Google.";
550
551        let response1 = llm1.complete(prompt).await.unwrap();
552        let response2 = llm2.complete(prompt).await.unwrap();
553
554        // Responses may still be similar due to pattern matching,
555        // but confidence values should differ
556        assert!(response1.contains("Bob") || response1.contains("Charlie"));
557        assert!(response2.contains("Bob") || response2.contains("Charlie"));
558    }
559
560    #[tokio::test]
561    async fn test_entity_extraction_routing() {
562        let llm = create_test_llm(42);
563
564        let prompt = "Extract entities from the following text: Alice and Bob work at Microsoft.";
565        let response = llm.complete(prompt).await.unwrap();
566
567        assert!(response.contains("entities"));
568        assert!(response.contains("Alice") || response.contains("Bob"));
569    }
570
571    #[tokio::test]
572    async fn test_query_rewrite_routing() {
573        let llm = create_test_llm(42);
574
575        let prompt =
576            "Rewrite the following query for better search:\nQuery: what is rust programming";
577        let response = llm.complete(prompt).await.unwrap();
578
579        assert!(response.contains("queries"));
580    }
581
582    #[tokio::test]
583    async fn test_evolution_detection_routing() {
584        let llm = create_test_llm(42);
585
586        let prompt = "Detect evolution relationship between:\nOld: Alice is 25\nNew: Alice is 26";
587        let response = llm.complete(prompt).await.unwrap();
588
589        assert!(response.contains("evolution_type") || response.contains("detected"));
590    }
591
592    #[tokio::test]
593    async fn test_generic_routing() {
594        let llm = create_test_llm(42);
595
596        let prompt = "Hello, how are you?";
597        let response = llm.complete(prompt).await.unwrap();
598
599        assert!(response.contains("response") || response.contains("success"));
600    }
601
602    #[tokio::test]
603    async fn test_empty_prompt_error() {
604        let llm = create_test_llm(42);
605
606        let result = llm.complete("").await;
607        assert!(matches!(result, Err(LLMError::InvalidPrompt(_))));
608    }
609
610    #[tokio::test]
611    async fn test_prompt_too_long_error() {
612        let llm = create_test_llm(42);
613
614        let long_prompt = "x".repeat(LLM_PROMPT_BYTES_MAX + 1);
615        let result = llm.complete(&long_prompt).await;
616
617        assert!(matches!(result, Err(LLMError::ContextOverflow(_))));
618    }
619
620    #[tokio::test]
621    async fn test_fault_injection_timeout() {
622        let clock = SimClock::new();
623        let rng = DeterministicRng::new(42);
624        let mut injector = FaultInjector::new(DeterministicRng::new(42));
625        injector.register(FaultConfig::new(FaultType::LlmTimeout, 1.0));
626        let faults = Arc::new(injector);
627
628        let llm = SimLLM::new(clock, rng, faults).without_latency();
629        let result = llm.complete("test prompt").await;
630
631        assert!(matches!(result, Err(LLMError::Timeout)));
632    }
633
634    #[tokio::test]
635    async fn test_fault_injection_rate_limit() {
636        let clock = SimClock::new();
637        let rng = DeterministicRng::new(42);
638        let mut injector = FaultInjector::new(DeterministicRng::new(42));
639        injector.register(FaultConfig::new(FaultType::LlmRateLimit, 1.0));
640        let faults = Arc::new(injector);
641
642        let llm = SimLLM::new(clock, rng, faults).without_latency();
643        let result = llm.complete("test prompt").await;
644
645        assert!(matches!(result, Err(LLMError::RateLimit)));
646    }
647
648    #[tokio::test]
649    async fn test_complete_json() {
650        let llm = create_test_llm(42);
651
652        #[derive(serde::Deserialize)]
653        struct GenericResponse {
654            response: String,
655            success: bool,
656        }
657
658        let prompt = "Hello, world!";
659        let result: GenericResponse = llm.complete_json(prompt).await.unwrap();
660
661        assert!(result.success);
662        assert!(!result.response.is_empty());
663    }
664
665    #[tokio::test]
666    async fn test_with_latency() {
667        let clock = SimClock::new();
668        let rng = DeterministicRng::new(42);
669        let faults = Arc::new(FaultInjector::new(DeterministicRng::new(42)));
670
671        let llm = SimLLM::new(clock.clone(), rng, faults).with_latency(500);
672
673        // Spawn a task to advance time while the LLM waits
674        let clock_for_advance = clock.clone();
675        let advance_handle = tokio::spawn(async move {
676            // Give the complete() call time to start waiting
677            tokio::task::yield_now().await;
678            // Advance time enough to cover latency + jitter (500 + up to 50)
679            clock_for_advance.advance_ms(600);
680        });
681
682        let start = clock.now_ms();
683        llm.complete("test").await.unwrap();
684        let end = clock.now_ms();
685
686        advance_handle.await.unwrap();
687
688        // Clock should have advanced (we advanced by 600ms)
689        assert!(
690            end >= start + 500,
691            "Expected clock to advance at least 500ms, start={}, end={}",
692            start,
693            end
694        );
695    }
696
697    #[test]
698    fn test_prompt_hash_determinism() {
699        let llm = create_test_llm(42);
700
701        let hash1 = llm.prompt_hash("test prompt");
702        let hash2 = llm.prompt_hash("test prompt");
703        let hash3 = llm.prompt_hash("different prompt");
704
705        assert_eq!(hash1, hash2);
706        assert_ne!(hash1, hash3);
707    }
708
709    #[test]
710    #[should_panic(expected = "latency must be in")]
711    fn test_invalid_latency() {
712        let clock = SimClock::new();
713        let rng = DeterministicRng::new(42);
714        let faults = Arc::new(FaultInjector::new(DeterministicRng::new(42)));
715
716        let _ = SimLLM::new(clock, rng, faults).with_latency(999999);
717    }
718}