umi_memory/llm/
sim.rs

1//! SimLLMProvider - Simulation-First LLM Provider
2//!
3//! TigerStyle: Primary implementation, wraps DST SimLLM.
4//!
5//! This is the DEFAULT provider for all tests and development.
6//! Real providers (Anthropic, OpenAI) are secondary.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use super::{CompletionRequest, LLMProvider, ProviderError};
13use crate::dst::{DeterministicRng, FaultInjector, LLMError, SimClock, SimLLM};
14
15// =============================================================================
16// SimLLMProvider
17// =============================================================================
18
19/// Simulation LLM provider wrapping DST SimLLM.
20///
21/// TigerStyle: Primary implementation, always available.
22///
23/// This provider wraps the deterministic `SimLLM` from the DST module,
24/// providing the same interface as production providers but with:
25/// - Deterministic responses (same seed = same output)
26/// - Fault injection support
27/// - No external dependencies
28///
29/// # Example
30///
31/// ```rust
32/// use umi_memory::llm::{SimLLMProvider, CompletionRequest, LLMProvider};
33///
34/// #[tokio::main]
35/// async fn main() {
36///     // Create with explicit seed for reproducibility
37///     let provider = SimLLMProvider::with_seed(42);
38///
39///     let request = CompletionRequest::new("Extract entities from: Alice works at Acme.");
40///     let response = provider.complete(&request).await.unwrap();
41///
42///     // Same seed always produces same response
43///     let provider2 = SimLLMProvider::with_seed(42);
44///     let response2 = provider2.complete(&request).await.unwrap();
45///     assert_eq!(response, response2);
46/// }
47/// ```
48#[derive(Debug, Clone)]
49pub struct SimLLMProvider {
50    /// The underlying SimLLM from DST
51    inner: SimLLM,
52}
53
54impl SimLLMProvider {
55    /// Create a new `SimLLMProvider` from an existing `SimLLM`.
56    ///
57    /// Use this when you already have a `SimLLM` from `SimEnvironment`.
58    ///
59    /// # Example
60    ///
61    /// ```rust
62    /// use umi_memory::dst::{Simulation, SimConfig};
63    /// use umi_memory::llm::SimLLMProvider;
64    ///
65    /// let sim = Simulation::new(SimConfig::with_seed(42));
66    /// let env = sim.build();
67    ///
68    /// // Note: In practice, you'd typically use env.llm directly
69    /// // This is for cases where you need the LLMProvider trait
70    /// ```
71    #[must_use]
72    pub fn from_sim_llm(sim_llm: SimLLM) -> Self {
73        Self { inner: sim_llm }
74    }
75
76    /// Create a new standalone `SimLLMProvider` with the given seed.
77    ///
78    /// This is the most common way to create a `SimLLMProvider` for testing.
79    ///
80    /// # Example
81    ///
82    /// ```rust
83    /// use umi_memory::llm::SimLLMProvider;
84    ///
85    /// let provider = SimLLMProvider::with_seed(42);
86    /// ```
87    #[must_use]
88    pub fn with_seed(seed: u64) -> Self {
89        let clock = SimClock::new();
90        let rng = DeterministicRng::new(seed);
91        let faults = Arc::new(FaultInjector::new(DeterministicRng::new(seed)));
92
93        // Disable latency for standalone use (no clock advancement)
94        let sim_llm = SimLLM::new(clock, rng, faults).without_latency();
95
96        Self { inner: sim_llm }
97    }
98
99    /// Create a new `SimLLMProvider` with fault injection.
100    ///
101    /// # Example
102    ///
103    /// ```rust
104    /// use umi_memory::llm::SimLLMProvider;
105    /// use umi_memory::dst::{FaultConfig, FaultType, FaultInjector, DeterministicRng};
106    /// use std::sync::Arc;
107    ///
108    /// let mut injector = FaultInjector::new(DeterministicRng::new(42));
109    /// injector.register(FaultConfig::new(FaultType::LlmTimeout, 0.5));
110    ///
111    /// let provider = SimLLMProvider::with_faults(42, Arc::new(injector));
112    /// ```
113    #[must_use]
114    pub fn with_faults(seed: u64, faults: Arc<FaultInjector>) -> Self {
115        let clock = SimClock::new();
116        let rng = DeterministicRng::new(seed);
117
118        // Disable latency for standalone use
119        let sim_llm = SimLLM::new(clock, rng, faults).without_latency();
120
121        Self { inner: sim_llm }
122    }
123
124    /// Get the seed used by this provider (for debugging/logging).
125    #[must_use]
126    pub fn seed(&self) -> u64 {
127        self.inner.seed()
128    }
129}
130
131#[async_trait]
132impl LLMProvider for SimLLMProvider {
133    async fn complete(&self, request: &CompletionRequest) -> Result<String, ProviderError> {
134        // Build the full prompt (system + user prompt)
135        let full_prompt = match &request.system {
136            Some(system) => format!("{}\n\n{}", system, request.prompt),
137            None => request.prompt.clone(),
138        };
139
140        // Call the underlying SimLLM
141        self.inner
142            .complete(&full_prompt)
143            .await
144            .map_err(llm_error_to_provider_error)
145    }
146
147    fn name(&self) -> &'static str {
148        "sim"
149    }
150
151    fn is_simulation(&self) -> bool {
152        true
153    }
154}
155
156/// Convert `LLMError` from DST to `ProviderError`.
157fn llm_error_to_provider_error(err: LLMError) -> ProviderError {
158    match err {
159        LLMError::Timeout => ProviderError::Timeout,
160        LLMError::RateLimit => ProviderError::rate_limit(None),
161        LLMError::ContextOverflow(size) => ProviderError::context_overflow(size),
162        LLMError::InvalidResponse(msg) => ProviderError::invalid_response(msg),
163        LLMError::ServiceUnavailable => ProviderError::service_unavailable("service unavailable"),
164        LLMError::JsonError(msg) => ProviderError::json_error(msg),
165        LLMError::InvalidPrompt(msg) => ProviderError::invalid_request(msg),
166    }
167}
168
169// =============================================================================
170// Tests
171// =============================================================================
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::dst::{FaultConfig, FaultType};
177
178    #[tokio::test]
179    async fn test_determinism() {
180        let provider1 = SimLLMProvider::with_seed(42);
181        let provider2 = SimLLMProvider::with_seed(42);
182
183        let request = CompletionRequest::new("Extract entities from: Alice works at Acme.");
184
185        let response1 = provider1.complete(&request).await.unwrap();
186        let response2 = provider2.complete(&request).await.unwrap();
187
188        assert_eq!(
189            response1, response2,
190            "Same seed should produce same response"
191        );
192    }
193
194    #[tokio::test]
195    async fn test_different_seeds() {
196        let provider1 = SimLLMProvider::with_seed(42);
197        let provider2 = SimLLMProvider::with_seed(12345);
198
199        let request = CompletionRequest::new("Extract entities from: Bob met Charlie.");
200
201        let response1 = provider1.complete(&request).await.unwrap();
202        let response2 = provider2.complete(&request).await.unwrap();
203
204        // Responses should be structurally similar but may have different values
205        assert!(response1.contains("entities") || response1.contains("Bob"));
206        assert!(response2.contains("entities") || response2.contains("Bob"));
207    }
208
209    #[tokio::test]
210    async fn test_name() {
211        let provider = SimLLMProvider::with_seed(42);
212        assert_eq!(provider.name(), "sim");
213    }
214
215    #[tokio::test]
216    async fn test_is_simulation() {
217        let provider = SimLLMProvider::with_seed(42);
218        assert!(provider.is_simulation());
219    }
220
221    #[tokio::test]
222    async fn test_with_system_prompt() {
223        let provider = SimLLMProvider::with_seed(42);
224
225        let request = CompletionRequest::new("Extract entities from: Alice.")
226            .with_system("You are an entity extractor.");
227
228        let response = provider.complete(&request).await.unwrap();
229        assert!(!response.is_empty());
230    }
231
232    #[tokio::test]
233    async fn test_complete_json() {
234        let provider = SimLLMProvider::with_seed(42);
235
236        #[derive(serde::Deserialize)]
237        struct GenericResponse {
238            response: String,
239            success: bool,
240        }
241
242        let request = CompletionRequest::new("Hello, world!").with_json_mode();
243        let result: GenericResponse = provider.complete_json(&request).await.unwrap();
244
245        assert!(result.success);
246        assert!(!result.response.is_empty());
247    }
248
249    #[tokio::test]
250    async fn test_entity_extraction_prompt() {
251        let provider = SimLLMProvider::with_seed(42);
252
253        let request =
254            CompletionRequest::new("Extract entities from the text: Alice works at Microsoft.");
255
256        let response = provider.complete(&request).await.unwrap();
257
258        // Should route to entity extraction and contain entities
259        assert!(response.contains("entities"));
260        assert!(response.contains("Alice") || response.contains("Microsoft"));
261    }
262
263    #[tokio::test]
264    async fn test_query_rewrite_prompt() {
265        let provider = SimLLMProvider::with_seed(42);
266
267        let request = CompletionRequest::new(
268            "Rewrite this query for better search results:\nQuery: what is rust programming",
269        );
270
271        let response = provider.complete(&request).await.unwrap();
272
273        // Should route to query rewrite
274        assert!(response.contains("queries"));
275    }
276
277    #[tokio::test]
278    async fn test_fault_injection_timeout() {
279        let mut injector = FaultInjector::new(DeterministicRng::new(42));
280        injector.register(FaultConfig::new(FaultType::LlmTimeout, 1.0));
281
282        let provider = SimLLMProvider::with_faults(42, Arc::new(injector));
283        let request = CompletionRequest::new("Test prompt");
284
285        let result = provider.complete(&request).await;
286        assert!(matches!(result, Err(ProviderError::Timeout)));
287    }
288
289    #[tokio::test]
290    async fn test_fault_injection_rate_limit() {
291        let mut injector = FaultInjector::new(DeterministicRng::new(42));
292        injector.register(FaultConfig::new(FaultType::LlmRateLimit, 1.0));
293
294        let provider = SimLLMProvider::with_faults(42, Arc::new(injector));
295        let request = CompletionRequest::new("Test prompt");
296
297        let result = provider.complete(&request).await;
298        assert!(matches!(result, Err(ProviderError::RateLimit { .. })));
299    }
300
301    #[tokio::test]
302    async fn test_seed_getter() {
303        let provider = SimLLMProvider::with_seed(12345);
304        assert_eq!(provider.seed(), 12345);
305    }
306}