1use std::sync::{Arc, Mutex};
8
9use serde::de::DeserializeOwned;
10use serde_json::json;
11
12use super::clock::SimClock;
13use super::fault::{FaultInjector, FaultType};
14use super::rng::DeterministicRng;
15use crate::constants::{
16 LLM_ENTITIES_COUNT_MAX, LLM_LATENCY_MS_DEFAULT, LLM_LATENCY_MS_MAX, LLM_LATENCY_MS_MIN,
17 LLM_PROMPT_BYTES_MAX, LLM_QUERY_REWRITES_COUNT_MAX, LLM_RESPONSE_BYTES_MAX,
18};
19
20#[derive(Debug, Clone, thiserror::Error)]
28pub enum LLMError {
29 #[error("LLM request timed out")]
31 Timeout,
32
33 #[error("Rate limit exceeded")]
35 RateLimit,
36
37 #[error("Context length exceeded: {0} bytes")]
39 ContextOverflow(usize),
40
41 #[error("Invalid response format: {0}")]
43 InvalidResponse(String),
44
45 #[error("Service unavailable")]
47 ServiceUnavailable,
48
49 #[error("JSON error: {0}")]
51 JsonError(String),
52
53 #[error("Invalid prompt: {0}")]
55 InvalidPrompt(String),
56}
57
58const COMMON_NAMES: &[&str] = &[
64 "Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Henry", "Ivy", "Jack",
65];
66
67const COMMON_ORGS: &[&str] = &[
69 "Acme",
70 "Google",
71 "Microsoft",
72 "Apple",
73 "Amazon",
74 "OpenAI",
75 "Anthropic",
76];
77
78#[derive(Debug, Clone)]
100pub struct SimLLM {
101 clock: SimClock,
103 rng: Arc<Mutex<DeterministicRng>>,
105 fault_injector: Arc<FaultInjector>,
107 base_latency_ms: u64,
109 simulate_latency_enabled: bool,
111}
112
113impl SimLLM {
114 #[must_use]
121 pub fn new(clock: SimClock, rng: DeterministicRng, fault_injector: Arc<FaultInjector>) -> Self {
122 Self {
123 clock,
124 rng: Arc::new(Mutex::new(rng)),
125 fault_injector,
126 base_latency_ms: LLM_LATENCY_MS_DEFAULT,
127 simulate_latency_enabled: true,
128 }
129 }
130
131 #[must_use]
137 pub fn without_latency(mut self) -> Self {
138 self.simulate_latency_enabled = false;
139 self
140 }
141
142 #[must_use]
147 pub fn with_latency(mut self, latency_ms: u64) -> Self {
148 assert!(
150 latency_ms >= LLM_LATENCY_MS_MIN && latency_ms <= LLM_LATENCY_MS_MAX,
151 "latency must be in [{}, {}], got {}",
152 LLM_LATENCY_MS_MIN,
153 LLM_LATENCY_MS_MAX,
154 latency_ms
155 );
156
157 self.base_latency_ms = latency_ms;
158 self
159 }
160
161 pub async fn complete(&self, prompt: &str) -> Result<String, LLMError> {
169 if prompt.is_empty() {
171 return Err(LLMError::InvalidPrompt("prompt must not be empty".into()));
172 }
173 if prompt.len() > LLM_PROMPT_BYTES_MAX {
174 return Err(LLMError::ContextOverflow(prompt.len()));
175 }
176
177 if let Some(fault) = self.fault_injector.should_inject("llm_complete") {
179 return Err(self.fault_to_error(fault));
180 }
181
182 self.simulate_latency().await;
184
185 let response = self.route_prompt(prompt);
187
188 debug_assert!(!response.is_empty(), "response must not be empty");
190 debug_assert!(
191 response.len() <= LLM_RESPONSE_BYTES_MAX,
192 "response exceeds limit"
193 );
194
195 Ok(response)
196 }
197
198 pub async fn complete_json<T: DeserializeOwned>(&self, prompt: &str) -> Result<T, LLMError> {
203 let response = self.complete(prompt).await?;
204
205 serde_json::from_str(&response)
206 .map_err(|e| LLMError::JsonError(format!("Failed to parse JSON: {}", e)))
207 }
208
209 fn route_prompt(&self, prompt: &str) -> String {
211 let prompt_lower = prompt.to_lowercase();
212
213 if prompt_lower.contains("extract") && prompt_lower.contains("entit") {
214 self.sim_entity_extraction(prompt)
215 } else if prompt_lower.contains("rewrite") && prompt_lower.contains("query") {
216 self.sim_query_rewrite(prompt)
217 } else if prompt_lower.contains("detect") && prompt_lower.contains("evolution") {
218 self.sim_evolution_detection(prompt)
219 } else if prompt_lower.contains("detect")
220 && (prompt_lower.contains("relation") || prompt_lower.contains("relationship"))
221 {
222 self.sim_relation_detection(prompt)
223 } else {
224 self.sim_generic(prompt)
225 }
226 }
227
228 fn sim_entity_extraction(&self, prompt: &str) -> String {
230 let mut entities = Vec::new();
231 let mut rng = self.rng.lock().unwrap();
232
233 for name in COMMON_NAMES {
235 if prompt.to_uppercase().contains(&name.to_uppercase()) {
236 if entities.len() >= LLM_ENTITIES_COUNT_MAX {
237 break;
238 }
239 entities.push(json!({
240 "name": name,
241 "entity_type": "person",
242 "content": format!("Information about {}", name),
243 "confidence": 0.7 + rng.next_float() * 0.3,
244 }));
245 }
246 }
247
248 for org in COMMON_ORGS {
250 if prompt.to_uppercase().contains(&org.to_uppercase()) {
251 if entities.len() >= LLM_ENTITIES_COUNT_MAX {
252 break;
253 }
254 entities.push(json!({
255 "name": org,
256 "entity_type": "organization",
257 "content": format!("Organization: {}", org),
258 "confidence": 0.8 + rng.next_float() * 0.2,
259 }));
260 }
261 }
262
263 if entities.is_empty() {
265 let hash = self.prompt_hash(prompt);
266 let snippet = &prompt[..100.min(prompt.len())];
267 entities.push(json!({
268 "name": format!("Note_{}", hash),
269 "entity_type": "note",
270 "content": snippet,
271 "confidence": 0.5 + rng.next_float() * 0.3,
272 }));
273 }
274
275 serde_json::to_string(&json!({
276 "entities": entities,
277 "relations": [],
278 }))
279 .unwrap()
280 }
281
282 fn sim_query_rewrite(&self, prompt: &str) -> String {
284 let mut rng = self.rng.lock().unwrap();
285
286 let query = prompt
288 .lines()
289 .find(|line| line.trim().starts_with("Query:") || line.trim().starts_with("query:"))
290 .map(|line| {
291 line.trim_start_matches("Query:")
292 .trim_start_matches("query:")
293 .trim()
294 })
295 .unwrap_or(&prompt[..50.min(prompt.len())]);
296
297 let num_rewrites = rng.next_usize(2, LLM_QUERY_REWRITES_COUNT_MAX);
299 let mut rewrites = vec![query.to_string()];
300
301 let prefixes = [
302 "What is",
303 "Tell me about",
304 "Information on",
305 "Details about",
306 ];
307 let suffixes = ["?", " please", " in detail", ""];
308
309 for _ in 0..num_rewrites - 1 {
310 let prefix = prefixes[rng.next_usize(0, prefixes.len() - 1)];
311 let suffix = suffixes[rng.next_usize(0, suffixes.len() - 1)];
312 rewrites.push(format!("{} {}{}", prefix, query, suffix));
313 }
314
315 serde_json::to_string(&json!({
316 "queries": rewrites,
317 }))
318 .unwrap()
319 }
320
321 fn sim_evolution_detection(&self, prompt: &str) -> String {
323 let mut rng = self.rng.lock().unwrap();
324
325 let evolution_types = [
327 ("update", 0.4),
328 ("extend", 0.3),
329 ("derive", 0.2),
330 ("contradict", 0.1),
331 ];
332
333 let roll = rng.next_float();
334 let mut cumulative = 0.0;
335 let mut selected_type = "update";
336
337 for (etype, weight) in &evolution_types {
338 cumulative += weight;
339 if roll < cumulative {
340 selected_type = etype;
341 break;
342 }
343 }
344
345 if rng.next_bool(0.3) {
347 return serde_json::to_string(&json!({
348 "detected": false,
349 "evolution_type": null,
350 "reason": null,
351 "confidence": 0.0,
352 }))
353 .unwrap();
354 }
355
356 let reasons = match selected_type {
357 "update" => vec![
358 "New information replaces outdated data",
359 "Values have been updated",
360 "Status has changed",
361 ],
362 "extend" => vec![
363 "Additional details provided",
364 "New attributes added",
365 "Information expanded",
366 ],
367 "derive" => vec![
368 "Conclusion drawn from existing data",
369 "Inference based on prior knowledge",
370 "Logically follows from previous entity",
371 ],
372 "contradict" => vec![
373 "Information conflicts with existing record",
374 "Inconsistent values detected",
375 "Contradictory statement found",
376 ],
377 _ => vec!["Evolution detected"],
378 };
379
380 let reason = reasons[rng.next_usize(0, reasons.len() - 1)];
381 let confidence = 0.6 + rng.next_float() * 0.4;
382
383 let hash = self.prompt_hash(prompt);
385
386 serde_json::to_string(&json!({
387 "detected": true,
388 "evolution_type": selected_type,
389 "source_id": format!("entity_{}", hash % 1000),
390 "target_id": format!("entity_{}", (hash / 1000) % 1000),
391 "reason": reason,
392 "confidence": confidence,
393 }))
394 .unwrap()
395 }
396
397 fn sim_relation_detection(&self, prompt: &str) -> String {
399 let mut rng = self.rng.lock().unwrap();
400
401 if rng.next_bool(0.4) {
403 return serde_json::to_string(&json!({
404 "relations": [],
405 }))
406 .unwrap();
407 }
408
409 let relation_types = [
410 "works_at",
411 "knows",
412 "located_in",
413 "part_of",
414 "created_by",
415 "related_to",
416 ];
417
418 let num_relations = rng.next_usize(1, 3);
419 let mut relations = Vec::new();
420 let hash = self.prompt_hash(prompt);
421
422 for i in 0..num_relations {
423 let rel_type = relation_types[rng.next_usize(0, relation_types.len() - 1)];
424 relations.push(json!({
425 "source": format!("entity_{}", (hash + i as u64) % 100),
426 "target": format!("entity_{}", (hash + i as u64 + 50) % 100),
427 "relation_type": rel_type,
428 "confidence": 0.5 + rng.next_float() * 0.5,
429 }));
430 }
431
432 serde_json::to_string(&json!({
433 "relations": relations,
434 }))
435 .unwrap()
436 }
437
438 fn sim_generic(&self, prompt: &str) -> String {
440 let hash = self.prompt_hash(prompt);
441 let mut rng = self.rng.lock().unwrap();
442
443 let responses = [
444 "Acknowledged.",
445 "Understood.",
446 "Processing complete.",
447 "Request handled.",
448 "Task completed successfully.",
449 ];
450
451 let response = responses[rng.next_usize(0, responses.len() - 1)];
452
453 serde_json::to_string(&json!({
454 "response": response,
455 "prompt_hash": hash,
456 "success": true,
457 }))
458 .unwrap()
459 }
460
461 fn prompt_hash(&self, prompt: &str) -> u64 {
463 let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
465 for byte in prompt.bytes() {
466 hash ^= u64::from(byte);
467 hash = hash.wrapping_mul(0x0100_0000_01b3);
468 }
469 hash
470 }
471
472 async fn simulate_latency(&self) {
474 if !self.simulate_latency_enabled {
475 return;
476 }
477
478 let jitter = {
479 let mut rng = self.rng.lock().unwrap();
480 rng.next_usize(0, 50) as u64
481 };
482 let latency = self.base_latency_ms + jitter;
483 self.clock.sleep_ms(latency).await;
484 }
485
486 fn fault_to_error(&self, fault: FaultType) -> LLMError {
488 match fault {
489 FaultType::LlmTimeout => LLMError::Timeout,
490 FaultType::LlmRateLimit => LLMError::RateLimit,
491 FaultType::LlmContextOverflow => LLMError::ContextOverflow(0),
492 FaultType::LlmInvalidResponse => {
493 LLMError::InvalidResponse("Simulated invalid response".into())
494 }
495 FaultType::LlmServiceUnavailable => LLMError::ServiceUnavailable,
496 FaultType::NetworkTimeout | FaultType::NetworkConnectionRefused => {
498 LLMError::ServiceUnavailable
499 }
500 _ => LLMError::ServiceUnavailable,
502 }
503 }
504
505 #[must_use]
507 pub fn seed(&self) -> u64 {
508 self.rng.lock().unwrap().seed()
509 }
510}
511
512#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::dst::fault::FaultConfig;
520
521 fn create_test_llm(seed: u64) -> SimLLM {
522 let clock = SimClock::new();
523 let rng = DeterministicRng::new(seed);
524 let faults = Arc::new(FaultInjector::new(DeterministicRng::new(seed)));
525 SimLLM::new(clock, rng, faults).without_latency()
526 }
527
528 #[tokio::test]
529 async fn test_determinism() {
530 let llm1 = create_test_llm(42);
531 let llm2 = create_test_llm(42);
532
533 let prompt = "Extract entities from: Alice works at Acme Corp.";
534
535 let response1 = llm1.complete(prompt).await.unwrap();
536 let response2 = llm2.complete(prompt).await.unwrap();
537
538 assert_eq!(
539 response1, response2,
540 "Same seed should produce same response"
541 );
542 }
543
544 #[tokio::test]
545 async fn test_different_seeds_different_responses() {
546 let llm1 = create_test_llm(42);
547 let llm2 = create_test_llm(12345);
548
549 let prompt = "Extract entities from: Bob met Charlie at Google.";
550
551 let response1 = llm1.complete(prompt).await.unwrap();
552 let response2 = llm2.complete(prompt).await.unwrap();
553
554 assert!(response1.contains("Bob") || response1.contains("Charlie"));
557 assert!(response2.contains("Bob") || response2.contains("Charlie"));
558 }
559
560 #[tokio::test]
561 async fn test_entity_extraction_routing() {
562 let llm = create_test_llm(42);
563
564 let prompt = "Extract entities from the following text: Alice and Bob work at Microsoft.";
565 let response = llm.complete(prompt).await.unwrap();
566
567 assert!(response.contains("entities"));
568 assert!(response.contains("Alice") || response.contains("Bob"));
569 }
570
571 #[tokio::test]
572 async fn test_query_rewrite_routing() {
573 let llm = create_test_llm(42);
574
575 let prompt =
576 "Rewrite the following query for better search:\nQuery: what is rust programming";
577 let response = llm.complete(prompt).await.unwrap();
578
579 assert!(response.contains("queries"));
580 }
581
582 #[tokio::test]
583 async fn test_evolution_detection_routing() {
584 let llm = create_test_llm(42);
585
586 let prompt = "Detect evolution relationship between:\nOld: Alice is 25\nNew: Alice is 26";
587 let response = llm.complete(prompt).await.unwrap();
588
589 assert!(response.contains("evolution_type") || response.contains("detected"));
590 }
591
592 #[tokio::test]
593 async fn test_generic_routing() {
594 let llm = create_test_llm(42);
595
596 let prompt = "Hello, how are you?";
597 let response = llm.complete(prompt).await.unwrap();
598
599 assert!(response.contains("response") || response.contains("success"));
600 }
601
602 #[tokio::test]
603 async fn test_empty_prompt_error() {
604 let llm = create_test_llm(42);
605
606 let result = llm.complete("").await;
607 assert!(matches!(result, Err(LLMError::InvalidPrompt(_))));
608 }
609
610 #[tokio::test]
611 async fn test_prompt_too_long_error() {
612 let llm = create_test_llm(42);
613
614 let long_prompt = "x".repeat(LLM_PROMPT_BYTES_MAX + 1);
615 let result = llm.complete(&long_prompt).await;
616
617 assert!(matches!(result, Err(LLMError::ContextOverflow(_))));
618 }
619
620 #[tokio::test]
621 async fn test_fault_injection_timeout() {
622 let clock = SimClock::new();
623 let rng = DeterministicRng::new(42);
624 let mut injector = FaultInjector::new(DeterministicRng::new(42));
625 injector.register(FaultConfig::new(FaultType::LlmTimeout, 1.0));
626 let faults = Arc::new(injector);
627
628 let llm = SimLLM::new(clock, rng, faults).without_latency();
629 let result = llm.complete("test prompt").await;
630
631 assert!(matches!(result, Err(LLMError::Timeout)));
632 }
633
634 #[tokio::test]
635 async fn test_fault_injection_rate_limit() {
636 let clock = SimClock::new();
637 let rng = DeterministicRng::new(42);
638 let mut injector = FaultInjector::new(DeterministicRng::new(42));
639 injector.register(FaultConfig::new(FaultType::LlmRateLimit, 1.0));
640 let faults = Arc::new(injector);
641
642 let llm = SimLLM::new(clock, rng, faults).without_latency();
643 let result = llm.complete("test prompt").await;
644
645 assert!(matches!(result, Err(LLMError::RateLimit)));
646 }
647
648 #[tokio::test]
649 async fn test_complete_json() {
650 let llm = create_test_llm(42);
651
652 #[derive(serde::Deserialize)]
653 struct GenericResponse {
654 response: String,
655 success: bool,
656 }
657
658 let prompt = "Hello, world!";
659 let result: GenericResponse = llm.complete_json(prompt).await.unwrap();
660
661 assert!(result.success);
662 assert!(!result.response.is_empty());
663 }
664
665 #[tokio::test]
666 async fn test_with_latency() {
667 let clock = SimClock::new();
668 let rng = DeterministicRng::new(42);
669 let faults = Arc::new(FaultInjector::new(DeterministicRng::new(42)));
670
671 let llm = SimLLM::new(clock.clone(), rng, faults).with_latency(500);
672
673 let clock_for_advance = clock.clone();
675 let advance_handle = tokio::spawn(async move {
676 tokio::task::yield_now().await;
678 clock_for_advance.advance_ms(600);
680 });
681
682 let start = clock.now_ms();
683 llm.complete("test").await.unwrap();
684 let end = clock.now_ms();
685
686 advance_handle.await.unwrap();
687
688 assert!(
690 end >= start + 500,
691 "Expected clock to advance at least 500ms, start={}, end={}",
692 start,
693 end
694 );
695 }
696
697 #[test]
698 fn test_prompt_hash_determinism() {
699 let llm = create_test_llm(42);
700
701 let hash1 = llm.prompt_hash("test prompt");
702 let hash2 = llm.prompt_hash("test prompt");
703 let hash3 = llm.prompt_hash("different prompt");
704
705 assert_eq!(hash1, hash2);
706 assert_ne!(hash1, hash3);
707 }
708
709 #[test]
710 #[should_panic(expected = "latency must be in")]
711 fn test_invalid_latency() {
712 let clock = SimClock::new();
713 let rng = DeterministicRng::new(42);
714 let faults = Arc::new(FaultInjector::new(DeterministicRng::new(42)));
715
716 let _ = SimLLM::new(clock, rng, faults).with_latency(999999);
717 }
718}