1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use rand::Rng;
5use serde::{Deserialize, Serialize};
6use simple_agent_type::prelude::{
7 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
8};
9use simple_agent_type::provider::RetryConfig;
10use simple_agent_type::telemetry::{ApiFormat, TelemetryConfig, TraceContext};
11use simple_agents_healing::coercion::CoercionEngine;
12use simple_agents_healing::parser::JsonishParser;
13use simple_agents_healing::schema::Schema;
14use std::sync::Arc;
15use std::time::Duration;
16use tracing::debug;
17
18#[derive(Clone, Serialize, Deserialize)]
28pub struct ClientConfig {
29 pub provider: String,
31 #[serde(skip_serializing)]
36 pub api_key: String,
37 pub base_url: Option<String>,
39 pub api_format: ApiFormat,
41 pub extra_headers: Option<Vec<(String, String)>>,
43 pub telemetry: Option<TelemetryConfig>,
45 pub default_retry: RetryConfig,
47}
48
49impl std::fmt::Debug for ClientConfig {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 let redacted = if self.api_key.is_empty() {
52 "<empty>"
53 } else {
54 "[REDACTED]"
55 };
56
57 f.debug_struct("ClientConfig")
58 .field("provider", &self.provider)
59 .field("api_key", &redacted)
60 .field("base_url", &self.base_url)
61 .field("api_format", &self.api_format)
62 .field("extra_headers", &self.extra_headers)
63 .field("telemetry", &self.telemetry)
64 .field("default_retry", &self.default_retry)
65 .finish()
66 }
67}
68
69impl Default for ClientConfig {
70 fn default() -> Self {
71 Self {
72 provider: "openai".into(),
73 api_key: String::new(),
74 base_url: None,
75 api_format: ApiFormat::default(),
76 extra_headers: None,
77 telemetry: None,
78 default_retry: RetryConfig::default(),
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ExecutionFlags {
86 pub workflow_streaming: bool,
88 pub node_llm_streaming: bool,
90}
91
92impl Default for ExecutionFlags {
93 fn default() -> Self {
94 Self {
95 workflow_streaming: false,
96 node_llm_streaming: true,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct RunOptions {
104 pub nerdstats: bool,
106 pub telemetry_enabled: bool,
108 pub trace_context: Option<TraceContext>,
110 pub execution_flags: ExecutionFlags,
112}
113
114impl Default for RunOptions {
115 fn default() -> Self {
116 Self {
117 nerdstats: true,
118 telemetry_enabled: true,
119 trace_context: None,
120 execution_flags: ExecutionFlags::default(),
121 }
122 }
123}
124
125#[derive(Clone)]
127pub enum CompletionMode {
128 Standard,
130 HealedJson,
132 CoercedSchema(Schema),
134}
135
136#[derive(Clone)]
138pub struct CompletionOptions {
139 pub mode: CompletionMode,
141}
142
143impl Default for CompletionOptions {
144 fn default() -> Self {
145 Self {
146 mode: CompletionMode::Standard,
147 }
148 }
149}
150
151pub enum CompletionOutcome {
153 Response(CompletionResponse),
155 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
157 HealedJson(HealedJsonResponse),
159 CoercedSchema(HealedSchemaResponse),
161}
162
163pub struct SimpleAgentsClient {
165 provider: Arc<dyn Provider>,
166 config: ClientConfig,
167 healing: HealingSettings,
168}
169
170impl SimpleAgentsClient {
171 pub fn new(provider: Arc<dyn Provider>) -> Self {
173 Self {
174 provider,
175 config: ClientConfig::default(),
176 healing: HealingSettings::default(),
177 }
178 }
179
180 pub fn from_config(provider: Arc<dyn Provider>, config: ClientConfig) -> Self {
182 Self {
183 provider,
184 config,
185 healing: HealingSettings::default(),
186 }
187 }
188
189 pub fn with_healing(provider: Arc<dyn Provider>, healing: HealingSettings) -> Self {
191 Self {
192 provider,
193 config: ClientConfig::default(),
194 healing,
195 }
196 }
197
198 pub fn config(&self) -> &ClientConfig {
200 &self.config
201 }
202
203 pub fn provider_name(&self) -> &str {
205 self.provider.name()
206 }
207
208 pub async fn complete(
210 &self,
211 request: &CompletionRequest,
212 options: CompletionOptions,
213 ) -> Result<CompletionOutcome> {
214 if request.stream.unwrap_or(false) {
215 if matches!(
216 options.mode,
217 CompletionMode::HealedJson | CompletionMode::CoercedSchema(_)
218 ) {
219 return Err(SimpleAgentsError::Config(
220 "streaming is incompatible with HealedJson/CoercedSchema modes; \
221 use Raw mode for streaming or disable streaming for structured output"
222 .to_string(),
223 ));
224 }
225 let stream = self.stream(request).await?;
226 return Ok(CompletionOutcome::Stream(stream));
227 }
228
229 match options.mode {
230 CompletionMode::Standard => {
231 let response = self.complete_response(request).await?;
232 Ok(CompletionOutcome::Response(response))
233 }
234 CompletionMode::HealedJson => {
235 let healed = self.complete_json_internal(request).await?;
236 Ok(CompletionOutcome::HealedJson(healed))
237 }
238 CompletionMode::CoercedSchema(schema) => {
239 let healed = self.complete_with_schema_internal(request, &schema).await?;
240 Ok(CompletionOutcome::CoercedSchema(healed))
241 }
242 }
243 }
244
245 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
246 request.validate()?;
247
248 let provider_request = self.provider.transform_request(request)?;
249 let provider_response = self.execute_with_retries(provider_request).await?;
250 self.provider.transform_response(provider_response)
251 }
252
253 async fn execute_with_retries(
254 &self,
255 provider_request: simple_agent_type::provider::ProviderRequest,
256 ) -> Result<simple_agent_type::provider::ProviderResponse> {
257 let retry = &self.config.default_retry;
258 let max_attempts = retry.max_attempts.max(1);
259 let mut attempt = 1;
260
261 loop {
262 match self.provider.execute(provider_request.clone()).await {
268 Ok(response) => return Ok(response),
269 Err(error) => {
270 if attempt >= max_attempts || !is_retryable_error(&error) {
271 return Err(error);
272 }
273
274 let delay = retry_delay(retry, attempt, &error);
275 if !delay.is_zero() {
276 tokio::time::sleep(delay).await;
277 }
278 attempt += 1;
279 }
280 }
281 }
282 }
283
284 async fn complete_json_internal(
285 &self,
286 request: &CompletionRequest,
287 ) -> Result<HealedJsonResponse> {
288 self.ensure_healing_enabled()?;
289 let response = self.complete_response(request).await?;
290 let content = response.content().ok_or_else(|| {
291 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
292 error_message: "response contained no content".to_string(),
293 input: String::new(),
294 })
295 })?;
296
297 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
298 let parsed = parser.parse(content)?;
299
300 Ok(HealedJsonResponse { response, parsed })
301 }
302
303 async fn complete_with_schema_internal(
304 &self,
305 request: &CompletionRequest,
306 schema: &Schema,
307 ) -> Result<HealedSchemaResponse> {
308 self.ensure_healing_enabled()?;
309 let healed = self.complete_json_internal(request).await?;
310 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
311 let coerced = engine
312 .coerce(&healed.parsed.value, schema)
313 .map_err(SimpleAgentsError::Healing)?;
314
315 Ok(HealedSchemaResponse {
316 response: healed.response,
317 parsed: healed.parsed,
318 coerced,
319 })
320 }
321
322 async fn stream(
323 &self,
324 request: &CompletionRequest,
325 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
326 request.validate()?;
327 debug!(
328 model = %request.model,
329 stream = ?request.stream,
330 "SimpleAgentsClient.stream start"
331 );
332
333 let provider_request = self.provider.transform_request(request)?;
334 self.provider.execute_stream(provider_request).await
335 }
336
337 fn ensure_healing_enabled(&self) -> Result<()> {
338 if self.healing.enabled {
339 Ok(())
340 } else {
341 Err(SimpleAgentsError::HealingDisabled)
342 }
343 }
344}
345
346fn is_retryable_error(error: &SimpleAgentsError) -> bool {
347 match error {
348 SimpleAgentsError::Provider(provider_error) => provider_error.is_retryable(),
349 SimpleAgentsError::Network(_) => true,
350 _ => false,
351 }
352}
353
354fn retry_after(error: &SimpleAgentsError) -> Option<Duration> {
355 match error {
356 SimpleAgentsError::Provider(simple_agent_type::error::ProviderError::RateLimit {
357 retry_after,
358 }) => *retry_after,
359 _ => None,
360 }
361}
362
363fn retry_delay(retry: &RetryConfig, failed_attempt: u32, error: &SimpleAgentsError) -> Duration {
364 if let Some(delay) = retry_after(error) {
365 return delay;
366 }
367
368 let factor = retry
369 .backoff_multiplier
370 .max(1.0)
371 .powi(failed_attempt.saturating_sub(1).min(31) as i32);
372 let delay = retry.initial_backoff.mul_f32(factor);
373 let delay = delay.min(retry.max_backoff.max(retry.initial_backoff));
374
375 if retry.jitter {
376 let jitter_factor = rand::thread_rng().gen_range(0.5..=1.5);
377 delay.mul_f64(jitter_factor)
378 } else {
379 delay
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use async_trait::async_trait;
387 use futures_util::StreamExt;
388 use simple_agent_type::error::ProviderError;
389 use simple_agent_type::prelude::*;
390 use std::sync::atomic::{AtomicUsize, Ordering};
391
392 struct MockProvider {
393 name: &'static str,
394 calls: AtomicUsize,
395 }
396
397 impl MockProvider {
398 fn new(name: &'static str) -> Self {
399 Self {
400 name,
401 calls: AtomicUsize::new(0),
402 }
403 }
404 }
405
406 #[async_trait]
407 impl Provider for MockProvider {
408 fn name(&self) -> &str {
409 self.name
410 }
411
412 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
413 Ok(ProviderRequest::new("http://example.com"))
414 }
415
416 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
417 self.calls.fetch_add(1, Ordering::Relaxed);
418 Ok(ProviderResponse::new(
419 200,
420 serde_json::json!({"content": "ok"}),
421 ))
422 }
423
424 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
425 Ok(CompletionResponse {
426 id: "resp_test".to_string(),
427 model: "test-model".to_string(),
428 choices: vec![CompletionChoice {
429 index: 0,
430 message: Message::assistant("ok"),
431 finish_reason: FinishReason::Stop,
432 logprobs: None,
433 }],
434 usage: Usage::new(1, 1),
435 created: None,
436 provider: Some(self.name.to_string()),
437 healing_metadata: None,
438 })
439 }
440 }
441
442 #[tokio::test]
443 async fn complete_returns_response() {
444 let provider = Arc::new(MockProvider::new("p1"));
445 let client = SimpleAgentsClient::new(provider);
446
447 let request = CompletionRequest::builder()
448 .model("gpt-4")
449 .message(Message::user("Hi"))
450 .build()
451 .unwrap();
452
453 let outcome = client
454 .complete(&request, CompletionOptions::default())
455 .await
456 .unwrap();
457
458 match outcome {
459 CompletionOutcome::Response(resp) => {
460 assert_eq!(resp.provider.as_deref(), Some("p1"));
461 }
462 _ => panic!("expected Response outcome"),
463 }
464 }
465
466 struct RetryProvider {
467 name: &'static str,
468 failures_before_success: usize,
469 error: ProviderError,
470 calls: AtomicUsize,
471 }
472
473 impl RetryProvider {
474 fn new(name: &'static str, failures_before_success: usize, error: ProviderError) -> Self {
475 Self {
476 name,
477 failures_before_success,
478 error,
479 calls: AtomicUsize::new(0),
480 }
481 }
482
483 fn calls(&self) -> usize {
484 self.calls.load(Ordering::Relaxed)
485 }
486 }
487
488 #[async_trait]
489 impl Provider for RetryProvider {
490 fn name(&self) -> &str {
491 self.name
492 }
493
494 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
495 Ok(ProviderRequest::new("http://example.com"))
496 }
497
498 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
499 let call = self.calls.fetch_add(1, Ordering::Relaxed);
500 if call < self.failures_before_success {
501 return Err(SimpleAgentsError::Provider(self.error.clone()));
502 }
503
504 Ok(ProviderResponse::new(
505 200,
506 serde_json::json!({"content": "ok"}),
507 ))
508 }
509
510 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
511 Ok(CompletionResponse {
512 id: "resp_retry".to_string(),
513 model: "test-model".to_string(),
514 choices: vec![CompletionChoice {
515 index: 0,
516 message: Message::assistant("ok"),
517 finish_reason: FinishReason::Stop,
518 logprobs: None,
519 }],
520 usage: Usage::new(1, 1),
521 created: None,
522 provider: Some(self.name.to_string()),
523 healing_metadata: None,
524 })
525 }
526 }
527
528 fn retry_test_config(max_attempts: u32, backoff_multiplier: f32) -> ClientConfig {
529 ClientConfig {
530 default_retry: RetryConfig {
531 max_attempts,
532 initial_backoff: Duration::ZERO,
533 max_backoff: Duration::ZERO,
534 backoff_multiplier,
535 jitter: false,
536 },
537 ..ClientConfig::default()
538 }
539 }
540
541 #[tokio::test]
542 async fn complete_retries_retryable_provider_errors() {
543 let provider = Arc::new(RetryProvider::new(
544 "retry",
545 2,
546 ProviderError::ServerError("temporary".to_string()),
547 ));
548 let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(3, 1.0));
549
550 let request = CompletionRequest::builder()
551 .model("gpt-4")
552 .message(Message::user("Hi"))
553 .build()
554 .unwrap();
555
556 let outcome = client
557 .complete(&request, CompletionOptions::default())
558 .await
559 .unwrap();
560
561 assert!(matches!(outcome, CompletionOutcome::Response(_)));
562 assert_eq!(provider.calls(), 3);
563 }
564
565 #[tokio::test]
566 async fn complete_does_not_retry_non_retryable_provider_errors() {
567 let provider = Arc::new(RetryProvider::new("retry", 1, ProviderError::InvalidApiKey));
568 let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(3, 1.0));
569
570 let request = CompletionRequest::builder()
571 .model("gpt-4")
572 .message(Message::user("Hi"))
573 .build()
574 .unwrap();
575
576 let result = client
577 .complete(&request, CompletionOptions::default())
578 .await;
579
580 assert!(result.is_err());
581 assert_eq!(provider.calls(), 1);
582 }
583
584 #[tokio::test]
585 async fn complete_does_not_retry_when_strategy_is_none() {
586 let provider = Arc::new(RetryProvider::new(
587 "retry",
588 1,
589 ProviderError::ServerError("temporary".to_string()),
590 ));
591 let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(1, 1.0));
592
593 let request = CompletionRequest::builder()
594 .model("gpt-4")
595 .message(Message::user("Hi"))
596 .build()
597 .unwrap();
598
599 let result = client
600 .complete(&request, CompletionOptions::default())
601 .await;
602
603 assert!(result.is_err());
604 assert_eq!(provider.calls(), 1);
605 }
606
607 #[test]
608 fn retry_delay_uses_backoff_multiplier() {
609 let error =
610 SimpleAgentsError::Provider(ProviderError::ServerError("temporary".to_string()));
611 let fixed = RetryConfig {
612 max_attempts: 3,
613 initial_backoff: Duration::from_millis(100),
614 max_backoff: Duration::from_millis(1_000),
615 backoff_multiplier: 1.0,
616 jitter: false,
617 };
618 let exponential = RetryConfig {
619 backoff_multiplier: 2.0,
620 ..fixed.clone()
621 };
622
623 assert_eq!(retry_delay(&fixed, 2, &error).as_millis(), 100);
624 assert_eq!(retry_delay(&exponential, 1, &error).as_millis(), 100);
625 assert_eq!(retry_delay(&exponential, 4, &error).as_millis(), 800);
626 }
627
628 #[test]
629 fn retry_delay_with_jitter_stays_within_expected_range() {
630 let error =
631 SimpleAgentsError::Provider(ProviderError::ServerError("temporary".to_string()));
632 let config = RetryConfig {
633 max_attempts: 3,
634 initial_backoff: Duration::from_millis(1_000),
635 max_backoff: Duration::from_millis(10_000),
636 backoff_multiplier: 1.0,
637 jitter: true,
638 };
639
640 let base_ms = 1_000u128;
641 let min_expected = base_ms / 2; let max_expected = base_ms * 3 / 2; for _ in 0..50 {
645 let delay = retry_delay(&config, 1, &error);
646 let ms = delay.as_millis();
647 assert!(
648 ms >= min_expected && ms <= max_expected,
649 "jittered delay {ms}ms outside expected range [{min_expected}, {max_expected}]",
650 );
651 }
652
653 let mut delays = std::collections::HashSet::new();
654 for _ in 0..20 {
655 delays.insert(retry_delay(&config, 1, &error).as_nanos());
656 }
657 assert!(
658 delays.len() > 1,
659 "expected jitter to produce varying delays, but got {} distinct value(s)",
660 delays.len(),
661 );
662 }
663
664 struct StreamingProvider {
665 name: &'static str,
666 fail_after_first: bool,
667 }
668
669 impl StreamingProvider {
670 fn new(name: &'static str, fail_after_first: bool) -> Self {
671 Self {
672 name,
673 fail_after_first,
674 }
675 }
676
677 fn build_chunk(id: &str, content: &str) -> CompletionChunk {
678 CompletionChunk {
679 id: id.to_string(),
680 model: "test-model".to_string(),
681 choices: vec![ChoiceDelta {
682 index: 0,
683 delta: MessageDelta {
684 role: Some(Role::Assistant),
685 content: Some(content.to_string()),
686 reasoning_content: None,
687 tool_calls: None,
688 },
689 finish_reason: None,
690 }],
691 created: None,
692 usage: None,
693 }
694 }
695 }
696
697 #[async_trait]
698 impl Provider for StreamingProvider {
699 fn name(&self) -> &str {
700 self.name
701 }
702
703 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
704 Ok(ProviderRequest::new("http://example.com"))
705 }
706
707 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
708 Ok(ProviderResponse::new(
709 200,
710 serde_json::json!({"content": "ok"}),
711 ))
712 }
713
714 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
715 Ok(CompletionResponse {
716 id: "resp_stream".to_string(),
717 model: "test-model".to_string(),
718 choices: vec![CompletionChoice {
719 index: 0,
720 message: Message::assistant("ok"),
721 finish_reason: FinishReason::Stop,
722 logprobs: None,
723 }],
724 usage: Usage::new(1, 1),
725 created: None,
726 provider: Some(self.name.to_string()),
727 healing_metadata: None,
728 })
729 }
730
731 async fn execute_stream(
732 &self,
733 _req: ProviderRequest,
734 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
735 {
736 let stream = if self.fail_after_first {
737 let items: Vec<Result<CompletionChunk>> = vec![
738 Ok(Self::build_chunk("chunk-1", "hello")),
739 Err(SimpleAgentsError::Provider(ProviderError::ServerError(
740 "stream error".to_string(),
741 ))),
742 ];
743 futures_util::stream::iter(items)
744 } else {
745 let items: Vec<Result<CompletionChunk>> =
746 vec![Ok(Self::build_chunk("chunk-1", "hello"))];
747 futures_util::stream::iter(items)
748 };
749
750 Ok(Box::new(stream))
751 }
752 }
753
754 #[tokio::test]
755 async fn streaming_returns_chunks() {
756 let provider = Arc::new(StreamingProvider::new("p1", false));
757 let client = SimpleAgentsClient::new(provider);
758
759 let request = CompletionRequest::builder()
760 .model("gpt-4")
761 .message(Message::user("Hi"))
762 .stream(true)
763 .build()
764 .unwrap();
765
766 let outcome = client
767 .complete(&request, CompletionOptions::default())
768 .await
769 .unwrap();
770
771 let mut collected = Vec::new();
772 match outcome {
773 CompletionOutcome::Stream(mut stream) => {
774 while let Some(chunk) = stream.next().await {
775 collected.push(chunk.unwrap());
776 }
777 }
778 _ => panic!("expected stream outcome"),
779 }
780
781 assert_eq!(collected.len(), 1);
782 }
783
784 #[tokio::test]
785 async fn streaming_propagates_error() {
786 let provider = Arc::new(StreamingProvider::new("p1", true));
787 let client = SimpleAgentsClient::new(provider);
788
789 let request = CompletionRequest::builder()
790 .model("gpt-4")
791 .message(Message::user("Hi"))
792 .stream(true)
793 .build()
794 .unwrap();
795
796 let outcome = client
797 .complete(&request, CompletionOptions::default())
798 .await
799 .unwrap();
800
801 let mut chunks = Vec::new();
802 match outcome {
803 CompletionOutcome::Stream(mut stream) => {
804 while let Some(chunk) = stream.next().await {
805 chunks.push(chunk);
806 }
807 }
808 _ => panic!("expected stream outcome"),
809 }
810
811 assert_eq!(chunks.len(), 2);
812 assert!(chunks[0].is_ok());
813 assert!(chunks[1].is_err());
814 }
815}