1mod sim;
33
34#[cfg(feature = "anthropic")]
35mod anthropic;
36
37#[cfg(feature = "openai")]
38mod openai;
39
40pub use sim::SimLLMProvider;
41
42#[cfg(feature = "anthropic")]
43pub use anthropic::AnthropicProvider;
44
45#[cfg(feature = "openai")]
46pub use openai::OpenAIProvider;
47
48use async_trait::async_trait;
49use serde::de::DeserializeOwned;
50
51use crate::constants::{LLM_PROMPT_BYTES_MAX, LLM_RESPONSE_BYTES_MAX};
52
53#[derive(Debug, Clone, thiserror::Error)]
61pub enum ProviderError {
62 #[error("Request timed out")]
64 Timeout,
65
66 #[error("Rate limit exceeded, retry after {retry_after_secs:?}s")]
68 RateLimit {
69 retry_after_secs: Option<u64>,
71 },
72
73 #[error("Context length exceeded: {tokens} tokens")]
75 ContextOverflow {
76 tokens: usize,
78 },
79
80 #[error("Invalid response: {message}")]
82 InvalidResponse {
83 message: String,
85 },
86
87 #[error("Service unavailable: {message}")]
89 ServiceUnavailable {
90 message: String,
92 },
93
94 #[error("Authentication failed")]
96 AuthenticationFailed,
97
98 #[error("JSON error: {message}")]
100 JsonError {
101 message: String,
103 },
104
105 #[error("Network error: {message}")]
107 NetworkError {
108 message: String,
110 },
111
112 #[error("Invalid request: {message}")]
114 InvalidRequest {
115 message: String,
117 },
118}
119
120impl ProviderError {
121 #[must_use]
123 pub fn timeout() -> Self {
124 Self::Timeout
125 }
126
127 #[must_use]
129 pub fn rate_limit(retry_after_secs: Option<u64>) -> Self {
130 Self::RateLimit { retry_after_secs }
131 }
132
133 #[must_use]
135 pub fn context_overflow(tokens: usize) -> Self {
136 Self::ContextOverflow { tokens }
137 }
138
139 #[must_use]
141 pub fn invalid_response(message: impl Into<String>) -> Self {
142 Self::InvalidResponse {
143 message: message.into(),
144 }
145 }
146
147 #[must_use]
149 pub fn service_unavailable(message: impl Into<String>) -> Self {
150 Self::ServiceUnavailable {
151 message: message.into(),
152 }
153 }
154
155 #[must_use]
157 pub fn json_error(message: impl Into<String>) -> Self {
158 Self::JsonError {
159 message: message.into(),
160 }
161 }
162
163 #[must_use]
165 pub fn network_error(message: impl Into<String>) -> Self {
166 Self::NetworkError {
167 message: message.into(),
168 }
169 }
170
171 #[must_use]
173 pub fn invalid_request(message: impl Into<String>) -> Self {
174 Self::InvalidRequest {
175 message: message.into(),
176 }
177 }
178
179 #[must_use]
181 pub fn is_retryable(&self) -> bool {
182 matches!(
183 self,
184 Self::Timeout | Self::RateLimit { .. } | Self::ServiceUnavailable { .. }
185 )
186 }
187}
188
189#[derive(Debug, Clone)]
197pub struct CompletionRequest {
198 pub prompt: String,
200 pub system: Option<String>,
202 pub max_tokens: Option<usize>,
204 pub temperature: Option<f32>,
206 pub json_mode: bool,
208}
209
210impl CompletionRequest {
211 #[must_use]
216 pub fn new(prompt: impl Into<String>) -> Self {
217 let prompt = prompt.into();
218
219 assert!(!prompt.is_empty(), "prompt must not be empty");
221 assert!(
222 prompt.len() <= LLM_PROMPT_BYTES_MAX,
223 "prompt exceeds {} bytes",
224 LLM_PROMPT_BYTES_MAX
225 );
226
227 Self {
228 prompt,
229 system: None,
230 max_tokens: None,
231 temperature: None,
232 json_mode: false,
233 }
234 }
235
236 #[must_use]
238 pub fn with_system(mut self, system: impl Into<String>) -> Self {
239 self.system = Some(system.into());
240 self
241 }
242
243 #[must_use]
245 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
246 self.max_tokens = Some(max_tokens);
247 self
248 }
249
250 #[must_use]
255 pub fn with_temperature(mut self, temperature: f32) -> Self {
256 assert!(
257 (0.0..=1.0).contains(&temperature),
258 "temperature must be in [0.0, 1.0]"
259 );
260 self.temperature = Some(temperature);
261 self
262 }
263
264 #[must_use]
266 pub fn with_json_mode(mut self) -> Self {
267 self.json_mode = true;
268 self
269 }
270}
271
272#[async_trait]
294pub trait LLMProvider: Send + Sync {
295 async fn complete(&self, request: &CompletionRequest) -> Result<String, ProviderError>;
300
301 async fn complete_json<T: DeserializeOwned + Send>(
308 &self,
309 request: &CompletionRequest,
310 ) -> Result<T, ProviderError> {
311 let response = self.complete(request).await?;
312
313 debug_assert!(
315 response.len() <= LLM_RESPONSE_BYTES_MAX,
316 "response exceeds limit"
317 );
318
319 serde_json::from_str(&response).map_err(|e| ProviderError::json_error(e.to_string()))
320 }
321
322 fn name(&self) -> &'static str;
324
325 fn is_simulation(&self) -> bool;
329}
330
331#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_completion_request_new() {
341 let request = CompletionRequest::new("Hello, world!");
342 assert_eq!(request.prompt, "Hello, world!");
343 assert!(request.system.is_none());
344 assert!(request.max_tokens.is_none());
345 assert!(request.temperature.is_none());
346 assert!(!request.json_mode);
347 }
348
349 #[test]
350 fn test_completion_request_builder() {
351 let request = CompletionRequest::new("Hello")
352 .with_system("You are a helpful assistant")
353 .with_max_tokens(100)
354 .with_temperature(0.7)
355 .with_json_mode();
356
357 assert_eq!(request.prompt, "Hello");
358 assert_eq!(request.system, Some("You are a helpful assistant".into()));
359 assert_eq!(request.max_tokens, Some(100));
360 assert_eq!(request.temperature, Some(0.7));
361 assert!(request.json_mode);
362 }
363
364 #[test]
365 #[should_panic(expected = "prompt must not be empty")]
366 fn test_completion_request_empty_prompt() {
367 let _ = CompletionRequest::new("");
368 }
369
370 #[test]
371 #[should_panic(expected = "temperature must be in")]
372 fn test_completion_request_invalid_temperature() {
373 let _ = CompletionRequest::new("Hello").with_temperature(1.5);
374 }
375
376 #[test]
377 fn test_provider_error_is_retryable() {
378 assert!(ProviderError::timeout().is_retryable());
379 assert!(ProviderError::rate_limit(Some(60)).is_retryable());
380 assert!(ProviderError::service_unavailable("down").is_retryable());
381 assert!(!ProviderError::AuthenticationFailed.is_retryable());
382 assert!(!ProviderError::json_error("parse failed").is_retryable());
383 }
384
385 #[test]
386 fn test_provider_error_constructors() {
387 let err = ProviderError::context_overflow(10000);
388 assert!(matches!(
389 err,
390 ProviderError::ContextOverflow { tokens: 10000 }
391 ));
392
393 let err = ProviderError::invalid_response("bad format");
394 assert!(matches!(err, ProviderError::InvalidResponse { .. }));
395
396 let err = ProviderError::network_error("connection refused");
397 assert!(matches!(err, ProviderError::NetworkError { .. }));
398 }
399}