Skip to main content

simple_agent_type/
provider.rs

1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::error::Result;
6use crate::request::CompletionRequest;
7use crate::response::{CompletionChunk, CompletionResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::borrow::Cow;
12use std::time::Duration;
13
14/// Retry configuration for failed requests.
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub struct RetryConfig {
17    /// Maximum number of retry attempts
18    pub max_attempts: u32,
19    /// Initial backoff duration
20    pub initial_backoff: Duration,
21    /// Maximum backoff duration
22    pub max_backoff: Duration,
23    /// Backoff multiplier for exponential backoff
24    pub backoff_multiplier: f32,
25    /// Add random jitter to backoff
26    pub jitter: bool,
27}
28
29impl Default for RetryConfig {
30    fn default() -> Self {
31        Self {
32            max_attempts: 3,
33            initial_backoff: Duration::from_millis(100),
34            max_backoff: Duration::from_secs(10),
35            backoff_multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41/// Provider capabilities.
42#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
43pub struct Capabilities {
44    /// Supports streaming responses
45    pub streaming: bool,
46    /// Supports function/tool calling
47    pub function_calling: bool,
48    /// Supports vision/image inputs
49    pub vision: bool,
50    /// Maximum output tokens
51    pub max_tokens: u32,
52}
53
54/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
55pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
56
57/// Common HTTP header names (static to avoid allocations)
58pub mod headers {
59    /// Authorization header
60    pub const AUTHORIZATION: &str = "Authorization";
61    /// Content-Type header
62    pub const CONTENT_TYPE: &str = "Content-Type";
63    /// API key header (used by some providers like Anthropic)
64    pub const X_API_KEY: &str = "x-api-key";
65}
66
67/// Trait for LLM providers.
68///
69/// Providers implement this trait to support different LLM APIs while
70/// presenting a unified interface to the rest of SimpleAgents.
71///
72/// # Architecture
73///
74/// The provider trait follows a three-phase architecture:
75/// 1. **Transform Request**: Convert unified request to provider format
76/// 2. **Execute**: Make the actual API call
77/// 3. **Transform Response**: Convert provider response to unified format
78///
79/// This design allows for:
80/// - Maximum flexibility in provider-specific transformations
81/// - Clean separation between protocol logic and business logic
82/// - Easy testing and mocking
83///
84/// # Example Implementation
85///
86/// ```rust
87/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
88/// use simple_agent_type::request::CompletionRequest;
89/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
90/// use simple_agent_type::message::Message;
91/// use simple_agent_type::error::Result;
92/// use async_trait::async_trait;
93///
94/// struct MyProvider;
95///
96/// #[async_trait]
97/// impl Provider for MyProvider {
98///     fn name(&self) -> &str {
99///         "my-provider"
100///     }
101///
102///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
103///         Ok(ProviderRequest::new("http://example.com"))
104///     }
105///
106///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
107///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
108///     }
109///
110///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
111///         Ok(CompletionResponse {
112///             id: "resp_1".to_string(),
113///             model: "dummy".to_string(),
114///             choices: vec![CompletionChoice {
115///                 index: 0,
116///                 message: Message::assistant("ok"),
117///                 finish_reason: FinishReason::Stop,
118///                 logprobs: None,
119///             }],
120///             usage: Usage::new(1, 1),
121///             created: None,
122///             provider: Some(self.name().to_string()),
123///             healing_metadata: None,
124///         })
125///     }
126/// }
127///
128/// let provider = MyProvider;
129/// let request = CompletionRequest::builder()
130///     .model("gpt-4")
131///     .message(Message::user("Hello!"))
132///     .build()
133///     .unwrap();
134///
135/// let rt = tokio::runtime::Runtime::new().unwrap();
136/// rt.block_on(async {
137///     let provider_request = provider.transform_request(&request).unwrap();
138///     let provider_response = provider.execute(provider_request).await.unwrap();
139///     let response = provider.transform_response(provider_response).unwrap();
140///     assert_eq!(response.content(), Some("ok"));
141/// });
142/// ```
143#[async_trait]
144pub trait Provider: Send + Sync {
145    /// Provider name (e.g., "openai", "anthropic").
146    fn name(&self) -> &str;
147
148    /// Transform unified request to provider-specific format.
149    ///
150    /// This method converts the standardized `CompletionRequest` into
151    /// the provider's native API format.
152    fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
153
154    /// Execute request against provider API.
155    ///
156    /// This method makes the actual HTTP request to the provider.
157    /// Implementations should handle:
158    /// - Authentication (API keys, tokens)
159    /// - Rate limiting
160    /// - Network errors
161    /// - Provider-specific error codes
162    async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
163
164    /// Transform provider response to unified format.
165    ///
166    /// This method converts the provider's native response format into
167    /// the standardized `CompletionResponse`.
168    fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
169
170    /// Get retry configuration.
171    ///
172    /// Override to customize retry behavior for this provider.
173    fn retry_config(&self) -> RetryConfig {
174        RetryConfig::default()
175    }
176
177    /// Get provider capabilities.
178    ///
179    /// Override to specify what features this provider supports.
180    fn capabilities(&self) -> Capabilities {
181        Capabilities::default()
182    }
183
184    /// Get default timeout.
185    fn timeout(&self) -> Duration {
186        Duration::from_secs(30)
187    }
188
189    /// Execute streaming request against provider API.
190    ///
191    /// This method returns a stream of completion chunks for streaming responses.
192    /// Not all providers support streaming - default implementation returns an error.
193    ///
194    /// # Arguments
195    /// - `req`: The provider-specific request
196    ///
197    /// # Returns
198    /// A boxed stream of Result<CompletionChunk>
199    ///
200    /// # Example
201    /// ```rust
202    /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
203    /// use simple_agent_type::request::CompletionRequest;
204    /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
205    /// use simple_agent_type::message::Message;
206    /// use simple_agent_type::error::Result;
207    /// use async_trait::async_trait;
208    /// use futures_core::Stream;
209    /// use std::pin::Pin;
210    /// use std::task::{Context, Poll};
211    ///
212    /// struct EmptyStream;
213    ///
214    /// impl Stream for EmptyStream {
215    ///     type Item = Result<CompletionChunk>;
216    ///
217    ///     fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
218    ///         Poll::Ready(None)
219    ///     }
220    /// }
221    ///
222    /// struct StreamingProvider;
223    ///
224    /// #[async_trait]
225    /// impl Provider for StreamingProvider {
226    ///     fn name(&self) -> &str {
227    ///         "streaming-provider"
228    ///     }
229    ///
230    ///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
231    ///         Ok(ProviderRequest::new("http://example.com"))
232    ///     }
233    ///
234    ///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
235    ///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
236    ///     }
237    ///
238    ///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
239    ///         Ok(CompletionResponse {
240    ///             id: "resp_1".to_string(),
241    ///             model: "dummy".to_string(),
242    ///             choices: vec![CompletionChoice {
243    ///                 index: 0,
244    ///                 message: Message::assistant("ok"),
245    ///                 finish_reason: FinishReason::Stop,
246    ///                 logprobs: None,
247    ///             }],
248    ///             usage: Usage::new(1, 1),
249    ///             created: None,
250    ///             provider: None,
251    ///             healing_metadata: None,
252    ///         })
253    ///     }
254    ///
255    ///     async fn execute_stream(
256    ///         &self,
257    ///         _req: ProviderRequest,
258    ///     ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
259    ///         Ok(Box::new(EmptyStream))
260    ///     }
261    /// }
262    ///
263    /// let provider = StreamingProvider;
264    /// let request = CompletionRequest::builder()
265    ///     .model("gpt-4")
266    ///     .message(Message::user("Hello!"))
267    ///     .build()
268    ///     .unwrap();
269    ///
270    /// let rt = tokio::runtime::Runtime::new().unwrap();
271    /// rt.block_on(async {
272    ///     let provider_request = provider.transform_request(&request).unwrap();
273    ///     let _stream = provider.execute_stream(provider_request).await.unwrap();
274    ///     // Use StreamExt::next to consume the stream in real usage.
275    /// });
276    /// ```
277    async fn execute_stream(
278        &self,
279        mut req: ProviderRequest,
280    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
281        if let Value::Object(map) = &mut req.body {
282            if let Some(stream_value) = map.get_mut("stream") {
283                *stream_value = Value::Bool(false);
284            }
285        }
286
287        let provider_response = self.execute(req).await?;
288        let response = self.transform_response(provider_response)?;
289
290        struct SingleChunkStream {
291            chunk: Option<Result<CompletionChunk>>,
292        }
293
294        impl futures_core::Stream for SingleChunkStream {
295            type Item = Result<CompletionChunk>;
296
297            fn poll_next(
298                mut self: std::pin::Pin<&mut Self>,
299                _cx: &mut std::task::Context<'_>,
300            ) -> std::task::Poll<Option<Self::Item>> {
301                std::task::Poll::Ready(self.chunk.take())
302            }
303        }
304
305        let choices = response
306            .choices
307            .into_iter()
308            .map(|choice| crate::response::ChoiceDelta {
309                index: choice.index,
310                delta: crate::response::MessageDelta {
311                    role: Some(choice.message.role),
312                    content: Some(choice.message.content_text().to_string()),
313                    reasoning_content: None,
314                    tool_calls: None,
315                },
316                finish_reason: Some(choice.finish_reason),
317            })
318            .collect();
319
320        let chunk = CompletionChunk {
321            id: response.id,
322            model: response.model,
323            choices,
324            created: response.created,
325            usage: Some(response.usage),
326        };
327
328        Ok(Box::new(SingleChunkStream {
329            chunk: Some(Ok(chunk)),
330        }))
331    }
332}
333
334/// Opaque provider-specific request.
335///
336/// This type encapsulates all information needed to make an HTTP request
337/// to a provider, without committing to a specific HTTP client library.
338///
339/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
340#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
341pub struct ProviderRequest {
342    /// Full URL to send request to
343    pub url: String,
344    /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
345    #[serde(with = "header_serde")]
346    pub headers: Headers,
347    /// Request body (JSON)
348    pub body: serde_json::Value,
349    /// Optional request timeout override
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub timeout: Option<Duration>,
352}
353
354// Custom serde for Cow headers
355mod header_serde {
356    use super::Headers;
357    use serde::{Deserialize, Deserializer, Serialize, Serializer};
358    use std::borrow::Cow;
359
360    pub fn serialize<S>(
361        headers: &[(Cow<'static, str>, Cow<'static, str>)],
362        serializer: S,
363    ) -> Result<S::Ok, S::Error>
364    where
365        S: Serializer,
366    {
367        let string_headers: Vec<(&str, &str)> = headers
368            .iter()
369            .map(|(k, v)| (k.as_ref(), v.as_ref()))
370            .collect();
371        string_headers.serialize(serializer)
372    }
373
374    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
375    where
376        D: Deserializer<'de>,
377    {
378        let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
379        Ok(string_headers
380            .into_iter()
381            .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
382            .collect())
383    }
384}
385
386impl ProviderRequest {
387    /// Create a new provider request.
388    pub fn new(url: impl Into<String>) -> Self {
389        Self {
390            url: url.into(),
391            headers: Vec::new(),
392            body: serde_json::Value::Null,
393            timeout: None,
394        }
395    }
396
397    /// Add a header with owned strings.
398    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
399        self.headers
400            .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
401        self
402    }
403
404    /// Add a header with static strings (zero allocation).
405    pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
406        self.headers
407            .push((Cow::Borrowed(name), Cow::Borrowed(value)));
408        self
409    }
410
411    /// Set the body.
412    pub fn with_body(mut self, body: serde_json::Value) -> Self {
413        self.body = body;
414        self
415    }
416
417    /// Set the timeout.
418    pub fn with_timeout(mut self, timeout: Duration) -> Self {
419        self.timeout = Some(timeout);
420        self
421    }
422}
423
424/// Opaque provider-specific response.
425///
426/// This type encapsulates the raw HTTP response from a provider.
427#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
428pub struct ProviderResponse {
429    /// HTTP status code
430    pub status: u16,
431    /// Response body (JSON)
432    pub body: serde_json::Value,
433    /// Optional response headers
434    #[serde(skip_serializing_if = "Option::is_none")]
435    pub headers: Option<Vec<(String, String)>>,
436}
437
438impl ProviderResponse {
439    /// Create a new provider response.
440    pub fn new(status: u16, body: serde_json::Value) -> Self {
441        Self {
442            status,
443            body,
444            headers: None,
445        }
446    }
447
448    /// Check if response is successful (2xx).
449    pub fn is_success(&self) -> bool {
450        (200..300).contains(&self.status)
451    }
452
453    /// Check if response is a client error (4xx).
454    pub fn is_client_error(&self) -> bool {
455        (400..500).contains(&self.status)
456    }
457
458    /// Check if response is a server error (5xx).
459    pub fn is_server_error(&self) -> bool {
460        (500..600).contains(&self.status)
461    }
462
463    /// Add headers.
464    pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
465        self.headers = Some(headers);
466        self
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_provider_request_builder() {
476        let req = ProviderRequest::new("https://api.example.com/v1/completions")
477            .with_header("Authorization", "Bearer sk-test")
478            .with_header("Content-Type", "application/json")
479            .with_body(serde_json::json!({"model": "test"}))
480            .with_timeout(Duration::from_secs(30));
481
482        assert_eq!(req.url, "https://api.example.com/v1/completions");
483        assert_eq!(req.headers.len(), 2);
484        assert_eq!(req.body["model"], "test");
485        assert_eq!(req.timeout, Some(Duration::from_secs(30)));
486    }
487
488    #[test]
489    fn test_provider_response_status_checks() {
490        let resp = ProviderResponse::new(200, serde_json::json!({}));
491        assert!(resp.is_success());
492        assert!(!resp.is_client_error());
493        assert!(!resp.is_server_error());
494
495        let resp = ProviderResponse::new(404, serde_json::json!({}));
496        assert!(!resp.is_success());
497        assert!(resp.is_client_error());
498        assert!(!resp.is_server_error());
499
500        let resp = ProviderResponse::new(500, serde_json::json!({}));
501        assert!(!resp.is_success());
502        assert!(!resp.is_client_error());
503        assert!(resp.is_server_error());
504    }
505
506    #[test]
507    fn test_provider_request_serialization() {
508        let req = ProviderRequest::new("https://api.example.com")
509            .with_header("X-Test", "value")
510            .with_body(serde_json::json!({"key": "value"}));
511
512        let json = serde_json::to_string(&req).unwrap();
513        let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
514        assert_eq!(req, parsed);
515    }
516
517    #[test]
518    fn test_provider_response_serialization() {
519        let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
520            .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
521
522        let json = serde_json::to_string(&resp).unwrap();
523        let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
524        assert_eq!(resp, parsed);
525    }
526
527    // Test that Provider trait is object-safe
528    #[test]
529    fn test_provider_object_safety() {
530        fn _assert_object_safe(_: &dyn Provider) {}
531    }
532}