simple_agent_type/provider.rs
1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::borrow::Cow;
13use std::time::Duration;
14
15/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
16pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
17
18/// Common HTTP header names (static to avoid allocations)
19pub mod headers {
20 /// Authorization header
21 pub const AUTHORIZATION: &str = "Authorization";
22 /// Content-Type header
23 pub const CONTENT_TYPE: &str = "Content-Type";
24 /// API key header (used by some providers like Anthropic)
25 pub const X_API_KEY: &str = "x-api-key";
26}
27
28/// Trait for LLM providers.
29///
30/// Providers implement this trait to support different LLM APIs while
31/// presenting a unified interface to the rest of SimpleAgents.
32///
33/// # Architecture
34///
35/// The provider trait follows a three-phase architecture:
36/// 1. **Transform Request**: Convert unified request to provider format
37/// 2. **Execute**: Make the actual API call
38/// 3. **Transform Response**: Convert provider response to unified format
39///
40/// This design allows for:
41/// - Maximum flexibility in provider-specific transformations
42/// - Clean separation between protocol logic and business logic
43/// - Easy testing and mocking
44///
45/// # Example Implementation
46///
47/// ```rust
48/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
49/// use simple_agent_type::request::CompletionRequest;
50/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
51/// use simple_agent_type::message::Message;
52/// use simple_agent_type::error::Result;
53/// use async_trait::async_trait;
54///
55/// struct MyProvider;
56///
57/// #[async_trait]
58/// impl Provider for MyProvider {
59/// fn name(&self) -> &str {
60/// "my-provider"
61/// }
62///
63/// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
64/// Ok(ProviderRequest::new("http://example.com"))
65/// }
66///
67/// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
68/// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
69/// }
70///
71/// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
72/// Ok(CompletionResponse {
73/// id: "resp_1".to_string(),
74/// model: "dummy".to_string(),
75/// choices: vec![CompletionChoice {
76/// index: 0,
77/// message: Message::assistant("ok"),
78/// finish_reason: FinishReason::Stop,
79/// logprobs: None,
80/// }],
81/// usage: Usage::new(1, 1),
82/// created: None,
83/// provider: Some(self.name().to_string()),
84/// healing_metadata: None,
85/// })
86/// }
87/// }
88///
89/// let provider = MyProvider;
90/// let request = CompletionRequest::builder()
91/// .model("gpt-4")
92/// .message(Message::user("Hello!"))
93/// .build()
94/// .unwrap();
95///
96/// let rt = tokio::runtime::Runtime::new().unwrap();
97/// rt.block_on(async {
98/// let provider_request = provider.transform_request(&request).unwrap();
99/// let provider_response = provider.execute(provider_request).await.unwrap();
100/// let response = provider.transform_response(provider_response).unwrap();
101/// assert_eq!(response.content(), Some("ok"));
102/// });
103/// ```
104#[async_trait]
105pub trait Provider: Send + Sync {
106 /// Provider name (e.g., "openai", "anthropic").
107 fn name(&self) -> &str;
108
109 /// Transform unified request to provider-specific format.
110 ///
111 /// This method converts the standardized `CompletionRequest` into
112 /// the provider's native API format.
113 fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
114
115 /// Execute request against provider API.
116 ///
117 /// This method makes the actual HTTP request to the provider.
118 /// Implementations should handle:
119 /// - Authentication (API keys, tokens)
120 /// - Rate limiting
121 /// - Network errors
122 /// - Provider-specific error codes
123 async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
124
125 /// Transform provider response to unified format.
126 ///
127 /// This method converts the provider's native response format into
128 /// the standardized `CompletionResponse`.
129 fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
130
131 /// Get retry configuration.
132 ///
133 /// Override to customize retry behavior for this provider.
134 fn retry_config(&self) -> RetryConfig {
135 RetryConfig::default()
136 }
137
138 /// Get provider capabilities.
139 ///
140 /// Override to specify what features this provider supports.
141 fn capabilities(&self) -> Capabilities {
142 Capabilities::default()
143 }
144
145 /// Get default timeout.
146 fn timeout(&self) -> Duration {
147 Duration::from_secs(30)
148 }
149
150 /// Execute streaming request against provider API.
151 ///
152 /// This method returns a stream of completion chunks for streaming responses.
153 /// Not all providers support streaming - default implementation returns an error.
154 ///
155 /// # Arguments
156 /// - `req`: The provider-specific request
157 ///
158 /// # Returns
159 /// A boxed stream of Result<CompletionChunk>
160 ///
161 /// # Example
162 /// ```rust
163 /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
164 /// use simple_agent_type::request::CompletionRequest;
165 /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
166 /// use simple_agent_type::message::Message;
167 /// use simple_agent_type::error::Result;
168 /// use async_trait::async_trait;
169 /// use futures_core::Stream;
170 /// use std::pin::Pin;
171 /// use std::task::{Context, Poll};
172 ///
173 /// struct EmptyStream;
174 ///
175 /// impl Stream for EmptyStream {
176 /// type Item = Result<CompletionChunk>;
177 ///
178 /// fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179 /// Poll::Ready(None)
180 /// }
181 /// }
182 ///
183 /// struct StreamingProvider;
184 ///
185 /// #[async_trait]
186 /// impl Provider for StreamingProvider {
187 /// fn name(&self) -> &str {
188 /// "streaming-provider"
189 /// }
190 ///
191 /// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
192 /// Ok(ProviderRequest::new("http://example.com"))
193 /// }
194 ///
195 /// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
196 /// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
197 /// }
198 ///
199 /// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
200 /// Ok(CompletionResponse {
201 /// id: "resp_1".to_string(),
202 /// model: "dummy".to_string(),
203 /// choices: vec![CompletionChoice {
204 /// index: 0,
205 /// message: Message::assistant("ok"),
206 /// finish_reason: FinishReason::Stop,
207 /// logprobs: None,
208 /// }],
209 /// usage: Usage::new(1, 1),
210 /// created: None,
211 /// provider: None,
212 /// healing_metadata: None,
213 /// })
214 /// }
215 ///
216 /// async fn execute_stream(
217 /// &self,
218 /// _req: ProviderRequest,
219 /// ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
220 /// Ok(Box::new(EmptyStream))
221 /// }
222 /// }
223 ///
224 /// let provider = StreamingProvider;
225 /// let request = CompletionRequest::builder()
226 /// .model("gpt-4")
227 /// .message(Message::user("Hello!"))
228 /// .build()
229 /// .unwrap();
230 ///
231 /// let rt = tokio::runtime::Runtime::new().unwrap();
232 /// rt.block_on(async {
233 /// let provider_request = provider.transform_request(&request).unwrap();
234 /// let _stream = provider.execute_stream(provider_request).await.unwrap();
235 /// // Use StreamExt::next to consume the stream in real usage.
236 /// });
237 /// ```
238 async fn execute_stream(
239 &self,
240 mut req: ProviderRequest,
241 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
242 if let Value::Object(map) = &mut req.body {
243 if let Some(stream_value) = map.get_mut("stream") {
244 *stream_value = Value::Bool(false);
245 }
246 }
247
248 let provider_response = self.execute(req).await?;
249 let response = self.transform_response(provider_response)?;
250
251 struct SingleChunkStream {
252 chunk: Option<Result<CompletionChunk>>,
253 }
254
255 impl futures_core::Stream for SingleChunkStream {
256 type Item = Result<CompletionChunk>;
257
258 fn poll_next(
259 mut self: std::pin::Pin<&mut Self>,
260 _cx: &mut std::task::Context<'_>,
261 ) -> std::task::Poll<Option<Self::Item>> {
262 std::task::Poll::Ready(self.chunk.take())
263 }
264 }
265
266 let choices = response
267 .choices
268 .into_iter()
269 .map(|choice| crate::response::ChoiceDelta {
270 index: choice.index,
271 delta: crate::response::MessageDelta {
272 role: Some(choice.message.role),
273 content: Some(choice.message.content),
274 reasoning_content: None,
275 },
276 finish_reason: Some(choice.finish_reason),
277 })
278 .collect();
279
280 let chunk = CompletionChunk {
281 id: response.id,
282 model: response.model,
283 choices,
284 created: response.created,
285 usage: Some(response.usage),
286 };
287
288 Ok(Box::new(SingleChunkStream {
289 chunk: Some(Ok(chunk)),
290 }))
291 }
292}
293
294/// Opaque provider-specific request.
295///
296/// This type encapsulates all information needed to make an HTTP request
297/// to a provider, without committing to a specific HTTP client library.
298///
299/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
300#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
301pub struct ProviderRequest {
302 /// Full URL to send request to
303 pub url: String,
304 /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
305 #[serde(with = "header_serde")]
306 pub headers: Headers,
307 /// Request body (JSON)
308 pub body: serde_json::Value,
309 /// Optional request timeout override
310 #[serde(skip_serializing_if = "Option::is_none")]
311 pub timeout: Option<Duration>,
312}
313
314// Custom serde for Cow headers
315mod header_serde {
316 use super::Headers;
317 use serde::{Deserialize, Deserializer, Serialize, Serializer};
318 use std::borrow::Cow;
319
320 pub fn serialize<S>(
321 headers: &[(Cow<'static, str>, Cow<'static, str>)],
322 serializer: S,
323 ) -> Result<S::Ok, S::Error>
324 where
325 S: Serializer,
326 {
327 let string_headers: Vec<(&str, &str)> = headers
328 .iter()
329 .map(|(k, v)| (k.as_ref(), v.as_ref()))
330 .collect();
331 string_headers.serialize(serializer)
332 }
333
334 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
335 where
336 D: Deserializer<'de>,
337 {
338 let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
339 Ok(string_headers
340 .into_iter()
341 .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
342 .collect())
343 }
344}
345
346impl ProviderRequest {
347 /// Create a new provider request.
348 pub fn new(url: impl Into<String>) -> Self {
349 Self {
350 url: url.into(),
351 headers: Vec::new(),
352 body: serde_json::Value::Null,
353 timeout: None,
354 }
355 }
356
357 /// Add a header with owned strings.
358 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
359 self.headers
360 .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
361 self
362 }
363
364 /// Add a header with static strings (zero allocation).
365 pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
366 self.headers
367 .push((Cow::Borrowed(name), Cow::Borrowed(value)));
368 self
369 }
370
371 /// Set the body.
372 pub fn with_body(mut self, body: serde_json::Value) -> Self {
373 self.body = body;
374 self
375 }
376
377 /// Set the timeout.
378 pub fn with_timeout(mut self, timeout: Duration) -> Self {
379 self.timeout = Some(timeout);
380 self
381 }
382}
383
384/// Opaque provider-specific response.
385///
386/// This type encapsulates the raw HTTP response from a provider.
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderResponse {
389 /// HTTP status code
390 pub status: u16,
391 /// Response body (JSON)
392 pub body: serde_json::Value,
393 /// Optional response headers
394 #[serde(skip_serializing_if = "Option::is_none")]
395 pub headers: Option<Vec<(String, String)>>,
396}
397
398impl ProviderResponse {
399 /// Create a new provider response.
400 pub fn new(status: u16, body: serde_json::Value) -> Self {
401 Self {
402 status,
403 body,
404 headers: None,
405 }
406 }
407
408 /// Check if response is successful (2xx).
409 pub fn is_success(&self) -> bool {
410 (200..300).contains(&self.status)
411 }
412
413 /// Check if response is a client error (4xx).
414 pub fn is_client_error(&self) -> bool {
415 (400..500).contains(&self.status)
416 }
417
418 /// Check if response is a server error (5xx).
419 pub fn is_server_error(&self) -> bool {
420 (500..600).contains(&self.status)
421 }
422
423 /// Add headers.
424 pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
425 self.headers = Some(headers);
426 self
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_provider_request_builder() {
436 let req = ProviderRequest::new("https://api.example.com/v1/completions")
437 .with_header("Authorization", "Bearer sk-test")
438 .with_header("Content-Type", "application/json")
439 .with_body(serde_json::json!({"model": "test"}))
440 .with_timeout(Duration::from_secs(30));
441
442 assert_eq!(req.url, "https://api.example.com/v1/completions");
443 assert_eq!(req.headers.len(), 2);
444 assert_eq!(req.body["model"], "test");
445 assert_eq!(req.timeout, Some(Duration::from_secs(30)));
446 }
447
448 #[test]
449 fn test_provider_response_status_checks() {
450 let resp = ProviderResponse::new(200, serde_json::json!({}));
451 assert!(resp.is_success());
452 assert!(!resp.is_client_error());
453 assert!(!resp.is_server_error());
454
455 let resp = ProviderResponse::new(404, serde_json::json!({}));
456 assert!(!resp.is_success());
457 assert!(resp.is_client_error());
458 assert!(!resp.is_server_error());
459
460 let resp = ProviderResponse::new(500, serde_json::json!({}));
461 assert!(!resp.is_success());
462 assert!(!resp.is_client_error());
463 assert!(resp.is_server_error());
464 }
465
466 #[test]
467 fn test_provider_request_serialization() {
468 let req = ProviderRequest::new("https://api.example.com")
469 .with_header("X-Test", "value")
470 .with_body(serde_json::json!({"key": "value"}));
471
472 let json = serde_json::to_string(&req).unwrap();
473 let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
474 assert_eq!(req, parsed);
475 }
476
477 #[test]
478 fn test_provider_response_serialization() {
479 let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
480 .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
481
482 let json = serde_json::to_string(&resp).unwrap();
483 let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
484 assert_eq!(resp, parsed);
485 }
486
487 // Test that Provider trait is object-safe
488 #[test]
489 fn test_provider_object_safety() {
490 fn _assert_object_safe(_: &dyn Provider) {}
491 }
492}