simple_agent_type/provider.rs
1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::config::{Capabilities, RetryConfig};
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use crate::response::{CompletionChunk, CompletionResponse};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::borrow::Cow;
13use std::time::Duration;
14
15/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
16pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
17
18/// Common HTTP header names (static to avoid allocations)
19pub mod headers {
20 /// Authorization header
21 pub const AUTHORIZATION: &str = "Authorization";
22 /// Content-Type header
23 pub const CONTENT_TYPE: &str = "Content-Type";
24 /// API key header (used by some providers like Anthropic)
25 pub const X_API_KEY: &str = "x-api-key";
26}
27
28/// Trait for LLM providers.
29///
30/// Providers implement this trait to support different LLM APIs while
31/// presenting a unified interface to the rest of SimpleAgents.
32///
33/// # Architecture
34///
35/// The provider trait follows a three-phase architecture:
36/// 1. **Transform Request**: Convert unified request to provider format
37/// 2. **Execute**: Make the actual API call
38/// 3. **Transform Response**: Convert provider response to unified format
39///
40/// This design allows for:
41/// - Maximum flexibility in provider-specific transformations
42/// - Clean separation between protocol logic and business logic
43/// - Easy testing and mocking
44///
45/// # Example Implementation
46///
47/// ```rust
48/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
49/// use simple_agent_type::request::CompletionRequest;
50/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
51/// use simple_agent_type::message::Message;
52/// use simple_agent_type::error::Result;
53/// use async_trait::async_trait;
54///
55/// struct MyProvider;
56///
57/// #[async_trait]
58/// impl Provider for MyProvider {
59/// fn name(&self) -> &str {
60/// "my-provider"
61/// }
62///
63/// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
64/// Ok(ProviderRequest::new("http://example.com"))
65/// }
66///
67/// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
68/// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
69/// }
70///
71/// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
72/// Ok(CompletionResponse {
73/// id: "resp_1".to_string(),
74/// model: "dummy".to_string(),
75/// choices: vec![CompletionChoice {
76/// index: 0,
77/// message: Message::assistant("ok"),
78/// finish_reason: FinishReason::Stop,
79/// logprobs: None,
80/// }],
81/// usage: Usage::new(1, 1),
82/// created: None,
83/// provider: Some(self.name().to_string()),
84/// healing_metadata: None,
85/// })
86/// }
87/// }
88///
89/// let provider = MyProvider;
90/// let request = CompletionRequest::builder()
91/// .model("gpt-4")
92/// .message(Message::user("Hello!"))
93/// .build()
94/// .unwrap();
95///
96/// let rt = tokio::runtime::Runtime::new().unwrap();
97/// rt.block_on(async {
98/// let provider_request = provider.transform_request(&request).unwrap();
99/// let provider_response = provider.execute(provider_request).await.unwrap();
100/// let response = provider.transform_response(provider_response).unwrap();
101/// assert_eq!(response.content(), Some("ok"));
102/// });
103/// ```
104#[async_trait]
105pub trait Provider: Send + Sync {
106 /// Provider name (e.g., "openai", "anthropic").
107 fn name(&self) -> &str;
108
109 /// Transform unified request to provider-specific format.
110 ///
111 /// This method converts the standardized `CompletionRequest` into
112 /// the provider's native API format.
113 fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
114
115 /// Execute request against provider API.
116 ///
117 /// This method makes the actual HTTP request to the provider.
118 /// Implementations should handle:
119 /// - Authentication (API keys, tokens)
120 /// - Rate limiting
121 /// - Network errors
122 /// - Provider-specific error codes
123 async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
124
125 /// Transform provider response to unified format.
126 ///
127 /// This method converts the provider's native response format into
128 /// the standardized `CompletionResponse`.
129 fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
130
131 /// Get retry configuration.
132 ///
133 /// Override to customize retry behavior for this provider.
134 fn retry_config(&self) -> RetryConfig {
135 RetryConfig::default()
136 }
137
138 /// Get provider capabilities.
139 ///
140 /// Override to specify what features this provider supports.
141 fn capabilities(&self) -> Capabilities {
142 Capabilities::default()
143 }
144
145 /// Get default timeout.
146 fn timeout(&self) -> Duration {
147 Duration::from_secs(30)
148 }
149
150 /// Execute streaming request against provider API.
151 ///
152 /// This method returns a stream of completion chunks for streaming responses.
153 /// Not all providers support streaming - default implementation returns an error.
154 ///
155 /// # Arguments
156 /// - `req`: The provider-specific request
157 ///
158 /// # Returns
159 /// A boxed stream of Result<CompletionChunk>
160 ///
161 /// # Example
162 /// ```rust
163 /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
164 /// use simple_agent_type::request::CompletionRequest;
165 /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
166 /// use simple_agent_type::message::Message;
167 /// use simple_agent_type::error::Result;
168 /// use async_trait::async_trait;
169 /// use futures_core::Stream;
170 /// use std::pin::Pin;
171 /// use std::task::{Context, Poll};
172 ///
173 /// struct EmptyStream;
174 ///
175 /// impl Stream for EmptyStream {
176 /// type Item = Result<CompletionChunk>;
177 ///
178 /// fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179 /// Poll::Ready(None)
180 /// }
181 /// }
182 ///
183 /// struct StreamingProvider;
184 ///
185 /// #[async_trait]
186 /// impl Provider for StreamingProvider {
187 /// fn name(&self) -> &str {
188 /// "streaming-provider"
189 /// }
190 ///
191 /// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
192 /// Ok(ProviderRequest::new("http://example.com"))
193 /// }
194 ///
195 /// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
196 /// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
197 /// }
198 ///
199 /// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
200 /// Ok(CompletionResponse {
201 /// id: "resp_1".to_string(),
202 /// model: "dummy".to_string(),
203 /// choices: vec![CompletionChoice {
204 /// index: 0,
205 /// message: Message::assistant("ok"),
206 /// finish_reason: FinishReason::Stop,
207 /// logprobs: None,
208 /// }],
209 /// usage: Usage::new(1, 1),
210 /// created: None,
211 /// provider: None,
212 /// healing_metadata: None,
213 /// })
214 /// }
215 ///
216 /// async fn execute_stream(
217 /// &self,
218 /// _req: ProviderRequest,
219 /// ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
220 /// Ok(Box::new(EmptyStream))
221 /// }
222 /// }
223 ///
224 /// let provider = StreamingProvider;
225 /// let request = CompletionRequest::builder()
226 /// .model("gpt-4")
227 /// .message(Message::user("Hello!"))
228 /// .build()
229 /// .unwrap();
230 ///
231 /// let rt = tokio::runtime::Runtime::new().unwrap();
232 /// rt.block_on(async {
233 /// let provider_request = provider.transform_request(&request).unwrap();
234 /// let _stream = provider.execute_stream(provider_request).await.unwrap();
235 /// // Use StreamExt::next to consume the stream in real usage.
236 /// });
237 /// ```
238 async fn execute_stream(
239 &self,
240 mut req: ProviderRequest,
241 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
242 if let Value::Object(map) = &mut req.body {
243 if let Some(stream_value) = map.get_mut("stream") {
244 *stream_value = Value::Bool(false);
245 }
246 }
247
248 let provider_response = self.execute(req).await?;
249 let response = self.transform_response(provider_response)?;
250
251 struct SingleChunkStream {
252 chunk: Option<Result<CompletionChunk>>,
253 }
254
255 impl futures_core::Stream for SingleChunkStream {
256 type Item = Result<CompletionChunk>;
257
258 fn poll_next(
259 mut self: std::pin::Pin<&mut Self>,
260 _cx: &mut std::task::Context<'_>,
261 ) -> std::task::Poll<Option<Self::Item>> {
262 std::task::Poll::Ready(self.chunk.take())
263 }
264 }
265
266 let choices = response
267 .choices
268 .into_iter()
269 .map(|choice| crate::response::ChoiceDelta {
270 index: choice.index,
271 delta: crate::response::MessageDelta {
272 role: Some(choice.message.role),
273 content: Some(choice.message.content),
274 reasoning_content: None,
275 tool_calls: None,
276 },
277 finish_reason: Some(choice.finish_reason),
278 })
279 .collect();
280
281 let chunk = CompletionChunk {
282 id: response.id,
283 model: response.model,
284 choices,
285 created: response.created,
286 usage: Some(response.usage),
287 };
288
289 Ok(Box::new(SingleChunkStream {
290 chunk: Some(Ok(chunk)),
291 }))
292 }
293}
294
295/// Opaque provider-specific request.
296///
297/// This type encapsulates all information needed to make an HTTP request
298/// to a provider, without committing to a specific HTTP client library.
299///
300/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
301#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
302pub struct ProviderRequest {
303 /// Full URL to send request to
304 pub url: String,
305 /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
306 #[serde(with = "header_serde")]
307 pub headers: Headers,
308 /// Request body (JSON)
309 pub body: serde_json::Value,
310 /// Optional request timeout override
311 #[serde(skip_serializing_if = "Option::is_none")]
312 pub timeout: Option<Duration>,
313}
314
315// Custom serde for Cow headers
316mod header_serde {
317 use super::Headers;
318 use serde::{Deserialize, Deserializer, Serialize, Serializer};
319 use std::borrow::Cow;
320
321 pub fn serialize<S>(
322 headers: &[(Cow<'static, str>, Cow<'static, str>)],
323 serializer: S,
324 ) -> Result<S::Ok, S::Error>
325 where
326 S: Serializer,
327 {
328 let string_headers: Vec<(&str, &str)> = headers
329 .iter()
330 .map(|(k, v)| (k.as_ref(), v.as_ref()))
331 .collect();
332 string_headers.serialize(serializer)
333 }
334
335 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
336 where
337 D: Deserializer<'de>,
338 {
339 let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
340 Ok(string_headers
341 .into_iter()
342 .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
343 .collect())
344 }
345}
346
347impl ProviderRequest {
348 /// Create a new provider request.
349 pub fn new(url: impl Into<String>) -> Self {
350 Self {
351 url: url.into(),
352 headers: Vec::new(),
353 body: serde_json::Value::Null,
354 timeout: None,
355 }
356 }
357
358 /// Add a header with owned strings.
359 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
360 self.headers
361 .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
362 self
363 }
364
365 /// Add a header with static strings (zero allocation).
366 pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
367 self.headers
368 .push((Cow::Borrowed(name), Cow::Borrowed(value)));
369 self
370 }
371
372 /// Set the body.
373 pub fn with_body(mut self, body: serde_json::Value) -> Self {
374 self.body = body;
375 self
376 }
377
378 /// Set the timeout.
379 pub fn with_timeout(mut self, timeout: Duration) -> Self {
380 self.timeout = Some(timeout);
381 self
382 }
383}
384
385/// Opaque provider-specific response.
386///
387/// This type encapsulates the raw HTTP response from a provider.
388#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
389pub struct ProviderResponse {
390 /// HTTP status code
391 pub status: u16,
392 /// Response body (JSON)
393 pub body: serde_json::Value,
394 /// Optional response headers
395 #[serde(skip_serializing_if = "Option::is_none")]
396 pub headers: Option<Vec<(String, String)>>,
397}
398
399impl ProviderResponse {
400 /// Create a new provider response.
401 pub fn new(status: u16, body: serde_json::Value) -> Self {
402 Self {
403 status,
404 body,
405 headers: None,
406 }
407 }
408
409 /// Check if response is successful (2xx).
410 pub fn is_success(&self) -> bool {
411 (200..300).contains(&self.status)
412 }
413
414 /// Check if response is a client error (4xx).
415 pub fn is_client_error(&self) -> bool {
416 (400..500).contains(&self.status)
417 }
418
419 /// Check if response is a server error (5xx).
420 pub fn is_server_error(&self) -> bool {
421 (500..600).contains(&self.status)
422 }
423
424 /// Add headers.
425 pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
426 self.headers = Some(headers);
427 self
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_provider_request_builder() {
437 let req = ProviderRequest::new("https://api.example.com/v1/completions")
438 .with_header("Authorization", "Bearer sk-test")
439 .with_header("Content-Type", "application/json")
440 .with_body(serde_json::json!({"model": "test"}))
441 .with_timeout(Duration::from_secs(30));
442
443 assert_eq!(req.url, "https://api.example.com/v1/completions");
444 assert_eq!(req.headers.len(), 2);
445 assert_eq!(req.body["model"], "test");
446 assert_eq!(req.timeout, Some(Duration::from_secs(30)));
447 }
448
449 #[test]
450 fn test_provider_response_status_checks() {
451 let resp = ProviderResponse::new(200, serde_json::json!({}));
452 assert!(resp.is_success());
453 assert!(!resp.is_client_error());
454 assert!(!resp.is_server_error());
455
456 let resp = ProviderResponse::new(404, serde_json::json!({}));
457 assert!(!resp.is_success());
458 assert!(resp.is_client_error());
459 assert!(!resp.is_server_error());
460
461 let resp = ProviderResponse::new(500, serde_json::json!({}));
462 assert!(!resp.is_success());
463 assert!(!resp.is_client_error());
464 assert!(resp.is_server_error());
465 }
466
467 #[test]
468 fn test_provider_request_serialization() {
469 let req = ProviderRequest::new("https://api.example.com")
470 .with_header("X-Test", "value")
471 .with_body(serde_json::json!({"key": "value"}));
472
473 let json = serde_json::to_string(&req).unwrap();
474 let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
475 assert_eq!(req, parsed);
476 }
477
478 #[test]
479 fn test_provider_response_serialization() {
480 let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
481 .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
482
483 let json = serde_json::to_string(&resp).unwrap();
484 let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
485 assert_eq!(resp, parsed);
486 }
487
488 // Test that Provider trait is object-safe
489 #[test]
490 fn test_provider_object_safety() {
491 fn _assert_object_safe(_: &dyn Provider) {}
492 }
493}