1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use super::{CompletionRequest, LLMProvider, ProviderError};
13use crate::dst::{DeterministicRng, FaultInjector, LLMError, SimClock, SimLLM};
14
15#[derive(Debug, Clone)]
49pub struct SimLLMProvider {
50 inner: SimLLM,
52}
53
54impl SimLLMProvider {
55 #[must_use]
72 pub fn from_sim_llm(sim_llm: SimLLM) -> Self {
73 Self { inner: sim_llm }
74 }
75
76 #[must_use]
88 pub fn with_seed(seed: u64) -> Self {
89 let clock = SimClock::new();
90 let rng = DeterministicRng::new(seed);
91 let faults = Arc::new(FaultInjector::new(DeterministicRng::new(seed)));
92
93 let sim_llm = SimLLM::new(clock, rng, faults).without_latency();
95
96 Self { inner: sim_llm }
97 }
98
99 #[must_use]
114 pub fn with_faults(seed: u64, faults: Arc<FaultInjector>) -> Self {
115 let clock = SimClock::new();
116 let rng = DeterministicRng::new(seed);
117
118 let sim_llm = SimLLM::new(clock, rng, faults).without_latency();
120
121 Self { inner: sim_llm }
122 }
123
124 #[must_use]
126 pub fn seed(&self) -> u64 {
127 self.inner.seed()
128 }
129}
130
131#[async_trait]
132impl LLMProvider for SimLLMProvider {
133 async fn complete(&self, request: &CompletionRequest) -> Result<String, ProviderError> {
134 let full_prompt = match &request.system {
136 Some(system) => format!("{}\n\n{}", system, request.prompt),
137 None => request.prompt.clone(),
138 };
139
140 self.inner
142 .complete(&full_prompt)
143 .await
144 .map_err(llm_error_to_provider_error)
145 }
146
147 fn name(&self) -> &'static str {
148 "sim"
149 }
150
151 fn is_simulation(&self) -> bool {
152 true
153 }
154}
155
156fn llm_error_to_provider_error(err: LLMError) -> ProviderError {
158 match err {
159 LLMError::Timeout => ProviderError::Timeout,
160 LLMError::RateLimit => ProviderError::rate_limit(None),
161 LLMError::ContextOverflow(size) => ProviderError::context_overflow(size),
162 LLMError::InvalidResponse(msg) => ProviderError::invalid_response(msg),
163 LLMError::ServiceUnavailable => ProviderError::service_unavailable("service unavailable"),
164 LLMError::JsonError(msg) => ProviderError::json_error(msg),
165 LLMError::InvalidPrompt(msg) => ProviderError::invalid_request(msg),
166 }
167}
168
169#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::dst::{FaultConfig, FaultType};
177
178 #[tokio::test]
179 async fn test_determinism() {
180 let provider1 = SimLLMProvider::with_seed(42);
181 let provider2 = SimLLMProvider::with_seed(42);
182
183 let request = CompletionRequest::new("Extract entities from: Alice works at Acme.");
184
185 let response1 = provider1.complete(&request).await.unwrap();
186 let response2 = provider2.complete(&request).await.unwrap();
187
188 assert_eq!(
189 response1, response2,
190 "Same seed should produce same response"
191 );
192 }
193
194 #[tokio::test]
195 async fn test_different_seeds() {
196 let provider1 = SimLLMProvider::with_seed(42);
197 let provider2 = SimLLMProvider::with_seed(12345);
198
199 let request = CompletionRequest::new("Extract entities from: Bob met Charlie.");
200
201 let response1 = provider1.complete(&request).await.unwrap();
202 let response2 = provider2.complete(&request).await.unwrap();
203
204 assert!(response1.contains("entities") || response1.contains("Bob"));
206 assert!(response2.contains("entities") || response2.contains("Bob"));
207 }
208
209 #[tokio::test]
210 async fn test_name() {
211 let provider = SimLLMProvider::with_seed(42);
212 assert_eq!(provider.name(), "sim");
213 }
214
215 #[tokio::test]
216 async fn test_is_simulation() {
217 let provider = SimLLMProvider::with_seed(42);
218 assert!(provider.is_simulation());
219 }
220
221 #[tokio::test]
222 async fn test_with_system_prompt() {
223 let provider = SimLLMProvider::with_seed(42);
224
225 let request = CompletionRequest::new("Extract entities from: Alice.")
226 .with_system("You are an entity extractor.");
227
228 let response = provider.complete(&request).await.unwrap();
229 assert!(!response.is_empty());
230 }
231
232 #[tokio::test]
233 async fn test_complete_json() {
234 let provider = SimLLMProvider::with_seed(42);
235
236 #[derive(serde::Deserialize)]
237 struct GenericResponse {
238 response: String,
239 success: bool,
240 }
241
242 let request = CompletionRequest::new("Hello, world!").with_json_mode();
243 let result: GenericResponse = provider.complete_json(&request).await.unwrap();
244
245 assert!(result.success);
246 assert!(!result.response.is_empty());
247 }
248
249 #[tokio::test]
250 async fn test_entity_extraction_prompt() {
251 let provider = SimLLMProvider::with_seed(42);
252
253 let request =
254 CompletionRequest::new("Extract entities from the text: Alice works at Microsoft.");
255
256 let response = provider.complete(&request).await.unwrap();
257
258 assert!(response.contains("entities"));
260 assert!(response.contains("Alice") || response.contains("Microsoft"));
261 }
262
263 #[tokio::test]
264 async fn test_query_rewrite_prompt() {
265 let provider = SimLLMProvider::with_seed(42);
266
267 let request = CompletionRequest::new(
268 "Rewrite this query for better search results:\nQuery: what is rust programming",
269 );
270
271 let response = provider.complete(&request).await.unwrap();
272
273 assert!(response.contains("queries"));
275 }
276
277 #[tokio::test]
278 async fn test_fault_injection_timeout() {
279 let mut injector = FaultInjector::new(DeterministicRng::new(42));
280 injector.register(FaultConfig::new(FaultType::LlmTimeout, 1.0));
281
282 let provider = SimLLMProvider::with_faults(42, Arc::new(injector));
283 let request = CompletionRequest::new("Test prompt");
284
285 let result = provider.complete(&request).await;
286 assert!(matches!(result, Err(ProviderError::Timeout)));
287 }
288
289 #[tokio::test]
290 async fn test_fault_injection_rate_limit() {
291 let mut injector = FaultInjector::new(DeterministicRng::new(42));
292 injector.register(FaultConfig::new(FaultType::LlmRateLimit, 1.0));
293
294 let provider = SimLLMProvider::with_faults(42, Arc::new(injector));
295 let request = CompletionRequest::new("Test prompt");
296
297 let result = provider.complete(&request).await;
298 assert!(matches!(result, Err(ProviderError::RateLimit { .. })));
299 }
300
301 #[tokio::test]
302 async fn test_seed_getter() {
303 let provider = SimLLMProvider::with_seed(12345);
304 assert_eq!(provider.seed(), 12345);
305 }
306}