simple_agent_type/provider.rs
1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::{ProviderError, Result, SimpleAgentsError};
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::borrow::Cow;
12use std::time::Duration;
13
14/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
15pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
16
17/// Common HTTP header names (static to avoid allocations)
18pub mod headers {
19 /// Authorization header
20 pub const AUTHORIZATION: &str = "Authorization";
21 /// Content-Type header
22 pub const CONTENT_TYPE: &str = "Content-Type";
23 /// API key header (used by some providers like Anthropic)
24 pub const X_API_KEY: &str = "x-api-key";
25}
26
27/// Trait for LLM providers.
28///
29/// Providers implement this trait to support different LLM APIs while
30/// presenting a unified interface to the rest of SimpleAgents.
31///
32/// # Architecture
33///
34/// The provider trait follows a three-phase architecture:
35/// 1. **Transform Request**: Convert unified request to provider format
36/// 2. **Execute**: Make the actual API call
37/// 3. **Transform Response**: Convert provider response to unified format
38///
39/// This design allows for:
40/// - Maximum flexibility in provider-specific transformations
41/// - Clean separation between protocol logic and business logic
42/// - Easy testing and mocking
43///
44/// # Example Implementation
45///
46/// ```rust
47/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
48/// use simple_agent_type::request::CompletionRequest;
49/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
50/// use simple_agent_type::message::Message;
51/// use simple_agent_type::error::Result;
52/// use async_trait::async_trait;
53///
54/// struct MyProvider;
55///
56/// #[async_trait]
57/// impl Provider for MyProvider {
58/// fn name(&self) -> &str {
59/// "my-provider"
60/// }
61///
62/// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
63/// Ok(ProviderRequest::new("http://example.com"))
64/// }
65///
66/// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
67/// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
68/// }
69///
70/// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
71/// Ok(CompletionResponse {
72/// id: "resp_1".to_string(),
73/// model: "dummy".to_string(),
74/// choices: vec![CompletionChoice {
75/// index: 0,
76/// message: Message::assistant("ok"),
77/// finish_reason: FinishReason::Stop,
78/// logprobs: None,
79/// }],
80/// usage: Usage::new(1, 1),
81/// created: None,
82/// provider: Some(self.name().to_string()),
83/// healing_metadata: None,
84/// })
85/// }
86/// }
87///
88/// let provider = MyProvider;
89/// let request = CompletionRequest::builder()
90/// .model("gpt-4")
91/// .message(Message::user("Hello!"))
92/// .build()
93/// .unwrap();
94///
95/// let rt = tokio::runtime::Runtime::new().unwrap();
96/// rt.block_on(async {
97/// let provider_request = provider.transform_request(&request).unwrap();
98/// let provider_response = provider.execute(provider_request).await.unwrap();
99/// let response = provider.transform_response(provider_response).unwrap();
100/// assert_eq!(response.content(), Some("ok"));
101/// });
102/// ```
103#[async_trait]
104pub trait Provider: Send + Sync {
105 /// Provider name (e.g., "openai", "anthropic").
106 fn name(&self) -> &str;
107
108 /// Transform unified request to provider-specific format.
109 ///
110 /// This method converts the standardized `CompletionRequest` into
111 /// the provider's native API format.
112 fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
113
114 /// Execute request against provider API.
115 ///
116 /// This method makes the actual HTTP request to the provider.
117 /// Implementations should handle:
118 /// - Authentication (API keys, tokens)
119 /// - Rate limiting
120 /// - Network errors
121 /// - Provider-specific error codes
122 async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
123
124 /// Transform provider response to unified format.
125 ///
126 /// This method converts the provider's native response format into
127 /// the standardized `CompletionResponse`.
128 fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
129
130 /// Get retry configuration.
131 ///
132 /// Override to customize retry behavior for this provider.
133 fn retry_config(&self) -> RetryConfig {
134 RetryConfig::default()
135 }
136
137 /// Get provider capabilities.
138 ///
139 /// Override to specify what features this provider supports.
140 fn capabilities(&self) -> Capabilities {
141 Capabilities::default()
142 }
143
144 /// Get default timeout.
145 fn timeout(&self) -> Duration {
146 Duration::from_secs(30)
147 }
148
149 /// Execute streaming request against provider API.
150 ///
151 /// This method returns a stream of completion chunks for streaming responses.
152 /// Not all providers support streaming - default implementation returns an error.
153 ///
154 /// # Arguments
155 /// - `req`: The provider-specific request
156 ///
157 /// # Returns
158 /// A boxed stream of Result<CompletionChunk>
159 ///
160 /// # Example
161 /// ```rust
162 /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
163 /// use simple_agent_type::request::CompletionRequest;
164 /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
165 /// use simple_agent_type::message::Message;
166 /// use simple_agent_type::error::Result;
167 /// use async_trait::async_trait;
168 /// use futures_core::Stream;
169 /// use std::pin::Pin;
170 /// use std::task::{Context, Poll};
171 ///
172 /// struct EmptyStream;
173 ///
174 /// impl Stream for EmptyStream {
175 /// type Item = Result<CompletionChunk>;
176 ///
177 /// fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
178 /// Poll::Ready(None)
179 /// }
180 /// }
181 ///
182 /// struct StreamingProvider;
183 ///
184 /// #[async_trait]
185 /// impl Provider for StreamingProvider {
186 /// fn name(&self) -> &str {
187 /// "streaming-provider"
188 /// }
189 ///
190 /// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
191 /// Ok(ProviderRequest::new("http://example.com"))
192 /// }
193 ///
194 /// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
195 /// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
196 /// }
197 ///
198 /// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
199 /// Ok(CompletionResponse {
200 /// id: "resp_1".to_string(),
201 /// model: "dummy".to_string(),
202 /// choices: vec![CompletionChoice {
203 /// index: 0,
204 /// message: Message::assistant("ok"),
205 /// finish_reason: FinishReason::Stop,
206 /// logprobs: None,
207 /// }],
208 /// usage: Usage::new(1, 1),
209 /// created: None,
210 /// provider: None,
211 /// healing_metadata: None,
212 /// })
213 /// }
214 ///
215 /// async fn execute_stream(
216 /// &self,
217 /// _req: ProviderRequest,
218 /// ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
219 /// Ok(Box::new(EmptyStream))
220 /// }
221 /// }
222 ///
223 /// let provider = StreamingProvider;
224 /// let request = CompletionRequest::builder()
225 /// .model("gpt-4")
226 /// .message(Message::user("Hello!"))
227 /// .build()
228 /// .unwrap();
229 ///
230 /// let rt = tokio::runtime::Runtime::new().unwrap();
231 /// rt.block_on(async {
232 /// let provider_request = provider.transform_request(&request).unwrap();
233 /// let _stream = provider.execute_stream(provider_request).await.unwrap();
234 /// // Use StreamExt::next to consume the stream in real usage.
235 /// });
236 /// ```
237 async fn execute_stream(
238 &self,
239 _req: ProviderRequest,
240 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
241 Err(SimpleAgentsError::Provider(
242 ProviderError::UnsupportedFeature("streaming".to_string()),
243 ))
244 }
245}
246
247/// Opaque provider-specific request.
248///
249/// This type encapsulates all information needed to make an HTTP request
250/// to a provider, without committing to a specific HTTP client library.
251///
252/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254pub struct ProviderRequest {
255 /// Full URL to send request to
256 pub url: String,
257 /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
258 #[serde(with = "header_serde")]
259 pub headers: Headers,
260 /// Request body (JSON)
261 pub body: serde_json::Value,
262 /// Optional request timeout override
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub timeout: Option<Duration>,
265}
266
267// Custom serde for Cow headers
268mod header_serde {
269 use super::Headers;
270 use serde::{Deserialize, Deserializer, Serialize, Serializer};
271 use std::borrow::Cow;
272
273 pub fn serialize<S>(
274 headers: &[(Cow<'static, str>, Cow<'static, str>)],
275 serializer: S,
276 ) -> Result<S::Ok, S::Error>
277 where
278 S: Serializer,
279 {
280 let string_headers: Vec<(&str, &str)> = headers
281 .iter()
282 .map(|(k, v)| (k.as_ref(), v.as_ref()))
283 .collect();
284 string_headers.serialize(serializer)
285 }
286
287 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
288 where
289 D: Deserializer<'de>,
290 {
291 let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
292 Ok(string_headers
293 .into_iter()
294 .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
295 .collect())
296 }
297}
298
299impl ProviderRequest {
300 /// Create a new provider request.
301 pub fn new(url: impl Into<String>) -> Self {
302 Self {
303 url: url.into(),
304 headers: Vec::new(),
305 body: serde_json::Value::Null,
306 timeout: None,
307 }
308 }
309
310 /// Add a header with owned strings.
311 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
312 self.headers
313 .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
314 self
315 }
316
317 /// Add a header with static strings (zero allocation).
318 pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
319 self.headers
320 .push((Cow::Borrowed(name), Cow::Borrowed(value)));
321 self
322 }
323
324 /// Set the body.
325 pub fn with_body(mut self, body: serde_json::Value) -> Self {
326 self.body = body;
327 self
328 }
329
330 /// Set the timeout.
331 pub fn with_timeout(mut self, timeout: Duration) -> Self {
332 self.timeout = Some(timeout);
333 self
334 }
335}
336
337/// Opaque provider-specific response.
338///
339/// This type encapsulates the raw HTTP response from a provider.
340#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
341pub struct ProviderResponse {
342 /// HTTP status code
343 pub status: u16,
344 /// Response body (JSON)
345 pub body: serde_json::Value,
346 /// Optional response headers
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub headers: Option<Vec<(String, String)>>,
349}
350
351impl ProviderResponse {
352 /// Create a new provider response.
353 pub fn new(status: u16, body: serde_json::Value) -> Self {
354 Self {
355 status,
356 body,
357 headers: None,
358 }
359 }
360
361 /// Check if response is successful (2xx).
362 pub fn is_success(&self) -> bool {
363 (200..300).contains(&self.status)
364 }
365
366 /// Check if response is a client error (4xx).
367 pub fn is_client_error(&self) -> bool {
368 (400..500).contains(&self.status)
369 }
370
371 /// Check if response is a server error (5xx).
372 pub fn is_server_error(&self) -> bool {
373 (500..600).contains(&self.status)
374 }
375
376 /// Add headers.
377 pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
378 self.headers = Some(headers);
379 self
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_provider_request_builder() {
389 let req = ProviderRequest::new("https://api.example.com/v1/completions")
390 .with_header("Authorization", "Bearer sk-test")
391 .with_header("Content-Type", "application/json")
392 .with_body(serde_json::json!({"model": "test"}))
393 .with_timeout(Duration::from_secs(30));
394
395 assert_eq!(req.url, "https://api.example.com/v1/completions");
396 assert_eq!(req.headers.len(), 2);
397 assert_eq!(req.body["model"], "test");
398 assert_eq!(req.timeout, Some(Duration::from_secs(30)));
399 }
400
401 #[test]
402 fn test_provider_response_status_checks() {
403 let resp = ProviderResponse::new(200, serde_json::json!({}));
404 assert!(resp.is_success());
405 assert!(!resp.is_client_error());
406 assert!(!resp.is_server_error());
407
408 let resp = ProviderResponse::new(404, serde_json::json!({}));
409 assert!(!resp.is_success());
410 assert!(resp.is_client_error());
411 assert!(!resp.is_server_error());
412
413 let resp = ProviderResponse::new(500, serde_json::json!({}));
414 assert!(!resp.is_success());
415 assert!(!resp.is_client_error());
416 assert!(resp.is_server_error());
417 }
418
419 #[test]
420 fn test_provider_request_serialization() {
421 let req = ProviderRequest::new("https://api.example.com")
422 .with_header("X-Test", "value")
423 .with_body(serde_json::json!({"key": "value"}));
424
425 let json = serde_json::to_string(&req).unwrap();
426 let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
427 assert_eq!(req, parsed);
428 }
429
430 #[test]
431 fn test_provider_response_serialization() {
432 let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
433 .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
434
435 let json = serde_json::to_string(&resp).unwrap();
436 let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
437 assert_eq!(resp, parsed);
438 }
439
440 // Test that Provider trait is object-safe
441 #[test]
442 fn test_provider_object_safety() {
443 fn _assert_object_safe(_: &dyn Provider) {}
444 }
445}