Skip to main content

simple_agent_type/
provider.rs

1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::borrow::Cow;
13use std::time::Duration;
14
15/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
16pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
17
18/// Common HTTP header names (static to avoid allocations)
19pub mod headers {
20    /// Authorization header
21    pub const AUTHORIZATION: &str = "Authorization";
22    /// Content-Type header
23    pub const CONTENT_TYPE: &str = "Content-Type";
24    /// API key header (used by some providers like Anthropic)
25    pub const X_API_KEY: &str = "x-api-key";
26}
27
28/// Trait for LLM providers.
29///
30/// Providers implement this trait to support different LLM APIs while
31/// presenting a unified interface to the rest of SimpleAgents.
32///
33/// # Architecture
34///
35/// The provider trait follows a three-phase architecture:
36/// 1. **Transform Request**: Convert unified request to provider format
37/// 2. **Execute**: Make the actual API call
38/// 3. **Transform Response**: Convert provider response to unified format
39///
40/// This design allows for:
41/// - Maximum flexibility in provider-specific transformations
42/// - Clean separation between protocol logic and business logic
43/// - Easy testing and mocking
44///
45/// # Example Implementation
46///
47/// ```rust
48/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
49/// use simple_agent_type::request::CompletionRequest;
50/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
51/// use simple_agent_type::message::Message;
52/// use simple_agent_type::error::Result;
53/// use async_trait::async_trait;
54///
55/// struct MyProvider;
56///
57/// #[async_trait]
58/// impl Provider for MyProvider {
59///     fn name(&self) -> &str {
60///         "my-provider"
61///     }
62///
63///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
64///         Ok(ProviderRequest::new("http://example.com"))
65///     }
66///
67///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
68///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
69///     }
70///
71///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
72///         Ok(CompletionResponse {
73///             id: "resp_1".to_string(),
74///             model: "dummy".to_string(),
75///             choices: vec![CompletionChoice {
76///                 index: 0,
77///                 message: Message::assistant("ok"),
78///                 finish_reason: FinishReason::Stop,
79///                 logprobs: None,
80///             }],
81///             usage: Usage::new(1, 1),
82///             created: None,
83///             provider: Some(self.name().to_string()),
84///             healing_metadata: None,
85///         })
86///     }
87/// }
88///
89/// let provider = MyProvider;
90/// let request = CompletionRequest::builder()
91///     .model("gpt-4")
92///     .message(Message::user("Hello!"))
93///     .build()
94///     .unwrap();
95///
96/// let rt = tokio::runtime::Runtime::new().unwrap();
97/// rt.block_on(async {
98///     let provider_request = provider.transform_request(&request).unwrap();
99///     let provider_response = provider.execute(provider_request).await.unwrap();
100///     let response = provider.transform_response(provider_response).unwrap();
101///     assert_eq!(response.content(), Some("ok"));
102/// });
103/// ```
104#[async_trait]
105pub trait Provider: Send + Sync {
106    /// Provider name (e.g., "openai", "anthropic").
107    fn name(&self) -> &str;
108
109    /// Transform unified request to provider-specific format.
110    ///
111    /// This method converts the standardized `CompletionRequest` into
112    /// the provider's native API format.
113    fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
114
115    /// Execute request against provider API.
116    ///
117    /// This method makes the actual HTTP request to the provider.
118    /// Implementations should handle:
119    /// - Authentication (API keys, tokens)
120    /// - Rate limiting
121    /// - Network errors
122    /// - Provider-specific error codes
123    async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
124
125    /// Transform provider response to unified format.
126    ///
127    /// This method converts the provider's native response format into
128    /// the standardized `CompletionResponse`.
129    fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
130
131    /// Get retry configuration.
132    ///
133    /// Override to customize retry behavior for this provider.
134    fn retry_config(&self) -> RetryConfig {
135        RetryConfig::default()
136    }
137
138    /// Get provider capabilities.
139    ///
140    /// Override to specify what features this provider supports.
141    fn capabilities(&self) -> Capabilities {
142        Capabilities::default()
143    }
144
145    /// Get default timeout.
146    fn timeout(&self) -> Duration {
147        Duration::from_secs(30)
148    }
149
150    /// Execute streaming request against provider API.
151    ///
152    /// This method returns a stream of completion chunks for streaming responses.
153    /// Not all providers support streaming - default implementation returns an error.
154    ///
155    /// # Arguments
156    /// - `req`: The provider-specific request
157    ///
158    /// # Returns
159    /// A boxed stream of Result<CompletionChunk>
160    ///
161    /// # Example
162    /// ```rust
163    /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
164    /// use simple_agent_type::request::CompletionRequest;
165    /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
166    /// use simple_agent_type::message::Message;
167    /// use simple_agent_type::error::Result;
168    /// use async_trait::async_trait;
169    /// use futures_core::Stream;
170    /// use std::pin::Pin;
171    /// use std::task::{Context, Poll};
172    ///
173    /// struct EmptyStream;
174    ///
175    /// impl Stream for EmptyStream {
176    ///     type Item = Result<CompletionChunk>;
177    ///
178    ///     fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179    ///         Poll::Ready(None)
180    ///     }
181    /// }
182    ///
183    /// struct StreamingProvider;
184    ///
185    /// #[async_trait]
186    /// impl Provider for StreamingProvider {
187    ///     fn name(&self) -> &str {
188    ///         "streaming-provider"
189    ///     }
190    ///
191    ///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
192    ///         Ok(ProviderRequest::new("http://example.com"))
193    ///     }
194    ///
195    ///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
196    ///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
197    ///     }
198    ///
199    ///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
200    ///         Ok(CompletionResponse {
201    ///             id: "resp_1".to_string(),
202    ///             model: "dummy".to_string(),
203    ///             choices: vec![CompletionChoice {
204    ///                 index: 0,
205    ///                 message: Message::assistant("ok"),
206    ///                 finish_reason: FinishReason::Stop,
207    ///                 logprobs: None,
208    ///             }],
209    ///             usage: Usage::new(1, 1),
210    ///             created: None,
211    ///             provider: None,
212    ///             healing_metadata: None,
213    ///         })
214    ///     }
215    ///
216    ///     async fn execute_stream(
217    ///         &self,
218    ///         _req: ProviderRequest,
219    ///     ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
220    ///         Ok(Box::new(EmptyStream))
221    ///     }
222    /// }
223    ///
224    /// let provider = StreamingProvider;
225    /// let request = CompletionRequest::builder()
226    ///     .model("gpt-4")
227    ///     .message(Message::user("Hello!"))
228    ///     .build()
229    ///     .unwrap();
230    ///
231    /// let rt = tokio::runtime::Runtime::new().unwrap();
232    /// rt.block_on(async {
233    ///     let provider_request = provider.transform_request(&request).unwrap();
234    ///     let _stream = provider.execute_stream(provider_request).await.unwrap();
235    ///     // Use StreamExt::next to consume the stream in real usage.
236    /// });
237    /// ```
238    async fn execute_stream(
239        &self,
240        mut req: ProviderRequest,
241    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
242        if let Value::Object(map) = &mut req.body {
243            if let Some(stream_value) = map.get_mut("stream") {
244                *stream_value = Value::Bool(false);
245            }
246        }
247
248        let provider_response = self.execute(req).await?;
249        let response = self.transform_response(provider_response)?;
250
251        struct SingleChunkStream {
252            chunk: Option<Result<CompletionChunk>>,
253        }
254
255        impl futures_core::Stream for SingleChunkStream {
256            type Item = Result<CompletionChunk>;
257
258            fn poll_next(
259                mut self: std::pin::Pin<&mut Self>,
260                _cx: &mut std::task::Context<'_>,
261            ) -> std::task::Poll<Option<Self::Item>> {
262                std::task::Poll::Ready(self.chunk.take())
263            }
264        }
265
266        let choices = response
267            .choices
268            .into_iter()
269            .map(|choice| crate::response::ChoiceDelta {
270                index: choice.index,
271                delta: crate::response::MessageDelta {
272                    role: Some(choice.message.role),
273                    content: Some(choice.message.content),
274                    reasoning_content: None,
275                    tool_calls: None,
276                },
277                finish_reason: Some(choice.finish_reason),
278            })
279            .collect();
280
281        let chunk = CompletionChunk {
282            id: response.id,
283            model: response.model,
284            choices,
285            created: response.created,
286            usage: Some(response.usage),
287        };
288
289        Ok(Box::new(SingleChunkStream {
290            chunk: Some(Ok(chunk)),
291        }))
292    }
293}
294
295/// Opaque provider-specific request.
296///
297/// This type encapsulates all information needed to make an HTTP request
298/// to a provider, without committing to a specific HTTP client library.
299///
300/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
301#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
302pub struct ProviderRequest {
303    /// Full URL to send request to
304    pub url: String,
305    /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
306    #[serde(with = "header_serde")]
307    pub headers: Headers,
308    /// Request body (JSON)
309    pub body: serde_json::Value,
310    /// Optional request timeout override
311    #[serde(skip_serializing_if = "Option::is_none")]
312    pub timeout: Option<Duration>,
313}
314
315// Custom serde for Cow headers
316mod header_serde {
317    use super::Headers;
318    use serde::{Deserialize, Deserializer, Serialize, Serializer};
319    use std::borrow::Cow;
320
321    pub fn serialize<S>(
322        headers: &[(Cow<'static, str>, Cow<'static, str>)],
323        serializer: S,
324    ) -> Result<S::Ok, S::Error>
325    where
326        S: Serializer,
327    {
328        let string_headers: Vec<(&str, &str)> = headers
329            .iter()
330            .map(|(k, v)| (k.as_ref(), v.as_ref()))
331            .collect();
332        string_headers.serialize(serializer)
333    }
334
335    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
336    where
337        D: Deserializer<'de>,
338    {
339        let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
340        Ok(string_headers
341            .into_iter()
342            .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
343            .collect())
344    }
345}
346
347impl ProviderRequest {
348    /// Create a new provider request.
349    pub fn new(url: impl Into<String>) -> Self {
350        Self {
351            url: url.into(),
352            headers: Vec::new(),
353            body: serde_json::Value::Null,
354            timeout: None,
355        }
356    }
357
358    /// Add a header with owned strings.
359    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
360        self.headers
361            .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
362        self
363    }
364
365    /// Add a header with static strings (zero allocation).
366    pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
367        self.headers
368            .push((Cow::Borrowed(name), Cow::Borrowed(value)));
369        self
370    }
371
372    /// Set the body.
373    pub fn with_body(mut self, body: serde_json::Value) -> Self {
374        self.body = body;
375        self
376    }
377
378    /// Set the timeout.
379    pub fn with_timeout(mut self, timeout: Duration) -> Self {
380        self.timeout = Some(timeout);
381        self
382    }
383}
384
385/// Opaque provider-specific response.
386///
387/// This type encapsulates the raw HTTP response from a provider.
388#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
389pub struct ProviderResponse {
390    /// HTTP status code
391    pub status: u16,
392    /// Response body (JSON)
393    pub body: serde_json::Value,
394    /// Optional response headers
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub headers: Option<Vec<(String, String)>>,
397}
398
399impl ProviderResponse {
400    /// Create a new provider response.
401    pub fn new(status: u16, body: serde_json::Value) -> Self {
402        Self {
403            status,
404            body,
405            headers: None,
406        }
407    }
408
409    /// Check if response is successful (2xx).
410    pub fn is_success(&self) -> bool {
411        (200..300).contains(&self.status)
412    }
413
414    /// Check if response is a client error (4xx).
415    pub fn is_client_error(&self) -> bool {
416        (400..500).contains(&self.status)
417    }
418
419    /// Check if response is a server error (5xx).
420    pub fn is_server_error(&self) -> bool {
421        (500..600).contains(&self.status)
422    }
423
424    /// Add headers.
425    pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
426        self.headers = Some(headers);
427        self
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_provider_request_builder() {
437        let req = ProviderRequest::new("https://api.example.com/v1/completions")
438            .with_header("Authorization", "Bearer sk-test")
439            .with_header("Content-Type", "application/json")
440            .with_body(serde_json::json!({"model": "test"}))
441            .with_timeout(Duration::from_secs(30));
442
443        assert_eq!(req.url, "https://api.example.com/v1/completions");
444        assert_eq!(req.headers.len(), 2);
445        assert_eq!(req.body["model"], "test");
446        assert_eq!(req.timeout, Some(Duration::from_secs(30)));
447    }
448
449    #[test]
450    fn test_provider_response_status_checks() {
451        let resp = ProviderResponse::new(200, serde_json::json!({}));
452        assert!(resp.is_success());
453        assert!(!resp.is_client_error());
454        assert!(!resp.is_server_error());
455
456        let resp = ProviderResponse::new(404, serde_json::json!({}));
457        assert!(!resp.is_success());
458        assert!(resp.is_client_error());
459        assert!(!resp.is_server_error());
460
461        let resp = ProviderResponse::new(500, serde_json::json!({}));
462        assert!(!resp.is_success());
463        assert!(!resp.is_client_error());
464        assert!(resp.is_server_error());
465    }
466
467    #[test]
468    fn test_provider_request_serialization() {
469        let req = ProviderRequest::new("https://api.example.com")
470            .with_header("X-Test", "value")
471            .with_body(serde_json::json!({"key": "value"}));
472
473        let json = serde_json::to_string(&req).unwrap();
474        let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
475        assert_eq!(req, parsed);
476    }
477
478    #[test]
479    fn test_provider_response_serialization() {
480        let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
481            .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
482
483        let json = serde_json::to_string(&resp).unwrap();
484        let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
485        assert_eq!(resp, parsed);
486    }
487
488    // Test that Provider trait is object-safe
489    #[test]
490    fn test_provider_object_safety() {
491        fn _assert_object_safe(_: &dyn Provider) {}
492    }
493}