Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use serde::{Deserialize, Serialize};
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use simple_agent_type::provider::RetryConfig;
9use simple_agent_type::telemetry::{ApiFormat, TelemetryConfig, TraceContext};
10use simple_agents_healing::coercion::CoercionEngine;
11use simple_agents_healing::parser::JsonishParser;
12use simple_agents_healing::schema::Schema;
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::debug;
16
17// ---------------------------------------------------------------------------
18// Configuration types
19// ---------------------------------------------------------------------------
20
21/// Top-level configuration for a [`SimpleAgentsClient`].
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ClientConfig {
24    /// Provider identifier (e.g. `"openai"`, `"anthropic"`).
25    pub provider: String,
26    /// API key used for authentication with the provider.
27    pub api_key: String,
28    /// Optional custom base URL for the provider API.
29    pub base_url: Option<String>,
30    /// Wire format for the provider API.
31    pub api_format: ApiFormat,
32    /// Extra HTTP headers sent with every request.
33    pub extra_headers: Option<Vec<(String, String)>>,
34    /// Optional telemetry / OTEL export configuration.
35    pub telemetry: Option<TelemetryConfig>,
36    /// Default retry policy applied to all requests.
37    pub default_retry: RetryConfig,
38}
39
40impl Default for ClientConfig {
41    fn default() -> Self {
42        Self {
43            provider: "openai".into(),
44            api_key: String::new(),
45            base_url: None,
46            api_format: ApiFormat::default(),
47            extra_headers: None,
48            telemetry: None,
49            default_retry: RetryConfig::default(),
50        }
51    }
52}
53
54/// Flags that control execution behaviour for a single run.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ExecutionFlags {
57    /// Whether the workflow-level orchestrator streams between nodes.
58    pub workflow_streaming: bool,
59    /// Whether individual LLM calls within a node use streaming.
60    pub node_llm_streaming: bool,
61}
62
63impl Default for ExecutionFlags {
64    fn default() -> Self {
65        Self {
66            workflow_streaming: false,
67            node_llm_streaming: true,
68        }
69    }
70}
71
72/// Per-run options passed to the executor.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RunOptions {
75    /// Print timing / token-usage statistics after the run.
76    pub nerdstats: bool,
77    /// Whether to emit telemetry spans for this run.
78    pub telemetry_enabled: bool,
79    /// Distributed trace context to propagate.
80    pub trace_context: Option<TraceContext>,
81    /// Execution behaviour flags.
82    pub execution_flags: ExecutionFlags,
83}
84
85impl Default for RunOptions {
86    fn default() -> Self {
87        Self {
88            nerdstats: true,
89            telemetry_enabled: true,
90            trace_context: None,
91            execution_flags: ExecutionFlags::default(),
92        }
93    }
94}
95
96/// Mode for completion post-processing.
97#[derive(Clone)]
98pub enum CompletionMode {
99    /// Return the raw completion response.
100    Standard,
101    /// Parse the response content as JSON using healing.
102    HealedJson,
103    /// Parse and coerce the response into the provided schema.
104    CoercedSchema(Schema),
105}
106
107/// Options that control completion behavior.
108#[derive(Clone)]
109pub struct CompletionOptions {
110    /// Completion post-processing mode.
111    pub mode: CompletionMode,
112}
113
114impl Default for CompletionOptions {
115    fn default() -> Self {
116        Self {
117            mode: CompletionMode::Standard,
118        }
119    }
120}
121
122/// Result of a unified completion call.
123pub enum CompletionOutcome {
124    /// A standard, non-streaming completion response.
125    Response(CompletionResponse),
126    /// A streaming response yielding completion chunks.
127    Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
128    /// A healed JSON response.
129    HealedJson(HealedJsonResponse),
130    /// A schema-coerced response.
131    CoercedSchema(HealedSchemaResponse),
132}
133
134/// Unified SimpleAgents client.
135pub struct SimpleAgentsClient {
136    provider: Arc<dyn Provider>,
137    config: ClientConfig,
138    healing: HealingSettings,
139}
140
141impl SimpleAgentsClient {
142    /// Create a new client wrapping a single provider with default config and healing.
143    pub fn new(provider: Arc<dyn Provider>) -> Self {
144        Self {
145            provider,
146            config: ClientConfig::default(),
147            healing: HealingSettings::default(),
148        }
149    }
150
151    /// Create a new client from a [`ClientConfig`], using the supplied provider.
152    pub fn from_config(provider: Arc<dyn Provider>, config: ClientConfig) -> Self {
153        Self {
154            provider,
155            config,
156            healing: HealingSettings::default(),
157        }
158    }
159
160    /// Create a new client with custom healing settings.
161    pub fn with_healing(provider: Arc<dyn Provider>, healing: HealingSettings) -> Self {
162        Self {
163            provider,
164            config: ClientConfig::default(),
165            healing,
166        }
167    }
168
169    /// Return a reference to the client's configuration.
170    pub fn config(&self) -> &ClientConfig {
171        &self.config
172    }
173
174    /// Return the name of the underlying provider.
175    pub fn provider_name(&self) -> &str {
176        self.provider.name()
177    }
178
179    /// Execute a completion request.
180    pub async fn complete(
181        &self,
182        request: &CompletionRequest,
183        options: CompletionOptions,
184    ) -> Result<CompletionOutcome> {
185        if request.stream.unwrap_or(false) {
186            let stream = self.stream(request).await?;
187            return Ok(CompletionOutcome::Stream(stream));
188        }
189
190        match options.mode {
191            CompletionMode::Standard => {
192                let response = self.complete_response(request).await?;
193                Ok(CompletionOutcome::Response(response))
194            }
195            CompletionMode::HealedJson => {
196                let healed = self.complete_json_internal(request).await?;
197                Ok(CompletionOutcome::HealedJson(healed))
198            }
199            CompletionMode::CoercedSchema(schema) => {
200                let healed = self.complete_with_schema_internal(request, &schema).await?;
201                Ok(CompletionOutcome::CoercedSchema(healed))
202            }
203        }
204    }
205
206    async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
207        request.validate()?;
208
209        let provider_request = self.provider.transform_request(request)?;
210        let provider_response = self.execute_with_retries(provider_request).await?;
211        self.provider.transform_response(provider_response)
212    }
213
214    async fn execute_with_retries(
215        &self,
216        provider_request: simple_agent_type::provider::ProviderRequest,
217    ) -> Result<simple_agent_type::provider::ProviderResponse> {
218        let retry = &self.config.default_retry;
219        let max_attempts = retry.max_attempts.max(1);
220        let mut attempt = 1;
221
222        loop {
223            match self.provider.execute(provider_request.clone()).await {
224                Ok(response) => return Ok(response),
225                Err(error) => {
226                    if attempt >= max_attempts || !is_retryable_error(&error) {
227                        return Err(error);
228                    }
229
230                    let delay = retry_delay(retry, attempt, &error);
231                    if !delay.is_zero() {
232                        tokio::time::sleep(delay).await;
233                    }
234                    attempt += 1;
235                }
236            }
237        }
238    }
239
240    async fn complete_json_internal(
241        &self,
242        request: &CompletionRequest,
243    ) -> Result<HealedJsonResponse> {
244        self.ensure_healing_enabled()?;
245        let response = self.complete_response(request).await?;
246        let content = response.content().ok_or_else(|| {
247            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
248                error_message: "response contained no content".to_string(),
249                input: String::new(),
250            })
251        })?;
252
253        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
254        let parsed = parser.parse(content)?;
255
256        Ok(HealedJsonResponse { response, parsed })
257    }
258
259    async fn complete_with_schema_internal(
260        &self,
261        request: &CompletionRequest,
262        schema: &Schema,
263    ) -> Result<HealedSchemaResponse> {
264        self.ensure_healing_enabled()?;
265        let healed = self.complete_json_internal(request).await?;
266        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
267        let coerced = engine
268            .coerce(&healed.parsed.value, schema)
269            .map_err(SimpleAgentsError::Healing)?;
270
271        Ok(HealedSchemaResponse {
272            response: healed.response,
273            parsed: healed.parsed,
274            coerced,
275        })
276    }
277
278    async fn stream(
279        &self,
280        request: &CompletionRequest,
281    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
282        request.validate()?;
283        debug!(
284            model = %request.model,
285            stream = ?request.stream,
286            "SimpleAgentsClient.stream start"
287        );
288
289        let provider_request = self.provider.transform_request(request)?;
290        self.provider.execute_stream(provider_request).await
291    }
292
293    fn ensure_healing_enabled(&self) -> Result<()> {
294        if self.healing.enabled {
295            Ok(())
296        } else {
297            Err(SimpleAgentsError::Config(
298                "healing is disabled for this client".to_string(),
299            ))
300        }
301    }
302}
303
304fn is_retryable_error(error: &SimpleAgentsError) -> bool {
305    match error {
306        SimpleAgentsError::Provider(provider_error) => provider_error.is_retryable(),
307        SimpleAgentsError::Network(_) => true,
308        _ => false,
309    }
310}
311
312fn retry_after(error: &SimpleAgentsError) -> Option<Duration> {
313    match error {
314        SimpleAgentsError::Provider(simple_agent_type::error::ProviderError::RateLimit {
315            retry_after,
316        }) => *retry_after,
317        _ => None,
318    }
319}
320
321fn retry_delay(retry: &RetryConfig, failed_attempt: u32, error: &SimpleAgentsError) -> Duration {
322    if let Some(delay) = retry_after(error) {
323        return delay;
324    }
325
326    let factor = retry
327        .backoff_multiplier
328        .max(1.0)
329        .powi(failed_attempt.saturating_sub(1).min(31) as i32);
330    let delay = retry.initial_backoff.mul_f32(factor);
331    delay.min(retry.max_backoff.max(retry.initial_backoff))
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use async_trait::async_trait;
338    use futures_util::StreamExt;
339    use simple_agent_type::error::ProviderError;
340    use simple_agent_type::prelude::*;
341    use std::sync::atomic::{AtomicUsize, Ordering};
342
343    struct MockProvider {
344        name: &'static str,
345        calls: AtomicUsize,
346    }
347
348    impl MockProvider {
349        fn new(name: &'static str) -> Self {
350            Self {
351                name,
352                calls: AtomicUsize::new(0),
353            }
354        }
355    }
356
357    #[async_trait]
358    impl Provider for MockProvider {
359        fn name(&self) -> &str {
360            self.name
361        }
362
363        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
364            Ok(ProviderRequest::new("http://example.com"))
365        }
366
367        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
368            self.calls.fetch_add(1, Ordering::Relaxed);
369            Ok(ProviderResponse::new(
370                200,
371                serde_json::json!({"content": "ok"}),
372            ))
373        }
374
375        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
376            Ok(CompletionResponse {
377                id: "resp_test".to_string(),
378                model: "test-model".to_string(),
379                choices: vec![CompletionChoice {
380                    index: 0,
381                    message: Message::assistant("ok"),
382                    finish_reason: FinishReason::Stop,
383                    logprobs: None,
384                }],
385                usage: Usage::new(1, 1),
386                created: None,
387                provider: Some(self.name.to_string()),
388                healing_metadata: None,
389            })
390        }
391    }
392
393    #[tokio::test]
394    async fn complete_returns_response() {
395        let provider = Arc::new(MockProvider::new("p1"));
396        let client = SimpleAgentsClient::new(provider);
397
398        let request = CompletionRequest::builder()
399            .model("gpt-4")
400            .message(Message::user("Hi"))
401            .build()
402            .unwrap();
403
404        let outcome = client
405            .complete(&request, CompletionOptions::default())
406            .await
407            .unwrap();
408
409        match outcome {
410            CompletionOutcome::Response(resp) => {
411                assert_eq!(resp.provider.as_deref(), Some("p1"));
412            }
413            _ => panic!("expected Response outcome"),
414        }
415    }
416
417    struct RetryProvider {
418        name: &'static str,
419        failures_before_success: usize,
420        error: ProviderError,
421        calls: AtomicUsize,
422    }
423
424    impl RetryProvider {
425        fn new(name: &'static str, failures_before_success: usize, error: ProviderError) -> Self {
426            Self {
427                name,
428                failures_before_success,
429                error,
430                calls: AtomicUsize::new(0),
431            }
432        }
433
434        fn calls(&self) -> usize {
435            self.calls.load(Ordering::Relaxed)
436        }
437    }
438
439    #[async_trait]
440    impl Provider for RetryProvider {
441        fn name(&self) -> &str {
442            self.name
443        }
444
445        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
446            Ok(ProviderRequest::new("http://example.com"))
447        }
448
449        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
450            let call = self.calls.fetch_add(1, Ordering::Relaxed);
451            if call < self.failures_before_success {
452                return Err(SimpleAgentsError::Provider(self.error.clone()));
453            }
454
455            Ok(ProviderResponse::new(
456                200,
457                serde_json::json!({"content": "ok"}),
458            ))
459        }
460
461        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
462            Ok(CompletionResponse {
463                id: "resp_retry".to_string(),
464                model: "test-model".to_string(),
465                choices: vec![CompletionChoice {
466                    index: 0,
467                    message: Message::assistant("ok"),
468                    finish_reason: FinishReason::Stop,
469                    logprobs: None,
470                }],
471                usage: Usage::new(1, 1),
472                created: None,
473                provider: Some(self.name.to_string()),
474                healing_metadata: None,
475            })
476        }
477    }
478
479    fn retry_test_config(max_attempts: u32, backoff_multiplier: f32) -> ClientConfig {
480        ClientConfig {
481            default_retry: RetryConfig {
482                max_attempts,
483                initial_backoff: Duration::ZERO,
484                max_backoff: Duration::ZERO,
485                backoff_multiplier,
486                jitter: false,
487            },
488            ..ClientConfig::default()
489        }
490    }
491
492    #[tokio::test]
493    async fn complete_retries_retryable_provider_errors() {
494        let provider = Arc::new(RetryProvider::new(
495            "retry",
496            2,
497            ProviderError::ServerError("temporary".to_string()),
498        ));
499        let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(3, 1.0));
500
501        let request = CompletionRequest::builder()
502            .model("gpt-4")
503            .message(Message::user("Hi"))
504            .build()
505            .unwrap();
506
507        let outcome = client
508            .complete(&request, CompletionOptions::default())
509            .await
510            .unwrap();
511
512        assert!(matches!(outcome, CompletionOutcome::Response(_)));
513        assert_eq!(provider.calls(), 3);
514    }
515
516    #[tokio::test]
517    async fn complete_does_not_retry_non_retryable_provider_errors() {
518        let provider = Arc::new(RetryProvider::new("retry", 1, ProviderError::InvalidApiKey));
519        let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(3, 1.0));
520
521        let request = CompletionRequest::builder()
522            .model("gpt-4")
523            .message(Message::user("Hi"))
524            .build()
525            .unwrap();
526
527        let result = client
528            .complete(&request, CompletionOptions::default())
529            .await;
530
531        assert!(result.is_err());
532        assert_eq!(provider.calls(), 1);
533    }
534
535    #[tokio::test]
536    async fn complete_does_not_retry_when_strategy_is_none() {
537        let provider = Arc::new(RetryProvider::new(
538            "retry",
539            1,
540            ProviderError::ServerError("temporary".to_string()),
541        ));
542        let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(1, 1.0));
543
544        let request = CompletionRequest::builder()
545            .model("gpt-4")
546            .message(Message::user("Hi"))
547            .build()
548            .unwrap();
549
550        let result = client
551            .complete(&request, CompletionOptions::default())
552            .await;
553
554        assert!(result.is_err());
555        assert_eq!(provider.calls(), 1);
556    }
557
558    #[test]
559    fn retry_delay_uses_backoff_multiplier() {
560        let error =
561            SimpleAgentsError::Provider(ProviderError::ServerError("temporary".to_string()));
562        let fixed = RetryConfig {
563            max_attempts: 3,
564            initial_backoff: Duration::from_millis(100),
565            max_backoff: Duration::from_millis(1_000),
566            backoff_multiplier: 1.0,
567            jitter: false,
568        };
569        let exponential = RetryConfig {
570            backoff_multiplier: 2.0,
571            ..fixed.clone()
572        };
573
574        assert_eq!(retry_delay(&fixed, 2, &error).as_millis(), 100);
575        assert_eq!(retry_delay(&exponential, 1, &error).as_millis(), 100);
576        assert_eq!(retry_delay(&exponential, 4, &error).as_millis(), 800);
577    }
578
579    struct StreamingProvider {
580        name: &'static str,
581        fail_after_first: bool,
582    }
583
584    impl StreamingProvider {
585        fn new(name: &'static str, fail_after_first: bool) -> Self {
586            Self {
587                name,
588                fail_after_first,
589            }
590        }
591
592        fn build_chunk(id: &str, content: &str) -> CompletionChunk {
593            CompletionChunk {
594                id: id.to_string(),
595                model: "test-model".to_string(),
596                choices: vec![ChoiceDelta {
597                    index: 0,
598                    delta: MessageDelta {
599                        role: Some(Role::Assistant),
600                        content: Some(content.to_string()),
601                        reasoning_content: None,
602                        tool_calls: None,
603                    },
604                    finish_reason: None,
605                }],
606                created: None,
607                usage: None,
608            }
609        }
610    }
611
612    #[async_trait]
613    impl Provider for StreamingProvider {
614        fn name(&self) -> &str {
615            self.name
616        }
617
618        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
619            Ok(ProviderRequest::new("http://example.com"))
620        }
621
622        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
623            Ok(ProviderResponse::new(
624                200,
625                serde_json::json!({"content": "ok"}),
626            ))
627        }
628
629        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
630            Ok(CompletionResponse {
631                id: "resp_stream".to_string(),
632                model: "test-model".to_string(),
633                choices: vec![CompletionChoice {
634                    index: 0,
635                    message: Message::assistant("ok"),
636                    finish_reason: FinishReason::Stop,
637                    logprobs: None,
638                }],
639                usage: Usage::new(1, 1),
640                created: None,
641                provider: Some(self.name.to_string()),
642                healing_metadata: None,
643            })
644        }
645
646        async fn execute_stream(
647            &self,
648            _req: ProviderRequest,
649        ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
650        {
651            let stream = if self.fail_after_first {
652                let items: Vec<Result<CompletionChunk>> = vec![
653                    Ok(Self::build_chunk("chunk-1", "hello")),
654                    Err(SimpleAgentsError::Provider(ProviderError::ServerError(
655                        "stream error".to_string(),
656                    ))),
657                ];
658                futures_util::stream::iter(items)
659            } else {
660                let items: Vec<Result<CompletionChunk>> =
661                    vec![Ok(Self::build_chunk("chunk-1", "hello"))];
662                futures_util::stream::iter(items)
663            };
664
665            Ok(Box::new(stream))
666        }
667    }
668
669    #[tokio::test]
670    async fn streaming_returns_chunks() {
671        let provider = Arc::new(StreamingProvider::new("p1", false));
672        let client = SimpleAgentsClient::new(provider);
673
674        let request = CompletionRequest::builder()
675            .model("gpt-4")
676            .message(Message::user("Hi"))
677            .stream(true)
678            .build()
679            .unwrap();
680
681        let outcome = client
682            .complete(&request, CompletionOptions::default())
683            .await
684            .unwrap();
685
686        let mut collected = Vec::new();
687        match outcome {
688            CompletionOutcome::Stream(mut stream) => {
689                while let Some(chunk) = stream.next().await {
690                    collected.push(chunk.unwrap());
691                }
692            }
693            _ => panic!("expected stream outcome"),
694        }
695
696        assert_eq!(collected.len(), 1);
697    }
698
699    #[tokio::test]
700    async fn streaming_propagates_error() {
701        let provider = Arc::new(StreamingProvider::new("p1", true));
702        let client = SimpleAgentsClient::new(provider);
703
704        let request = CompletionRequest::builder()
705            .model("gpt-4")
706            .message(Message::user("Hi"))
707            .stream(true)
708            .build()
709            .unwrap();
710
711        let outcome = client
712            .complete(&request, CompletionOptions::default())
713            .await
714            .unwrap();
715
716        let mut chunks = Vec::new();
717        match outcome {
718            CompletionOutcome::Stream(mut stream) => {
719                while let Some(chunk) = stream.next().await {
720                    chunks.push(chunk);
721                }
722            }
723            _ => panic!("expected stream outcome"),
724        }
725
726        assert_eq!(chunks.len(), 2);
727        assert!(chunks[0].is_ok());
728        assert!(chunks[1].is_err());
729    }
730}