Skip to main content

simple_agent_type/
provider.rs

1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::borrow::Cow;
13use std::time::Duration;
14
15/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
16pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
17
18/// Common HTTP header names (static to avoid allocations)
19pub mod headers {
20    /// Authorization header
21    pub const AUTHORIZATION: &str = "Authorization";
22    /// Content-Type header
23    pub const CONTENT_TYPE: &str = "Content-Type";
24    /// API key header (used by some providers like Anthropic)
25    pub const X_API_KEY: &str = "x-api-key";
26}
27
28/// Trait for LLM providers.
29///
30/// Providers implement this trait to support different LLM APIs while
31/// presenting a unified interface to the rest of SimpleAgents.
32///
33/// # Architecture
34///
35/// The provider trait follows a three-phase architecture:
36/// 1. **Transform Request**: Convert unified request to provider format
37/// 2. **Execute**: Make the actual API call
38/// 3. **Transform Response**: Convert provider response to unified format
39///
40/// This design allows for:
41/// - Maximum flexibility in provider-specific transformations
42/// - Clean separation between protocol logic and business logic
43/// - Easy testing and mocking
44///
45/// # Example Implementation
46///
47/// ```rust
48/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
49/// use simple_agent_type::request::CompletionRequest;
50/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
51/// use simple_agent_type::message::Message;
52/// use simple_agent_type::error::Result;
53/// use async_trait::async_trait;
54///
55/// struct MyProvider;
56///
57/// #[async_trait]
58/// impl Provider for MyProvider {
59///     fn name(&self) -> &str {
60///         "my-provider"
61///     }
62///
63///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
64///         Ok(ProviderRequest::new("http://example.com"))
65///     }
66///
67///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
68///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
69///     }
70///
71///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
72///         Ok(CompletionResponse {
73///             id: "resp_1".to_string(),
74///             model: "dummy".to_string(),
75///             choices: vec![CompletionChoice {
76///                 index: 0,
77///                 message: Message::assistant("ok"),
78///                 finish_reason: FinishReason::Stop,
79///                 logprobs: None,
80///             }],
81///             usage: Usage::new(1, 1),
82///             created: None,
83///             provider: Some(self.name().to_string()),
84///             healing_metadata: None,
85///         })
86///     }
87/// }
88///
89/// let provider = MyProvider;
90/// let request = CompletionRequest::builder()
91///     .model("gpt-4")
92///     .message(Message::user("Hello!"))
93///     .build()
94///     .unwrap();
95///
96/// let rt = tokio::runtime::Runtime::new().unwrap();
97/// rt.block_on(async {
98///     let provider_request = provider.transform_request(&request).unwrap();
99///     let provider_response = provider.execute(provider_request).await.unwrap();
100///     let response = provider.transform_response(provider_response).unwrap();
101///     assert_eq!(response.content(), Some("ok"));
102/// });
103/// ```
104#[async_trait]
105pub trait Provider: Send + Sync {
106    /// Provider name (e.g., "openai", "anthropic").
107    fn name(&self) -> &str;
108
109    /// Transform unified request to provider-specific format.
110    ///
111    /// This method converts the standardized `CompletionRequest` into
112    /// the provider's native API format.
113    fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
114
115    /// Execute request against provider API.
116    ///
117    /// This method makes the actual HTTP request to the provider.
118    /// Implementations should handle:
119    /// - Authentication (API keys, tokens)
120    /// - Rate limiting
121    /// - Network errors
122    /// - Provider-specific error codes
123    async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
124
125    /// Transform provider response to unified format.
126    ///
127    /// This method converts the provider's native response format into
128    /// the standardized `CompletionResponse`.
129    fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
130
131    /// Get retry configuration.
132    ///
133    /// Override to customize retry behavior for this provider.
134    fn retry_config(&self) -> RetryConfig {
135        RetryConfig::default()
136    }
137
138    /// Get provider capabilities.
139    ///
140    /// Override to specify what features this provider supports.
141    fn capabilities(&self) -> Capabilities {
142        Capabilities::default()
143    }
144
145    /// Get default timeout.
146    fn timeout(&self) -> Duration {
147        Duration::from_secs(30)
148    }
149
150    /// Execute streaming request against provider API.
151    ///
152    /// This method returns a stream of completion chunks for streaming responses.
153    /// Not all providers support streaming - default implementation returns an error.
154    ///
155    /// # Arguments
156    /// - `req`: The provider-specific request
157    ///
158    /// # Returns
159    /// A boxed stream of Result<CompletionChunk>
160    ///
161    /// # Example
162    /// ```rust
163    /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
164    /// use simple_agent_type::request::CompletionRequest;
165    /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
166    /// use simple_agent_type::message::Message;
167    /// use simple_agent_type::error::Result;
168    /// use async_trait::async_trait;
169    /// use futures_core::Stream;
170    /// use std::pin::Pin;
171    /// use std::task::{Context, Poll};
172    ///
173    /// struct EmptyStream;
174    ///
175    /// impl Stream for EmptyStream {
176    ///     type Item = Result<CompletionChunk>;
177    ///
178    ///     fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179    ///         Poll::Ready(None)
180    ///     }
181    /// }
182    ///
183    /// struct StreamingProvider;
184    ///
185    /// #[async_trait]
186    /// impl Provider for StreamingProvider {
187    ///     fn name(&self) -> &str {
188    ///         "streaming-provider"
189    ///     }
190    ///
191    ///     fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
192    ///         Ok(ProviderRequest::new("http://example.com"))
193    ///     }
194    ///
195    ///     async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
196    ///         Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
197    ///     }
198    ///
199    ///     fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
200    ///         Ok(CompletionResponse {
201    ///             id: "resp_1".to_string(),
202    ///             model: "dummy".to_string(),
203    ///             choices: vec![CompletionChoice {
204    ///                 index: 0,
205    ///                 message: Message::assistant("ok"),
206    ///                 finish_reason: FinishReason::Stop,
207    ///                 logprobs: None,
208    ///             }],
209    ///             usage: Usage::new(1, 1),
210    ///             created: None,
211    ///             provider: None,
212    ///             healing_metadata: None,
213    ///         })
214    ///     }
215    ///
216    ///     async fn execute_stream(
217    ///         &self,
218    ///         _req: ProviderRequest,
219    ///     ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
220    ///         Ok(Box::new(EmptyStream))
221    ///     }
222    /// }
223    ///
224    /// let provider = StreamingProvider;
225    /// let request = CompletionRequest::builder()
226    ///     .model("gpt-4")
227    ///     .message(Message::user("Hello!"))
228    ///     .build()
229    ///     .unwrap();
230    ///
231    /// let rt = tokio::runtime::Runtime::new().unwrap();
232    /// rt.block_on(async {
233    ///     let provider_request = provider.transform_request(&request).unwrap();
234    ///     let _stream = provider.execute_stream(provider_request).await.unwrap();
235    ///     // Use StreamExt::next to consume the stream in real usage.
236    /// });
237    /// ```
238    async fn execute_stream(
239        &self,
240        mut req: ProviderRequest,
241    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
242        if let Value::Object(map) = &mut req.body {
243            if let Some(stream_value) = map.get_mut("stream") {
244                *stream_value = Value::Bool(false);
245            }
246        }
247
248        let provider_response = self.execute(req).await?;
249        let response = self.transform_response(provider_response)?;
250
251        struct SingleChunkStream {
252            chunk: Option<Result<CompletionChunk>>,
253        }
254
255        impl futures_core::Stream for SingleChunkStream {
256            type Item = Result<CompletionChunk>;
257
258            fn poll_next(
259                mut self: std::pin::Pin<&mut Self>,
260                _cx: &mut std::task::Context<'_>,
261            ) -> std::task::Poll<Option<Self::Item>> {
262                std::task::Poll::Ready(self.chunk.take())
263            }
264        }
265
266        let choices = response
267            .choices
268            .into_iter()
269            .map(|choice| crate::response::ChoiceDelta {
270                index: choice.index,
271                delta: crate::response::MessageDelta {
272                    role: Some(choice.message.role),
273                    content: Some(choice.message.content),
274                },
275                finish_reason: Some(choice.finish_reason),
276            })
277            .collect();
278
279        let chunk = CompletionChunk {
280            id: response.id,
281            model: response.model,
282            choices,
283            created: response.created,
284        };
285
286        Ok(Box::new(SingleChunkStream {
287            chunk: Some(Ok(chunk)),
288        }))
289    }
290}
291
292/// Opaque provider-specific request.
293///
294/// This type encapsulates all information needed to make an HTTP request
295/// to a provider, without committing to a specific HTTP client library.
296///
297/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
298#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
299pub struct ProviderRequest {
300    /// Full URL to send request to
301    pub url: String,
302    /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
303    #[serde(with = "header_serde")]
304    pub headers: Headers,
305    /// Request body (JSON)
306    pub body: serde_json::Value,
307    /// Optional request timeout override
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub timeout: Option<Duration>,
310}
311
312// Custom serde for Cow headers
313mod header_serde {
314    use super::Headers;
315    use serde::{Deserialize, Deserializer, Serialize, Serializer};
316    use std::borrow::Cow;
317
318    pub fn serialize<S>(
319        headers: &[(Cow<'static, str>, Cow<'static, str>)],
320        serializer: S,
321    ) -> Result<S::Ok, S::Error>
322    where
323        S: Serializer,
324    {
325        let string_headers: Vec<(&str, &str)> = headers
326            .iter()
327            .map(|(k, v)| (k.as_ref(), v.as_ref()))
328            .collect();
329        string_headers.serialize(serializer)
330    }
331
332    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
333    where
334        D: Deserializer<'de>,
335    {
336        let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
337        Ok(string_headers
338            .into_iter()
339            .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
340            .collect())
341    }
342}
343
344impl ProviderRequest {
345    /// Create a new provider request.
346    pub fn new(url: impl Into<String>) -> Self {
347        Self {
348            url: url.into(),
349            headers: Vec::new(),
350            body: serde_json::Value::Null,
351            timeout: None,
352        }
353    }
354
355    /// Add a header with owned strings.
356    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
357        self.headers
358            .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
359        self
360    }
361
362    /// Add a header with static strings (zero allocation).
363    pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
364        self.headers
365            .push((Cow::Borrowed(name), Cow::Borrowed(value)));
366        self
367    }
368
369    /// Set the body.
370    pub fn with_body(mut self, body: serde_json::Value) -> Self {
371        self.body = body;
372        self
373    }
374
375    /// Set the timeout.
376    pub fn with_timeout(mut self, timeout: Duration) -> Self {
377        self.timeout = Some(timeout);
378        self
379    }
380}
381
382/// Opaque provider-specific response.
383///
384/// This type encapsulates the raw HTTP response from a provider.
385#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
386pub struct ProviderResponse {
387    /// HTTP status code
388    pub status: u16,
389    /// Response body (JSON)
390    pub body: serde_json::Value,
391    /// Optional response headers
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub headers: Option<Vec<(String, String)>>,
394}
395
396impl ProviderResponse {
397    /// Create a new provider response.
398    pub fn new(status: u16, body: serde_json::Value) -> Self {
399        Self {
400            status,
401            body,
402            headers: None,
403        }
404    }
405
406    /// Check if response is successful (2xx).
407    pub fn is_success(&self) -> bool {
408        (200..300).contains(&self.status)
409    }
410
411    /// Check if response is a client error (4xx).
412    pub fn is_client_error(&self) -> bool {
413        (400..500).contains(&self.status)
414    }
415
416    /// Check if response is a server error (5xx).
417    pub fn is_server_error(&self) -> bool {
418        (500..600).contains(&self.status)
419    }
420
421    /// Add headers.
422    pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
423        self.headers = Some(headers);
424        self
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_provider_request_builder() {
434        let req = ProviderRequest::new("https://api.example.com/v1/completions")
435            .with_header("Authorization", "Bearer sk-test")
436            .with_header("Content-Type", "application/json")
437            .with_body(serde_json::json!({"model": "test"}))
438            .with_timeout(Duration::from_secs(30));
439
440        assert_eq!(req.url, "https://api.example.com/v1/completions");
441        assert_eq!(req.headers.len(), 2);
442        assert_eq!(req.body["model"], "test");
443        assert_eq!(req.timeout, Some(Duration::from_secs(30)));
444    }
445
446    #[test]
447    fn test_provider_response_status_checks() {
448        let resp = ProviderResponse::new(200, serde_json::json!({}));
449        assert!(resp.is_success());
450        assert!(!resp.is_client_error());
451        assert!(!resp.is_server_error());
452
453        let resp = ProviderResponse::new(404, serde_json::json!({}));
454        assert!(!resp.is_success());
455        assert!(resp.is_client_error());
456        assert!(!resp.is_server_error());
457
458        let resp = ProviderResponse::new(500, serde_json::json!({}));
459        assert!(!resp.is_success());
460        assert!(!resp.is_client_error());
461        assert!(resp.is_server_error());
462    }
463
464    #[test]
465    fn test_provider_request_serialization() {
466        let req = ProviderRequest::new("https://api.example.com")
467            .with_header("X-Test", "value")
468            .with_body(serde_json::json!({"key": "value"}));
469
470        let json = serde_json::to_string(&req).unwrap();
471        let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
472        assert_eq!(req, parsed);
473    }
474
475    #[test]
476    fn test_provider_response_serialization() {
477        let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
478            .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
479
480        let json = serde_json::to_string(&resp).unwrap();
481        let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
482        assert_eq!(resp, parsed);
483    }
484
485    // Test that Provider trait is object-safe
486    #[test]
487    fn test_provider_object_safety() {
488        fn _assert_object_safe(_: &dyn Provider) {}
489    }
490}