simple_agent_type/provider.rs
1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::borrow::Cow;
13use std::time::Duration;
14
15/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
16pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
17
18/// Common HTTP header names (static to avoid allocations)
19pub mod headers {
20 /// Authorization header
21 pub const AUTHORIZATION: &str = "Authorization";
22 /// Content-Type header
23 pub const CONTENT_TYPE: &str = "Content-Type";
24 /// API key header (used by some providers like Anthropic)
25 pub const X_API_KEY: &str = "x-api-key";
26}
27
28/// Trait for LLM providers.
29///
30/// Providers implement this trait to support different LLM APIs while
31/// presenting a unified interface to the rest of SimpleAgents.
32///
33/// # Architecture
34///
35/// The provider trait follows a three-phase architecture:
36/// 1. **Transform Request**: Convert unified request to provider format
37/// 2. **Execute**: Make the actual API call
38/// 3. **Transform Response**: Convert provider response to unified format
39///
40/// This design allows for:
41/// - Maximum flexibility in provider-specific transformations
42/// - Clean separation between protocol logic and business logic
43/// - Easy testing and mocking
44///
45/// # Example Implementation
46///
47/// ```rust
48/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
49/// use simple_agent_type::request::CompletionRequest;
50/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
51/// use simple_agent_type::message::Message;
52/// use simple_agent_type::error::Result;
53/// use async_trait::async_trait;
54///
55/// struct MyProvider;
56///
57/// #[async_trait]
58/// impl Provider for MyProvider {
59/// fn name(&self) -> &str {
60/// "my-provider"
61/// }
62///
63/// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
64/// Ok(ProviderRequest::new("http://example.com"))
65/// }
66///
67/// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
68/// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
69/// }
70///
71/// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
72/// Ok(CompletionResponse {
73/// id: "resp_1".to_string(),
74/// model: "dummy".to_string(),
75/// choices: vec![CompletionChoice {
76/// index: 0,
77/// message: Message::assistant("ok"),
78/// finish_reason: FinishReason::Stop,
79/// logprobs: None,
80/// }],
81/// usage: Usage::new(1, 1),
82/// created: None,
83/// provider: Some(self.name().to_string()),
84/// healing_metadata: None,
85/// })
86/// }
87/// }
88///
89/// let provider = MyProvider;
90/// let request = CompletionRequest::builder()
91/// .model("gpt-4")
92/// .message(Message::user("Hello!"))
93/// .build()
94/// .unwrap();
95///
96/// let rt = tokio::runtime::Runtime::new().unwrap();
97/// rt.block_on(async {
98/// let provider_request = provider.transform_request(&request).unwrap();
99/// let provider_response = provider.execute(provider_request).await.unwrap();
100/// let response = provider.transform_response(provider_response).unwrap();
101/// assert_eq!(response.content(), Some("ok"));
102/// });
103/// ```
104#[async_trait]
105pub trait Provider: Send + Sync {
106 /// Provider name (e.g., "openai", "anthropic").
107 fn name(&self) -> &str;
108
109 /// Transform unified request to provider-specific format.
110 ///
111 /// This method converts the standardized `CompletionRequest` into
112 /// the provider's native API format.
113 fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
114
115 /// Execute request against provider API.
116 ///
117 /// This method makes the actual HTTP request to the provider.
118 /// Implementations should handle:
119 /// - Authentication (API keys, tokens)
120 /// - Rate limiting
121 /// - Network errors
122 /// - Provider-specific error codes
123 async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
124
125 /// Transform provider response to unified format.
126 ///
127 /// This method converts the provider's native response format into
128 /// the standardized `CompletionResponse`.
129 fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
130
131 /// Get retry configuration.
132 ///
133 /// Override to customize retry behavior for this provider.
134 fn retry_config(&self) -> RetryConfig {
135 RetryConfig::default()
136 }
137
138 /// Get provider capabilities.
139 ///
140 /// Override to specify what features this provider supports.
141 fn capabilities(&self) -> Capabilities {
142 Capabilities::default()
143 }
144
145 /// Get default timeout.
146 fn timeout(&self) -> Duration {
147 Duration::from_secs(30)
148 }
149
150 /// Execute streaming request against provider API.
151 ///
152 /// This method returns a stream of completion chunks for streaming responses.
153 /// Not all providers support streaming - default implementation returns an error.
154 ///
155 /// # Arguments
156 /// - `req`: The provider-specific request
157 ///
158 /// # Returns
159 /// A boxed stream of Result<CompletionChunk>
160 ///
161 /// # Example
162 /// ```rust
163 /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
164 /// use simple_agent_type::request::CompletionRequest;
165 /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
166 /// use simple_agent_type::message::Message;
167 /// use simple_agent_type::error::Result;
168 /// use async_trait::async_trait;
169 /// use futures_core::Stream;
170 /// use std::pin::Pin;
171 /// use std::task::{Context, Poll};
172 ///
173 /// struct EmptyStream;
174 ///
175 /// impl Stream for EmptyStream {
176 /// type Item = Result<CompletionChunk>;
177 ///
178 /// fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179 /// Poll::Ready(None)
180 /// }
181 /// }
182 ///
183 /// struct StreamingProvider;
184 ///
185 /// #[async_trait]
186 /// impl Provider for StreamingProvider {
187 /// fn name(&self) -> &str {
188 /// "streaming-provider"
189 /// }
190 ///
191 /// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
192 /// Ok(ProviderRequest::new("http://example.com"))
193 /// }
194 ///
195 /// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
196 /// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
197 /// }
198 ///
199 /// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
200 /// Ok(CompletionResponse {
201 /// id: "resp_1".to_string(),
202 /// model: "dummy".to_string(),
203 /// choices: vec![CompletionChoice {
204 /// index: 0,
205 /// message: Message::assistant("ok"),
206 /// finish_reason: FinishReason::Stop,
207 /// logprobs: None,
208 /// }],
209 /// usage: Usage::new(1, 1),
210 /// created: None,
211 /// provider: None,
212 /// healing_metadata: None,
213 /// })
214 /// }
215 ///
216 /// async fn execute_stream(
217 /// &self,
218 /// _req: ProviderRequest,
219 /// ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
220 /// Ok(Box::new(EmptyStream))
221 /// }
222 /// }
223 ///
224 /// let provider = StreamingProvider;
225 /// let request = CompletionRequest::builder()
226 /// .model("gpt-4")
227 /// .message(Message::user("Hello!"))
228 /// .build()
229 /// .unwrap();
230 ///
231 /// let rt = tokio::runtime::Runtime::new().unwrap();
232 /// rt.block_on(async {
233 /// let provider_request = provider.transform_request(&request).unwrap();
234 /// let _stream = provider.execute_stream(provider_request).await.unwrap();
235 /// // Use StreamExt::next to consume the stream in real usage.
236 /// });
237 /// ```
238 async fn execute_stream(
239 &self,
240 mut req: ProviderRequest,
241 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
242 if let Value::Object(map) = &mut req.body {
243 if let Some(stream_value) = map.get_mut("stream") {
244 *stream_value = Value::Bool(false);
245 }
246 }
247
248 let provider_response = self.execute(req).await?;
249 let response = self.transform_response(provider_response)?;
250
251 struct SingleChunkStream {
252 chunk: Option<Result<CompletionChunk>>,
253 }
254
255 impl futures_core::Stream for SingleChunkStream {
256 type Item = Result<CompletionChunk>;
257
258 fn poll_next(
259 mut self: std::pin::Pin<&mut Self>,
260 _cx: &mut std::task::Context<'_>,
261 ) -> std::task::Poll<Option<Self::Item>> {
262 std::task::Poll::Ready(self.chunk.take())
263 }
264 }
265
266 let choices = response
267 .choices
268 .into_iter()
269 .map(|choice| crate::response::ChoiceDelta {
270 index: choice.index,
271 delta: crate::response::MessageDelta {
272 role: Some(choice.message.role),
273 content: Some(choice.message.content),
274 },
275 finish_reason: Some(choice.finish_reason),
276 })
277 .collect();
278
279 let chunk = CompletionChunk {
280 id: response.id,
281 model: response.model,
282 choices,
283 created: response.created,
284 };
285
286 Ok(Box::new(SingleChunkStream {
287 chunk: Some(Ok(chunk)),
288 }))
289 }
290}
291
292/// Opaque provider-specific request.
293///
294/// This type encapsulates all information needed to make an HTTP request
295/// to a provider, without committing to a specific HTTP client library.
296///
297/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
298#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
299pub struct ProviderRequest {
300 /// Full URL to send request to
301 pub url: String,
302 /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
303 #[serde(with = "header_serde")]
304 pub headers: Headers,
305 /// Request body (JSON)
306 pub body: serde_json::Value,
307 /// Optional request timeout override
308 #[serde(skip_serializing_if = "Option::is_none")]
309 pub timeout: Option<Duration>,
310}
311
312// Custom serde for Cow headers
313mod header_serde {
314 use super::Headers;
315 use serde::{Deserialize, Deserializer, Serialize, Serializer};
316 use std::borrow::Cow;
317
318 pub fn serialize<S>(
319 headers: &[(Cow<'static, str>, Cow<'static, str>)],
320 serializer: S,
321 ) -> Result<S::Ok, S::Error>
322 where
323 S: Serializer,
324 {
325 let string_headers: Vec<(&str, &str)> = headers
326 .iter()
327 .map(|(k, v)| (k.as_ref(), v.as_ref()))
328 .collect();
329 string_headers.serialize(serializer)
330 }
331
332 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
333 where
334 D: Deserializer<'de>,
335 {
336 let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
337 Ok(string_headers
338 .into_iter()
339 .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
340 .collect())
341 }
342}
343
344impl ProviderRequest {
345 /// Create a new provider request.
346 pub fn new(url: impl Into<String>) -> Self {
347 Self {
348 url: url.into(),
349 headers: Vec::new(),
350 body: serde_json::Value::Null,
351 timeout: None,
352 }
353 }
354
355 /// Add a header with owned strings.
356 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
357 self.headers
358 .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
359 self
360 }
361
362 /// Add a header with static strings (zero allocation).
363 pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
364 self.headers
365 .push((Cow::Borrowed(name), Cow::Borrowed(value)));
366 self
367 }
368
369 /// Set the body.
370 pub fn with_body(mut self, body: serde_json::Value) -> Self {
371 self.body = body;
372 self
373 }
374
375 /// Set the timeout.
376 pub fn with_timeout(mut self, timeout: Duration) -> Self {
377 self.timeout = Some(timeout);
378 self
379 }
380}
381
382/// Opaque provider-specific response.
383///
384/// This type encapsulates the raw HTTP response from a provider.
385#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
386pub struct ProviderResponse {
387 /// HTTP status code
388 pub status: u16,
389 /// Response body (JSON)
390 pub body: serde_json::Value,
391 /// Optional response headers
392 #[serde(skip_serializing_if = "Option::is_none")]
393 pub headers: Option<Vec<(String, String)>>,
394}
395
396impl ProviderResponse {
397 /// Create a new provider response.
398 pub fn new(status: u16, body: serde_json::Value) -> Self {
399 Self {
400 status,
401 body,
402 headers: None,
403 }
404 }
405
406 /// Check if response is successful (2xx).
407 pub fn is_success(&self) -> bool {
408 (200..300).contains(&self.status)
409 }
410
411 /// Check if response is a client error (4xx).
412 pub fn is_client_error(&self) -> bool {
413 (400..500).contains(&self.status)
414 }
415
416 /// Check if response is a server error (5xx).
417 pub fn is_server_error(&self) -> bool {
418 (500..600).contains(&self.status)
419 }
420
421 /// Add headers.
422 pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
423 self.headers = Some(headers);
424 self
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_provider_request_builder() {
434 let req = ProviderRequest::new("https://api.example.com/v1/completions")
435 .with_header("Authorization", "Bearer sk-test")
436 .with_header("Content-Type", "application/json")
437 .with_body(serde_json::json!({"model": "test"}))
438 .with_timeout(Duration::from_secs(30));
439
440 assert_eq!(req.url, "https://api.example.com/v1/completions");
441 assert_eq!(req.headers.len(), 2);
442 assert_eq!(req.body["model"], "test");
443 assert_eq!(req.timeout, Some(Duration::from_secs(30)));
444 }
445
446 #[test]
447 fn test_provider_response_status_checks() {
448 let resp = ProviderResponse::new(200, serde_json::json!({}));
449 assert!(resp.is_success());
450 assert!(!resp.is_client_error());
451 assert!(!resp.is_server_error());
452
453 let resp = ProviderResponse::new(404, serde_json::json!({}));
454 assert!(!resp.is_success());
455 assert!(resp.is_client_error());
456 assert!(!resp.is_server_error());
457
458 let resp = ProviderResponse::new(500, serde_json::json!({}));
459 assert!(!resp.is_success());
460 assert!(!resp.is_client_error());
461 assert!(resp.is_server_error());
462 }
463
464 #[test]
465 fn test_provider_request_serialization() {
466 let req = ProviderRequest::new("https://api.example.com")
467 .with_header("X-Test", "value")
468 .with_body(serde_json::json!({"key": "value"}));
469
470 let json = serde_json::to_string(&req).unwrap();
471 let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
472 assert_eq!(req, parsed);
473 }
474
475 #[test]
476 fn test_provider_response_serialization() {
477 let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
478 .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
479
480 let json = serde_json::to_string(&resp).unwrap();
481 let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
482 assert_eq!(resp, parsed);
483 }
484
485 // Test that Provider trait is object-safe
486 #[test]
487 fn test_provider_object_safety() {
488 fn _assert_object_safe(_: &dyn Provider) {}
489 }
490}