Skip to main content

simple_agent_type/
provider.rs

1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::{ProviderError, Result, SimpleAgentsError};
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::borrow::Cow;
12use std::time::Duration;
13
14/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
15pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
16
17/// Common HTTP header names (static to avoid allocations)
18pub mod headers {
19    /// Authorization header
20    pub const AUTHORIZATION: &str = "Authorization";
21    /// Content-Type header
22    pub const CONTENT_TYPE: &str = "Content-Type";
23    /// API key header (used by some providers like Anthropic)
24    pub const X_API_KEY: &str = "x-api-key";
25}
26
27/// Trait for LLM providers.
28///
29/// Providers implement this trait to support different LLM APIs while
30/// presenting a unified interface to the rest of SimpleAgents.
31///
32/// # Architecture
33///
34/// The provider trait follows a three-phase architecture:
35/// 1. **Transform Request**: Convert unified request to provider format
36/// 2. **Execute**: Make the actual API call
37/// 3. **Transform Response**: Convert provider response to unified format
38///
39/// This design allows for:
40/// - Maximum flexibility in provider-specific transformations
41/// - Clean separation between protocol logic and business logic
42/// - Easy testing and mocking
43///
44/// # Example Implementation
45///
46/// ```rust
47/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
48/// use simple_agent_type::request::CompletionRequest;
49/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
50/// use simple_agent_type::message::Message;
51/// use simple_agent_type::error::Result;
52/// use async_trait::async_trait;
53///
54/// struct MyProvider;
55///
56/// #[async_trait]
57/// impl Provider for MyProvider {
58///     fn name(&self) -> &str {
59///         "my-provider"
60///     }
61///
62///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
63///         Ok(ProviderRequest::new("http://example.com"))
64///     }
65///
66///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
67///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
68///     }
69///
70///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
71///         Ok(CompletionResponse {
72///             id: "resp_1".to_string(),
73///             model: "dummy".to_string(),
74///             choices: vec![CompletionChoice {
75///                 index: 0,
76///                 message: Message::assistant("ok"),
77///                 finish_reason: FinishReason::Stop,
78///                 logprobs: None,
79///             }],
80///             usage: Usage::new(1, 1),
81///             created: None,
82///             provider: Some(self.name().to_string()),
83///             healing_metadata: None,
84///         })
85///     }
86/// }
87///
88/// let provider = MyProvider;
89/// let request = CompletionRequest::builder()
90///     .model("gpt-4")
91///     .message(Message::user("Hello!"))
92///     .build()
93///     .unwrap();
94///
95/// let rt = tokio::runtime::Runtime::new().unwrap();
96/// rt.block_on(async {
97///     let provider_request = provider.transform_request(&request).unwrap();
98///     let provider_response = provider.execute(provider_request).await.unwrap();
99///     let response = provider.transform_response(provider_response).unwrap();
100///     assert_eq!(response.content(), Some("ok"));
101/// });
102/// ```
103#[async_trait]
104pub trait Provider: Send + Sync {
105    /// Provider name (e.g., "openai", "anthropic").
106    fn name(&self) -> &str;
107
108    /// Transform unified request to provider-specific format.
109    ///
110    /// This method converts the standardized `CompletionRequest` into
111    /// the provider's native API format.
112    fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
113
114    /// Execute request against provider API.
115    ///
116    /// This method makes the actual HTTP request to the provider.
117    /// Implementations should handle:
118    /// - Authentication (API keys, tokens)
119    /// - Rate limiting
120    /// - Network errors
121    /// - Provider-specific error codes
122    async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
123
124    /// Transform provider response to unified format.
125    ///
126    /// This method converts the provider's native response format into
127    /// the standardized `CompletionResponse`.
128    fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
129
130    /// Get retry configuration.
131    ///
132    /// Override to customize retry behavior for this provider.
133    fn retry_config(&self) -> RetryConfig {
134        RetryConfig::default()
135    }
136
137    /// Get provider capabilities.
138    ///
139    /// Override to specify what features this provider supports.
140    fn capabilities(&self) -> Capabilities {
141        Capabilities::default()
142    }
143
144    /// Get default timeout.
145    fn timeout(&self) -> Duration {
146        Duration::from_secs(30)
147    }
148
149    /// Execute streaming request against provider API.
150    ///
151    /// This method returns a stream of completion chunks for streaming responses.
152    /// Not all providers support streaming - default implementation returns an error.
153    ///
154    /// # Arguments
155    /// - `req`: The provider-specific request
156    ///
157    /// # Returns
158    /// A boxed stream of Result<CompletionChunk>
159    ///
160    /// # Example
161    /// ```rust
162    /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
163    /// use simple_agent_type::request::CompletionRequest;
164    /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
165    /// use simple_agent_type::message::Message;
166    /// use simple_agent_type::error::Result;
167    /// use async_trait::async_trait;
168    /// use futures_core::Stream;
169    /// use std::pin::Pin;
170    /// use std::task::{Context, Poll};
171    ///
172    /// struct EmptyStream;
173    ///
174    /// impl Stream for EmptyStream {
175    ///     type Item = Result<CompletionChunk>;
176    ///
177    ///     fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178    ///         Poll::Ready(None)
179    ///     }
180    /// }
181    ///
182    /// struct StreamingProvider;
183    ///
184    /// #[async_trait]
185    /// impl Provider for StreamingProvider {
186    ///     fn name(&self) -> &str {
187    ///         "streaming-provider"
188    ///     }
189    ///
190    ///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
191    ///         Ok(ProviderRequest::new("http://example.com"))
192    ///     }
193    ///
194    ///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
195    ///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
196    ///     }
197    ///
198    ///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
199    ///         Ok(CompletionResponse {
200    ///             id: "resp_1".to_string(),
201    ///             model: "dummy".to_string(),
202    ///             choices: vec![CompletionChoice {
203    ///                 index: 0,
204    ///                 message: Message::assistant("ok"),
205    ///                 finish_reason: FinishReason::Stop,
206    ///                 logprobs: None,
207    ///             }],
208    ///             usage: Usage::new(1, 1),
209    ///             created: None,
210    ///             provider: None,
211    ///             healing_metadata: None,
212    ///         })
213    ///     }
214    ///
215    ///     async fn execute_stream(
216    ///         &self,
217    ///         _req: ProviderRequest,
218    ///     ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
219    ///         Ok(Box::new(EmptyStream))
220    ///     }
221    /// }
222    ///
223    /// let provider = StreamingProvider;
224    /// let request = CompletionRequest::builder()
225    ///     .model("gpt-4")
226    ///     .message(Message::user("Hello!"))
227    ///     .build()
228    ///     .unwrap();
229    ///
230    /// let rt = tokio::runtime::Runtime::new().unwrap();
231    /// rt.block_on(async {
232    ///     let provider_request = provider.transform_request(&request).unwrap();
233    ///     let _stream = provider.execute_stream(provider_request).await.unwrap();
234    ///     // Use StreamExt::next to consume the stream in real usage.
235    /// });
236    /// ```
237    async fn execute_stream(
238        &self,
239        _req: ProviderRequest,
240    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
241        Err(SimpleAgentsError::Provider(
242            ProviderError::UnsupportedFeature("streaming".to_string()),
243        ))
244    }
245}
246
247/// Opaque provider-specific request.
248///
249/// This type encapsulates all information needed to make an HTTP request
250/// to a provider, without committing to a specific HTTP client library.
251///
252/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254pub struct ProviderRequest {
255    /// Full URL to send request to
256    pub url: String,
257    /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
258    #[serde(with = "header_serde")]
259    pub headers: Headers,
260    /// Request body (JSON)
261    pub body: serde_json::Value,
262    /// Optional request timeout override
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub timeout: Option<Duration>,
265}
266
267// Custom serde for Cow headers
268mod header_serde {
269    use super::Headers;
270    use serde::{Deserialize, Deserializer, Serialize, Serializer};
271    use std::borrow::Cow;
272
273    pub fn serialize<S>(
274        headers: &[(Cow<'static, str>, Cow<'static, str>)],
275        serializer: S,
276    ) -> Result<S::Ok, S::Error>
277    where
278        S: Serializer,
279    {
280        let string_headers: Vec<(&str, &str)> = headers
281            .iter()
282            .map(|(k, v)| (k.as_ref(), v.as_ref()))
283            .collect();
284        string_headers.serialize(serializer)
285    }
286
287    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
288    where
289        D: Deserializer<'de>,
290    {
291        let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
292        Ok(string_headers
293            .into_iter()
294            .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
295            .collect())
296    }
297}
298
299impl ProviderRequest {
300    /// Create a new provider request.
301    pub fn new(url: impl Into<String>) -> Self {
302        Self {
303            url: url.into(),
304            headers: Vec::new(),
305            body: serde_json::Value::Null,
306            timeout: None,
307        }
308    }
309
310    /// Add a header with owned strings.
311    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
312        self.headers
313            .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
314        self
315    }
316
317    /// Add a header with static strings (zero allocation).
318    pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
319        self.headers
320            .push((Cow::Borrowed(name), Cow::Borrowed(value)));
321        self
322    }
323
324    /// Set the body.
325    pub fn with_body(mut self, body: serde_json::Value) -> Self {
326        self.body = body;
327        self
328    }
329
330    /// Set the timeout.
331    pub fn with_timeout(mut self, timeout: Duration) -> Self {
332        self.timeout = Some(timeout);
333        self
334    }
335}
336
337/// Opaque provider-specific response.
338///
339/// This type encapsulates the raw HTTP response from a provider.
340#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
341pub struct ProviderResponse {
342    /// HTTP status code
343    pub status: u16,
344    /// Response body (JSON)
345    pub body: serde_json::Value,
346    /// Optional response headers
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub headers: Option<Vec<(String, String)>>,
349}
350
351impl ProviderResponse {
352    /// Create a new provider response.
353    pub fn new(status: u16, body: serde_json::Value) -> Self {
354        Self {
355            status,
356            body,
357            headers: None,
358        }
359    }
360
361    /// Check if response is successful (2xx).
362    pub fn is_success(&self) -> bool {
363        (200..300).contains(&self.status)
364    }
365
366    /// Check if response is a client error (4xx).
367    pub fn is_client_error(&self) -> bool {
368        (400..500).contains(&self.status)
369    }
370
371    /// Check if response is a server error (5xx).
372    pub fn is_server_error(&self) -> bool {
373        (500..600).contains(&self.status)
374    }
375
376    /// Add headers.
377    pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
378        self.headers = Some(headers);
379        self
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_provider_request_builder() {
389        let req = ProviderRequest::new("https://api.example.com/v1/completions")
390            .with_header("Authorization", "Bearer sk-test")
391            .with_header("Content-Type", "application/json")
392            .with_body(serde_json::json!({"model": "test"}))
393            .with_timeout(Duration::from_secs(30));
394
395        assert_eq!(req.url, "https://api.example.com/v1/completions");
396        assert_eq!(req.headers.len(), 2);
397        assert_eq!(req.body["model"], "test");
398        assert_eq!(req.timeout, Some(Duration::from_secs(30)));
399    }
400
401    #[test]
402    fn test_provider_response_status_checks() {
403        let resp = ProviderResponse::new(200, serde_json::json!({}));
404        assert!(resp.is_success());
405        assert!(!resp.is_client_error());
406        assert!(!resp.is_server_error());
407
408        let resp = ProviderResponse::new(404, serde_json::json!({}));
409        assert!(!resp.is_success());
410        assert!(resp.is_client_error());
411        assert!(!resp.is_server_error());
412
413        let resp = ProviderResponse::new(500, serde_json::json!({}));
414        assert!(!resp.is_success());
415        assert!(!resp.is_client_error());
416        assert!(resp.is_server_error());
417    }
418
419    #[test]
420    fn test_provider_request_serialization() {
421        let req = ProviderRequest::new("https://api.example.com")
422            .with_header("X-Test", "value")
423            .with_body(serde_json::json!({"key": "value"}));
424
425        let json = serde_json::to_string(&req).unwrap();
426        let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
427        assert_eq!(req, parsed);
428    }
429
430    #[test]
431    fn test_provider_response_serialization() {
432        let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
433            .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
434
435        let json = serde_json::to_string(&resp).unwrap();
436        let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
437        assert_eq!(resp, parsed);
438    }
439
440    // Test that Provider trait is object-safe
441    #[test]
442    fn test_provider_object_safety() {
443        fn _assert_object_safe(_: &dyn Provider) {}
444    }
445}