simple_agent_type/provider.rs
1//! Provider trait and types.
2//!
3//! Defines the interface for LLM providers with transformation hooks.
4
5use crate::error::Result;
6use crate::request::CompletionRequest;
7use crate::response::{CompletionChunk, CompletionResponse};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::borrow::Cow;
12use std::time::Duration;
13
14/// Retry configuration for failed requests.
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub struct RetryConfig {
17 /// Maximum number of retry attempts
18 pub max_attempts: u32,
19 /// Initial backoff duration
20 pub initial_backoff: Duration,
21 /// Maximum backoff duration
22 pub max_backoff: Duration,
23 /// Backoff multiplier for exponential backoff
24 pub backoff_multiplier: f32,
25 /// Add random jitter to backoff
26 pub jitter: bool,
27}
28
29impl Default for RetryConfig {
30 fn default() -> Self {
31 Self {
32 max_attempts: 3,
33 initial_backoff: Duration::from_millis(100),
34 max_backoff: Duration::from_secs(10),
35 backoff_multiplier: 2.0,
36 jitter: true,
37 }
38 }
39}
40
41/// Provider capabilities.
42#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
43pub struct Capabilities {
44 /// Supports streaming responses
45 pub streaming: bool,
46 /// Supports function/tool calling
47 pub function_calling: bool,
48 /// Supports vision/image inputs
49 pub vision: bool,
50 /// Maximum output tokens
51 pub max_tokens: u32,
52}
53
54/// Type alias for HTTP headers (key-value pairs with static lifetime strings)
55pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
56
57/// Common HTTP header names (static to avoid allocations)
58pub mod headers {
59 /// Authorization header
60 pub const AUTHORIZATION: &str = "Authorization";
61 /// Content-Type header
62 pub const CONTENT_TYPE: &str = "Content-Type";
63 /// API key header (used by some providers like Anthropic)
64 pub const X_API_KEY: &str = "x-api-key";
65}
66
67/// Trait for LLM providers.
68///
69/// Providers implement this trait to support different LLM APIs while
70/// presenting a unified interface to the rest of SimpleAgents.
71///
72/// # Architecture
73///
74/// The provider trait follows a three-phase architecture:
75/// 1. **Transform Request**: Convert unified request to provider format
76/// 2. **Execute**: Make the actual API call
77/// 3. **Transform Response**: Convert provider response to unified format
78///
79/// This design allows for:
80/// - Maximum flexibility in provider-specific transformations
81/// - Clean separation between protocol logic and business logic
82/// - Easy testing and mocking
83///
84/// # Example Implementation
85///
86/// ```rust
87/// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
88/// use simple_agent_type::request::CompletionRequest;
89/// use simple_agent_type::response::{CompletionResponse, CompletionChoice, FinishReason, Usage};
90/// use simple_agent_type::message::Message;
91/// use simple_agent_type::error::Result;
92/// use async_trait::async_trait;
93///
94/// struct MyProvider;
95///
96/// #[async_trait]
97/// impl Provider for MyProvider {
98/// fn name(&self) -> &str {
99/// "my-provider"
100/// }
101///
102/// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
103/// Ok(ProviderRequest::new("http://example.com"))
104/// }
105///
106/// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
107/// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
108/// }
109///
110/// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
111/// Ok(CompletionResponse {
112/// id: "resp_1".to_string(),
113/// model: "dummy".to_string(),
114/// choices: vec![CompletionChoice {
115/// index: 0,
116/// message: Message::assistant("ok"),
117/// finish_reason: FinishReason::Stop,
118/// logprobs: None,
119/// }],
120/// usage: Usage::new(1, 1),
121/// created: None,
122/// provider: Some(self.name().to_string()),
123/// healing_metadata: None,
124/// })
125/// }
126/// }
127///
128/// let provider = MyProvider;
129/// let request = CompletionRequest::builder()
130/// .model("gpt-4")
131/// .message(Message::user("Hello!"))
132/// .build()
133/// .unwrap();
134///
135/// let rt = tokio::runtime::Runtime::new().unwrap();
136/// rt.block_on(async {
137/// let provider_request = provider.transform_request(&request).unwrap();
138/// let provider_response = provider.execute(provider_request).await.unwrap();
139/// let response = provider.transform_response(provider_response).unwrap();
140/// assert_eq!(response.content(), Some("ok"));
141/// });
142/// ```
143#[async_trait]
144pub trait Provider: Send + Sync {
145 /// Provider name (e.g., "openai", "anthropic").
146 fn name(&self) -> &str;
147
148 /// Transform unified request to provider-specific format.
149 ///
150 /// This method converts the standardized `CompletionRequest` into
151 /// the provider's native API format.
152 fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
153
154 /// Execute request against provider API.
155 ///
156 /// This method makes the actual HTTP request to the provider.
157 /// Implementations should handle:
158 /// - Authentication (API keys, tokens)
159 /// - Rate limiting
160 /// - Network errors
161 /// - Provider-specific error codes
162 async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
163
164 /// Transform provider response to unified format.
165 ///
166 /// This method converts the provider's native response format into
167 /// the standardized `CompletionResponse`.
168 fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
169
170 /// Get retry configuration.
171 ///
172 /// Override to customize retry behavior for this provider.
173 fn retry_config(&self) -> RetryConfig {
174 RetryConfig::default()
175 }
176
177 /// Get provider capabilities.
178 ///
179 /// Override to specify what features this provider supports.
180 fn capabilities(&self) -> Capabilities {
181 Capabilities::default()
182 }
183
184 /// Get default timeout.
185 fn timeout(&self) -> Duration {
186 Duration::from_secs(30)
187 }
188
189 /// Execute streaming request against provider API.
190 ///
191 /// This method returns a stream of completion chunks for streaming responses.
192 /// Not all providers support streaming - default implementation returns an error.
193 ///
194 /// # Arguments
195 /// - `req`: The provider-specific request
196 ///
197 /// # Returns
198 /// A boxed stream of Result<CompletionChunk>
199 ///
200 /// # Example
201 /// ```rust
202 /// use simple_agent_type::provider::{Provider, ProviderRequest, ProviderResponse};
203 /// use simple_agent_type::request::CompletionRequest;
204 /// use simple_agent_type::response::{CompletionResponse, CompletionChunk, CompletionChoice, FinishReason, Usage};
205 /// use simple_agent_type::message::Message;
206 /// use simple_agent_type::error::Result;
207 /// use async_trait::async_trait;
208 /// use futures_core::Stream;
209 /// use std::pin::Pin;
210 /// use std::task::{Context, Poll};
211 ///
212 /// struct EmptyStream;
213 ///
214 /// impl Stream for EmptyStream {
215 /// type Item = Result<CompletionChunk>;
216 ///
217 /// fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
218 /// Poll::Ready(None)
219 /// }
220 /// }
221 ///
222 /// struct StreamingProvider;
223 ///
224 /// #[async_trait]
225 /// impl Provider for StreamingProvider {
226 /// fn name(&self) -> &str {
227 /// "streaming-provider"
228 /// }
229 ///
230 /// fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
231 /// Ok(ProviderRequest::new("http://example.com"))
232 /// }
233 ///
234 /// async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
235 /// Ok(ProviderResponse::new(200, serde_json::json!({"ok": true})))
236 /// }
237 ///
238 /// fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
239 /// Ok(CompletionResponse {
240 /// id: "resp_1".to_string(),
241 /// model: "dummy".to_string(),
242 /// choices: vec![CompletionChoice {
243 /// index: 0,
244 /// message: Message::assistant("ok"),
245 /// finish_reason: FinishReason::Stop,
246 /// logprobs: None,
247 /// }],
248 /// usage: Usage::new(1, 1),
249 /// created: None,
250 /// provider: None,
251 /// healing_metadata: None,
252 /// })
253 /// }
254 ///
255 /// async fn execute_stream(
256 /// &self,
257 /// _req: ProviderRequest,
258 /// ) -> Result<Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
259 /// Ok(Box::new(EmptyStream))
260 /// }
261 /// }
262 ///
263 /// let provider = StreamingProvider;
264 /// let request = CompletionRequest::builder()
265 /// .model("gpt-4")
266 /// .message(Message::user("Hello!"))
267 /// .build()
268 /// .unwrap();
269 ///
270 /// let rt = tokio::runtime::Runtime::new().unwrap();
271 /// rt.block_on(async {
272 /// let provider_request = provider.transform_request(&request).unwrap();
273 /// let _stream = provider.execute_stream(provider_request).await.unwrap();
274 /// // Use StreamExt::next to consume the stream in real usage.
275 /// });
276 /// ```
277 async fn execute_stream(
278 &self,
279 mut req: ProviderRequest,
280 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
281 if let Value::Object(map) = &mut req.body {
282 if let Some(stream_value) = map.get_mut("stream") {
283 *stream_value = Value::Bool(false);
284 }
285 }
286
287 let provider_response = self.execute(req).await?;
288 let response = self.transform_response(provider_response)?;
289
290 struct SingleChunkStream {
291 chunk: Option<Result<CompletionChunk>>,
292 }
293
294 impl futures_core::Stream for SingleChunkStream {
295 type Item = Result<CompletionChunk>;
296
297 fn poll_next(
298 mut self: std::pin::Pin<&mut Self>,
299 _cx: &mut std::task::Context<'_>,
300 ) -> std::task::Poll<Option<Self::Item>> {
301 std::task::Poll::Ready(self.chunk.take())
302 }
303 }
304
305 let choices = response
306 .choices
307 .into_iter()
308 .map(|choice| crate::response::ChoiceDelta {
309 index: choice.index,
310 delta: crate::response::MessageDelta {
311 role: Some(choice.message.role),
312 content: Some(choice.message.content_text().to_string()),
313 reasoning_content: None,
314 tool_calls: None,
315 },
316 finish_reason: Some(choice.finish_reason),
317 })
318 .collect();
319
320 let chunk = CompletionChunk {
321 id: response.id,
322 model: response.model,
323 choices,
324 created: response.created,
325 usage: Some(response.usage),
326 };
327
328 Ok(Box::new(SingleChunkStream {
329 chunk: Some(Ok(chunk)),
330 }))
331 }
332}
333
334/// Opaque provider-specific request.
335///
336/// This type encapsulates all information needed to make an HTTP request
337/// to a provider, without committing to a specific HTTP client library.
338///
339/// Headers use `Cow<'static, str>` to avoid allocations for common headers.
340#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
341pub struct ProviderRequest {
342 /// Full URL to send request to
343 pub url: String,
344 /// HTTP headers (name, value pairs) using Cow to avoid allocations for static strings
345 #[serde(with = "header_serde")]
346 pub headers: Headers,
347 /// Request body (JSON)
348 pub body: serde_json::Value,
349 /// Optional request timeout override
350 #[serde(skip_serializing_if = "Option::is_none")]
351 pub timeout: Option<Duration>,
352}
353
354// Custom serde for Cow headers
355mod header_serde {
356 use super::Headers;
357 use serde::{Deserialize, Deserializer, Serialize, Serializer};
358 use std::borrow::Cow;
359
360 pub fn serialize<S>(
361 headers: &[(Cow<'static, str>, Cow<'static, str>)],
362 serializer: S,
363 ) -> Result<S::Ok, S::Error>
364 where
365 S: Serializer,
366 {
367 let string_headers: Vec<(&str, &str)> = headers
368 .iter()
369 .map(|(k, v)| (k.as_ref(), v.as_ref()))
370 .collect();
371 string_headers.serialize(serializer)
372 }
373
374 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
375 where
376 D: Deserializer<'de>,
377 {
378 let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
379 Ok(string_headers
380 .into_iter()
381 .map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
382 .collect())
383 }
384}
385
386impl ProviderRequest {
387 /// Create a new provider request.
388 pub fn new(url: impl Into<String>) -> Self {
389 Self {
390 url: url.into(),
391 headers: Vec::new(),
392 body: serde_json::Value::Null,
393 timeout: None,
394 }
395 }
396
397 /// Add a header with owned strings.
398 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
399 self.headers
400 .push((Cow::Owned(name.into()), Cow::Owned(value.into())));
401 self
402 }
403
404 /// Add a header with static strings (zero allocation).
405 pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
406 self.headers
407 .push((Cow::Borrowed(name), Cow::Borrowed(value)));
408 self
409 }
410
411 /// Set the body.
412 pub fn with_body(mut self, body: serde_json::Value) -> Self {
413 self.body = body;
414 self
415 }
416
417 /// Set the timeout.
418 pub fn with_timeout(mut self, timeout: Duration) -> Self {
419 self.timeout = Some(timeout);
420 self
421 }
422}
423
424/// Opaque provider-specific response.
425///
426/// This type encapsulates the raw HTTP response from a provider.
427#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
428pub struct ProviderResponse {
429 /// HTTP status code
430 pub status: u16,
431 /// Response body (JSON)
432 pub body: serde_json::Value,
433 /// Optional response headers
434 #[serde(skip_serializing_if = "Option::is_none")]
435 pub headers: Option<Vec<(String, String)>>,
436}
437
438impl ProviderResponse {
439 /// Create a new provider response.
440 pub fn new(status: u16, body: serde_json::Value) -> Self {
441 Self {
442 status,
443 body,
444 headers: None,
445 }
446 }
447
448 /// Check if response is successful (2xx).
449 pub fn is_success(&self) -> bool {
450 (200..300).contains(&self.status)
451 }
452
453 /// Check if response is a client error (4xx).
454 pub fn is_client_error(&self) -> bool {
455 (400..500).contains(&self.status)
456 }
457
458 /// Check if response is a server error (5xx).
459 pub fn is_server_error(&self) -> bool {
460 (500..600).contains(&self.status)
461 }
462
463 /// Add headers.
464 pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
465 self.headers = Some(headers);
466 self
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_provider_request_builder() {
476 let req = ProviderRequest::new("https://api.example.com/v1/completions")
477 .with_header("Authorization", "Bearer sk-test")
478 .with_header("Content-Type", "application/json")
479 .with_body(serde_json::json!({"model": "test"}))
480 .with_timeout(Duration::from_secs(30));
481
482 assert_eq!(req.url, "https://api.example.com/v1/completions");
483 assert_eq!(req.headers.len(), 2);
484 assert_eq!(req.body["model"], "test");
485 assert_eq!(req.timeout, Some(Duration::from_secs(30)));
486 }
487
488 #[test]
489 fn test_provider_response_status_checks() {
490 let resp = ProviderResponse::new(200, serde_json::json!({}));
491 assert!(resp.is_success());
492 assert!(!resp.is_client_error());
493 assert!(!resp.is_server_error());
494
495 let resp = ProviderResponse::new(404, serde_json::json!({}));
496 assert!(!resp.is_success());
497 assert!(resp.is_client_error());
498 assert!(!resp.is_server_error());
499
500 let resp = ProviderResponse::new(500, serde_json::json!({}));
501 assert!(!resp.is_success());
502 assert!(!resp.is_client_error());
503 assert!(resp.is_server_error());
504 }
505
506 #[test]
507 fn test_provider_request_serialization() {
508 let req = ProviderRequest::new("https://api.example.com")
509 .with_header("X-Test", "value")
510 .with_body(serde_json::json!({"key": "value"}));
511
512 let json = serde_json::to_string(&req).unwrap();
513 let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
514 assert_eq!(req, parsed);
515 }
516
517 #[test]
518 fn test_provider_response_serialization() {
519 let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
520 .with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
521
522 let json = serde_json::to_string(&resp).unwrap();
523 let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
524 assert_eq!(resp, parsed);
525 }
526
527 // Test that Provider trait is object-safe
528 #[test]
529 fn test_provider_object_safety() {
530 fn _assert_object_safe(_: &dyn Provider) {}
531 }
532}