umi_memory/llm/
mod.rs

1//! LLM Provider Trait - Unified Interface for Sim and Production
2//!
3//! TigerStyle: Simulation-first LLM abstraction.
4//!
5//! See ADR-013 for design rationale.
6//!
7//! # Architecture
8//!
9//! ```text
10//! LLMProvider (trait)
11//! ├── SimLLMProvider      (always available, wraps DST SimLLM)
12//! ├── AnthropicProvider   (feature: anthropic)
13//! └── OpenAIProvider      (feature: openai)
14//! ```
15//!
16//! # Usage
17//!
18//! ```rust
19//! use umi_memory::llm::{LLMProvider, SimLLMProvider, CompletionRequest};
20//!
21//! #[tokio::main]
22//! async fn main() {
23//!     // Simulation (always available, no external deps)
24//!     let provider = SimLLMProvider::with_seed(42);
25//!
26//!     let request = CompletionRequest::new("Extract entities from: Alice works at Acme.");
27//!     let response = provider.complete(&request).await.unwrap();
28//!     println!("Response: {}", response);
29//! }
30//! ```
31
32mod sim;
33
34#[cfg(feature = "anthropic")]
35mod anthropic;
36
37#[cfg(feature = "openai")]
38mod openai;
39
40pub use sim::SimLLMProvider;
41
42#[cfg(feature = "anthropic")]
43pub use anthropic::AnthropicProvider;
44
45#[cfg(feature = "openai")]
46pub use openai::OpenAIProvider;
47
48use async_trait::async_trait;
49use serde::de::DeserializeOwned;
50
51use crate::constants::{LLM_PROMPT_BYTES_MAX, LLM_RESPONSE_BYTES_MAX};
52
53// =============================================================================
54// Error Types
55// =============================================================================
56
57/// Unified error type for all LLM providers.
58///
59/// TigerStyle: Explicit variants for all failure modes.
60#[derive(Debug, Clone, thiserror::Error)]
61pub enum ProviderError {
62    /// Request timed out
63    #[error("Request timed out")]
64    Timeout,
65
66    /// Rate limit exceeded
67    #[error("Rate limit exceeded, retry after {retry_after_secs:?}s")]
68    RateLimit {
69        /// Seconds until rate limit resets (if known)
70        retry_after_secs: Option<u64>,
71    },
72
73    /// Context/prompt too long
74    #[error("Context length exceeded: {tokens} tokens")]
75    ContextOverflow {
76        /// Number of tokens that exceeded the limit
77        tokens: usize,
78    },
79
80    /// Invalid response from provider
81    #[error("Invalid response: {message}")]
82    InvalidResponse {
83        /// Description of what was invalid
84        message: String,
85    },
86
87    /// Service unavailable
88    #[error("Service unavailable: {message}")]
89    ServiceUnavailable {
90        /// Reason for unavailability
91        message: String,
92    },
93
94    /// Authentication failed
95    #[error("Authentication failed")]
96    AuthenticationFailed,
97
98    /// JSON serialization/deserialization error
99    #[error("JSON error: {message}")]
100    JsonError {
101        /// Description of the JSON error
102        message: String,
103    },
104
105    /// Network error
106    #[error("Network error: {message}")]
107    NetworkError {
108        /// Description of the network error
109        message: String,
110    },
111
112    /// Invalid request parameters
113    #[error("Invalid request: {message}")]
114    InvalidRequest {
115        /// Description of what was invalid
116        message: String,
117    },
118}
119
120impl ProviderError {
121    /// Create a timeout error.
122    #[must_use]
123    pub fn timeout() -> Self {
124        Self::Timeout
125    }
126
127    /// Create a rate limit error.
128    #[must_use]
129    pub fn rate_limit(retry_after_secs: Option<u64>) -> Self {
130        Self::RateLimit { retry_after_secs }
131    }
132
133    /// Create a context overflow error.
134    #[must_use]
135    pub fn context_overflow(tokens: usize) -> Self {
136        Self::ContextOverflow { tokens }
137    }
138
139    /// Create an invalid response error.
140    #[must_use]
141    pub fn invalid_response(message: impl Into<String>) -> Self {
142        Self::InvalidResponse {
143            message: message.into(),
144        }
145    }
146
147    /// Create a service unavailable error.
148    #[must_use]
149    pub fn service_unavailable(message: impl Into<String>) -> Self {
150        Self::ServiceUnavailable {
151            message: message.into(),
152        }
153    }
154
155    /// Create a JSON error.
156    #[must_use]
157    pub fn json_error(message: impl Into<String>) -> Self {
158        Self::JsonError {
159            message: message.into(),
160        }
161    }
162
163    /// Create a network error.
164    #[must_use]
165    pub fn network_error(message: impl Into<String>) -> Self {
166        Self::NetworkError {
167            message: message.into(),
168        }
169    }
170
171    /// Create an invalid request error.
172    #[must_use]
173    pub fn invalid_request(message: impl Into<String>) -> Self {
174        Self::InvalidRequest {
175            message: message.into(),
176        }
177    }
178
179    /// Check if this error is retryable.
180    #[must_use]
181    pub fn is_retryable(&self) -> bool {
182        matches!(
183            self,
184            Self::Timeout | Self::RateLimit { .. } | Self::ServiceUnavailable { .. }
185        )
186    }
187}
188
189// =============================================================================
190// Request Types
191// =============================================================================
192
193/// Request for LLM completion.
194///
195/// TigerStyle: Explicit fields, no hidden defaults.
196#[derive(Debug, Clone)]
197pub struct CompletionRequest {
198    /// The prompt text (required)
199    pub prompt: String,
200    /// Optional system message (for chat-style APIs)
201    pub system: Option<String>,
202    /// Maximum tokens to generate (provider default if None)
203    pub max_tokens: Option<usize>,
204    /// Temperature (0.0-1.0, provider default if None)
205    pub temperature: Option<f32>,
206    /// Whether to request JSON output
207    pub json_mode: bool,
208}
209
210impl CompletionRequest {
211    /// Create a new completion request with just a prompt.
212    ///
213    /// # Panics
214    /// Panics if prompt is empty or exceeds `LLM_PROMPT_BYTES_MAX`.
215    #[must_use]
216    pub fn new(prompt: impl Into<String>) -> Self {
217        let prompt = prompt.into();
218
219        // Preconditions
220        assert!(!prompt.is_empty(), "prompt must not be empty");
221        assert!(
222            prompt.len() <= LLM_PROMPT_BYTES_MAX,
223            "prompt exceeds {} bytes",
224            LLM_PROMPT_BYTES_MAX
225        );
226
227        Self {
228            prompt,
229            system: None,
230            max_tokens: None,
231            temperature: None,
232            json_mode: false,
233        }
234    }
235
236    /// Set the system message.
237    #[must_use]
238    pub fn with_system(mut self, system: impl Into<String>) -> Self {
239        self.system = Some(system.into());
240        self
241    }
242
243    /// Set maximum tokens to generate.
244    #[must_use]
245    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
246        self.max_tokens = Some(max_tokens);
247        self
248    }
249
250    /// Set temperature.
251    ///
252    /// # Panics
253    /// Panics if temperature is not in [0.0, 1.0].
254    #[must_use]
255    pub fn with_temperature(mut self, temperature: f32) -> Self {
256        assert!(
257            (0.0..=1.0).contains(&temperature),
258            "temperature must be in [0.0, 1.0]"
259        );
260        self.temperature = Some(temperature);
261        self
262    }
263
264    /// Enable JSON mode (request structured output).
265    #[must_use]
266    pub fn with_json_mode(mut self) -> Self {
267        self.json_mode = true;
268        self
269    }
270}
271
272// =============================================================================
273// Provider Trait
274// =============================================================================
275
276/// Trait for LLM providers.
277///
278/// TigerStyle: Unified interface for simulation and production.
279///
280/// All providers implement this trait, allowing higher-level components
281/// to work with any provider without knowing the concrete type.
282///
283/// # Example
284///
285/// ```rust
286/// use umi_memory::llm::{LLMProvider, SimLLMProvider, CompletionRequest};
287///
288/// async fn extract_entities<P: LLMProvider>(provider: &P, text: &str) -> String {
289///     let request = CompletionRequest::new(format!("Extract entities from: {}", text));
290///     provider.complete(&request).await.unwrap()
291/// }
292/// ```
293#[async_trait]
294pub trait LLMProvider: Send + Sync {
295    /// Complete a prompt with a text response.
296    ///
297    /// # Errors
298    /// Returns `ProviderError` on failure.
299    async fn complete(&self, request: &CompletionRequest) -> Result<String, ProviderError>;
300
301    /// Complete a prompt expecting a JSON response.
302    ///
303    /// This is a convenience method that calls `complete` and parses the response.
304    ///
305    /// # Errors
306    /// Returns `ProviderError` on failure or JSON parse error.
307    async fn complete_json<T: DeserializeOwned + Send>(
308        &self,
309        request: &CompletionRequest,
310    ) -> Result<T, ProviderError> {
311        let response = self.complete(request).await?;
312
313        // Postcondition: response should be valid JSON
314        debug_assert!(
315            response.len() <= LLM_RESPONSE_BYTES_MAX,
316            "response exceeds limit"
317        );
318
319        serde_json::from_str(&response).map_err(|e| ProviderError::json_error(e.to_string()))
320    }
321
322    /// Get the provider name for logging/debugging.
323    fn name(&self) -> &'static str;
324
325    /// Check if this is a simulation provider.
326    ///
327    /// Returns `true` for `SimLLMProvider`, `false` for real providers.
328    fn is_simulation(&self) -> bool;
329}
330
331// =============================================================================
332// Tests
333// =============================================================================
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_completion_request_new() {
341        let request = CompletionRequest::new("Hello, world!");
342        assert_eq!(request.prompt, "Hello, world!");
343        assert!(request.system.is_none());
344        assert!(request.max_tokens.is_none());
345        assert!(request.temperature.is_none());
346        assert!(!request.json_mode);
347    }
348
349    #[test]
350    fn test_completion_request_builder() {
351        let request = CompletionRequest::new("Hello")
352            .with_system("You are a helpful assistant")
353            .with_max_tokens(100)
354            .with_temperature(0.7)
355            .with_json_mode();
356
357        assert_eq!(request.prompt, "Hello");
358        assert_eq!(request.system, Some("You are a helpful assistant".into()));
359        assert_eq!(request.max_tokens, Some(100));
360        assert_eq!(request.temperature, Some(0.7));
361        assert!(request.json_mode);
362    }
363
364    #[test]
365    #[should_panic(expected = "prompt must not be empty")]
366    fn test_completion_request_empty_prompt() {
367        let _ = CompletionRequest::new("");
368    }
369
370    #[test]
371    #[should_panic(expected = "temperature must be in")]
372    fn test_completion_request_invalid_temperature() {
373        let _ = CompletionRequest::new("Hello").with_temperature(1.5);
374    }
375
376    #[test]
377    fn test_provider_error_is_retryable() {
378        assert!(ProviderError::timeout().is_retryable());
379        assert!(ProviderError::rate_limit(Some(60)).is_retryable());
380        assert!(ProviderError::service_unavailable("down").is_retryable());
381        assert!(!ProviderError::AuthenticationFailed.is_retryable());
382        assert!(!ProviderError::json_error("parse failed").is_retryable());
383    }
384
385    #[test]
386    fn test_provider_error_constructors() {
387        let err = ProviderError::context_overflow(10000);
388        assert!(matches!(
389            err,
390            ProviderError::ContextOverflow { tokens: 10000 }
391        ));
392
393        let err = ProviderError::invalid_response("bad format");
394        assert!(matches!(err, ProviderError::InvalidResponse { .. }));
395
396        let err = ProviderError::network_error("connection refused");
397        assert!(matches!(err, ProviderError::NetworkError { .. }));
398    }
399}