1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use serde::{Deserialize, Serialize};
5use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use simple_agent_type::provider::RetryConfig;
9use simple_agent_type::telemetry::{ApiFormat, TelemetryConfig, TraceContext};
10use simple_agents_healing::coercion::CoercionEngine;
11use simple_agents_healing::parser::JsonishParser;
12use simple_agents_healing::schema::Schema;
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::debug;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ClientConfig {
24 pub provider: String,
26 pub api_key: String,
28 pub base_url: Option<String>,
30 pub api_format: ApiFormat,
32 pub extra_headers: Option<Vec<(String, String)>>,
34 pub telemetry: Option<TelemetryConfig>,
36 pub default_retry: RetryConfig,
38}
39
40impl Default for ClientConfig {
41 fn default() -> Self {
42 Self {
43 provider: "openai".into(),
44 api_key: String::new(),
45 base_url: None,
46 api_format: ApiFormat::default(),
47 extra_headers: None,
48 telemetry: None,
49 default_retry: RetryConfig::default(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ExecutionFlags {
57 pub workflow_streaming: bool,
59 pub node_llm_streaming: bool,
61}
62
63impl Default for ExecutionFlags {
64 fn default() -> Self {
65 Self {
66 workflow_streaming: false,
67 node_llm_streaming: true,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct RunOptions {
75 pub nerdstats: bool,
77 pub telemetry_enabled: bool,
79 pub trace_context: Option<TraceContext>,
81 pub execution_flags: ExecutionFlags,
83}
84
85impl Default for RunOptions {
86 fn default() -> Self {
87 Self {
88 nerdstats: true,
89 telemetry_enabled: true,
90 trace_context: None,
91 execution_flags: ExecutionFlags::default(),
92 }
93 }
94}
95
96#[derive(Clone)]
98pub enum CompletionMode {
99 Standard,
101 HealedJson,
103 CoercedSchema(Schema),
105}
106
107#[derive(Clone)]
109pub struct CompletionOptions {
110 pub mode: CompletionMode,
112}
113
114impl Default for CompletionOptions {
115 fn default() -> Self {
116 Self {
117 mode: CompletionMode::Standard,
118 }
119 }
120}
121
122pub enum CompletionOutcome {
124 Response(CompletionResponse),
126 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
128 HealedJson(HealedJsonResponse),
130 CoercedSchema(HealedSchemaResponse),
132}
133
134pub struct SimpleAgentsClient {
136 provider: Arc<dyn Provider>,
137 config: ClientConfig,
138 healing: HealingSettings,
139}
140
141impl SimpleAgentsClient {
142 pub fn new(provider: Arc<dyn Provider>) -> Self {
144 Self {
145 provider,
146 config: ClientConfig::default(),
147 healing: HealingSettings::default(),
148 }
149 }
150
151 pub fn from_config(provider: Arc<dyn Provider>, config: ClientConfig) -> Self {
153 Self {
154 provider,
155 config,
156 healing: HealingSettings::default(),
157 }
158 }
159
160 pub fn with_healing(provider: Arc<dyn Provider>, healing: HealingSettings) -> Self {
162 Self {
163 provider,
164 config: ClientConfig::default(),
165 healing,
166 }
167 }
168
169 pub fn config(&self) -> &ClientConfig {
171 &self.config
172 }
173
174 pub fn provider_name(&self) -> &str {
176 self.provider.name()
177 }
178
179 pub async fn complete(
181 &self,
182 request: &CompletionRequest,
183 options: CompletionOptions,
184 ) -> Result<CompletionOutcome> {
185 if request.stream.unwrap_or(false) {
186 let stream = self.stream(request).await?;
187 return Ok(CompletionOutcome::Stream(stream));
188 }
189
190 match options.mode {
191 CompletionMode::Standard => {
192 let response = self.complete_response(request).await?;
193 Ok(CompletionOutcome::Response(response))
194 }
195 CompletionMode::HealedJson => {
196 let healed = self.complete_json_internal(request).await?;
197 Ok(CompletionOutcome::HealedJson(healed))
198 }
199 CompletionMode::CoercedSchema(schema) => {
200 let healed = self.complete_with_schema_internal(request, &schema).await?;
201 Ok(CompletionOutcome::CoercedSchema(healed))
202 }
203 }
204 }
205
206 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
207 request.validate()?;
208
209 let provider_request = self.provider.transform_request(request)?;
210 let provider_response = self.execute_with_retries(provider_request).await?;
211 self.provider.transform_response(provider_response)
212 }
213
214 async fn execute_with_retries(
215 &self,
216 provider_request: simple_agent_type::provider::ProviderRequest,
217 ) -> Result<simple_agent_type::provider::ProviderResponse> {
218 let retry = &self.config.default_retry;
219 let max_attempts = retry.max_attempts.max(1);
220 let mut attempt = 1;
221
222 loop {
223 match self.provider.execute(provider_request.clone()).await {
224 Ok(response) => return Ok(response),
225 Err(error) => {
226 if attempt >= max_attempts || !is_retryable_error(&error) {
227 return Err(error);
228 }
229
230 let delay = retry_delay(retry, attempt, &error);
231 if !delay.is_zero() {
232 tokio::time::sleep(delay).await;
233 }
234 attempt += 1;
235 }
236 }
237 }
238 }
239
240 async fn complete_json_internal(
241 &self,
242 request: &CompletionRequest,
243 ) -> Result<HealedJsonResponse> {
244 self.ensure_healing_enabled()?;
245 let response = self.complete_response(request).await?;
246 let content = response.content().ok_or_else(|| {
247 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
248 error_message: "response contained no content".to_string(),
249 input: String::new(),
250 })
251 })?;
252
253 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
254 let parsed = parser.parse(content)?;
255
256 Ok(HealedJsonResponse { response, parsed })
257 }
258
259 async fn complete_with_schema_internal(
260 &self,
261 request: &CompletionRequest,
262 schema: &Schema,
263 ) -> Result<HealedSchemaResponse> {
264 self.ensure_healing_enabled()?;
265 let healed = self.complete_json_internal(request).await?;
266 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
267 let coerced = engine
268 .coerce(&healed.parsed.value, schema)
269 .map_err(SimpleAgentsError::Healing)?;
270
271 Ok(HealedSchemaResponse {
272 response: healed.response,
273 parsed: healed.parsed,
274 coerced,
275 })
276 }
277
278 async fn stream(
279 &self,
280 request: &CompletionRequest,
281 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
282 request.validate()?;
283 debug!(
284 model = %request.model,
285 stream = ?request.stream,
286 "SimpleAgentsClient.stream start"
287 );
288
289 let provider_request = self.provider.transform_request(request)?;
290 self.provider.execute_stream(provider_request).await
291 }
292
293 fn ensure_healing_enabled(&self) -> Result<()> {
294 if self.healing.enabled {
295 Ok(())
296 } else {
297 Err(SimpleAgentsError::Config(
298 "healing is disabled for this client".to_string(),
299 ))
300 }
301 }
302}
303
304fn is_retryable_error(error: &SimpleAgentsError) -> bool {
305 match error {
306 SimpleAgentsError::Provider(provider_error) => provider_error.is_retryable(),
307 SimpleAgentsError::Network(_) => true,
308 _ => false,
309 }
310}
311
312fn retry_after(error: &SimpleAgentsError) -> Option<Duration> {
313 match error {
314 SimpleAgentsError::Provider(simple_agent_type::error::ProviderError::RateLimit {
315 retry_after,
316 }) => *retry_after,
317 _ => None,
318 }
319}
320
321fn retry_delay(retry: &RetryConfig, failed_attempt: u32, error: &SimpleAgentsError) -> Duration {
322 if let Some(delay) = retry_after(error) {
323 return delay;
324 }
325
326 let factor = retry
327 .backoff_multiplier
328 .max(1.0)
329 .powi(failed_attempt.saturating_sub(1).min(31) as i32);
330 let delay = retry.initial_backoff.mul_f32(factor);
331 delay.min(retry.max_backoff.max(retry.initial_backoff))
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use async_trait::async_trait;
338 use futures_util::StreamExt;
339 use simple_agent_type::error::ProviderError;
340 use simple_agent_type::prelude::*;
341 use std::sync::atomic::{AtomicUsize, Ordering};
342
343 struct MockProvider {
344 name: &'static str,
345 calls: AtomicUsize,
346 }
347
348 impl MockProvider {
349 fn new(name: &'static str) -> Self {
350 Self {
351 name,
352 calls: AtomicUsize::new(0),
353 }
354 }
355 }
356
357 #[async_trait]
358 impl Provider for MockProvider {
359 fn name(&self) -> &str {
360 self.name
361 }
362
363 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
364 Ok(ProviderRequest::new("http://example.com"))
365 }
366
367 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
368 self.calls.fetch_add(1, Ordering::Relaxed);
369 Ok(ProviderResponse::new(
370 200,
371 serde_json::json!({"content": "ok"}),
372 ))
373 }
374
375 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
376 Ok(CompletionResponse {
377 id: "resp_test".to_string(),
378 model: "test-model".to_string(),
379 choices: vec![CompletionChoice {
380 index: 0,
381 message: Message::assistant("ok"),
382 finish_reason: FinishReason::Stop,
383 logprobs: None,
384 }],
385 usage: Usage::new(1, 1),
386 created: None,
387 provider: Some(self.name.to_string()),
388 healing_metadata: None,
389 })
390 }
391 }
392
393 #[tokio::test]
394 async fn complete_returns_response() {
395 let provider = Arc::new(MockProvider::new("p1"));
396 let client = SimpleAgentsClient::new(provider);
397
398 let request = CompletionRequest::builder()
399 .model("gpt-4")
400 .message(Message::user("Hi"))
401 .build()
402 .unwrap();
403
404 let outcome = client
405 .complete(&request, CompletionOptions::default())
406 .await
407 .unwrap();
408
409 match outcome {
410 CompletionOutcome::Response(resp) => {
411 assert_eq!(resp.provider.as_deref(), Some("p1"));
412 }
413 _ => panic!("expected Response outcome"),
414 }
415 }
416
417 struct RetryProvider {
418 name: &'static str,
419 failures_before_success: usize,
420 error: ProviderError,
421 calls: AtomicUsize,
422 }
423
424 impl RetryProvider {
425 fn new(name: &'static str, failures_before_success: usize, error: ProviderError) -> Self {
426 Self {
427 name,
428 failures_before_success,
429 error,
430 calls: AtomicUsize::new(0),
431 }
432 }
433
434 fn calls(&self) -> usize {
435 self.calls.load(Ordering::Relaxed)
436 }
437 }
438
439 #[async_trait]
440 impl Provider for RetryProvider {
441 fn name(&self) -> &str {
442 self.name
443 }
444
445 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
446 Ok(ProviderRequest::new("http://example.com"))
447 }
448
449 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
450 let call = self.calls.fetch_add(1, Ordering::Relaxed);
451 if call < self.failures_before_success {
452 return Err(SimpleAgentsError::Provider(self.error.clone()));
453 }
454
455 Ok(ProviderResponse::new(
456 200,
457 serde_json::json!({"content": "ok"}),
458 ))
459 }
460
461 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
462 Ok(CompletionResponse {
463 id: "resp_retry".to_string(),
464 model: "test-model".to_string(),
465 choices: vec![CompletionChoice {
466 index: 0,
467 message: Message::assistant("ok"),
468 finish_reason: FinishReason::Stop,
469 logprobs: None,
470 }],
471 usage: Usage::new(1, 1),
472 created: None,
473 provider: Some(self.name.to_string()),
474 healing_metadata: None,
475 })
476 }
477 }
478
479 fn retry_test_config(max_attempts: u32, backoff_multiplier: f32) -> ClientConfig {
480 ClientConfig {
481 default_retry: RetryConfig {
482 max_attempts,
483 initial_backoff: Duration::ZERO,
484 max_backoff: Duration::ZERO,
485 backoff_multiplier,
486 jitter: false,
487 },
488 ..ClientConfig::default()
489 }
490 }
491
492 #[tokio::test]
493 async fn complete_retries_retryable_provider_errors() {
494 let provider = Arc::new(RetryProvider::new(
495 "retry",
496 2,
497 ProviderError::ServerError("temporary".to_string()),
498 ));
499 let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(3, 1.0));
500
501 let request = CompletionRequest::builder()
502 .model("gpt-4")
503 .message(Message::user("Hi"))
504 .build()
505 .unwrap();
506
507 let outcome = client
508 .complete(&request, CompletionOptions::default())
509 .await
510 .unwrap();
511
512 assert!(matches!(outcome, CompletionOutcome::Response(_)));
513 assert_eq!(provider.calls(), 3);
514 }
515
516 #[tokio::test]
517 async fn complete_does_not_retry_non_retryable_provider_errors() {
518 let provider = Arc::new(RetryProvider::new("retry", 1, ProviderError::InvalidApiKey));
519 let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(3, 1.0));
520
521 let request = CompletionRequest::builder()
522 .model("gpt-4")
523 .message(Message::user("Hi"))
524 .build()
525 .unwrap();
526
527 let result = client
528 .complete(&request, CompletionOptions::default())
529 .await;
530
531 assert!(result.is_err());
532 assert_eq!(provider.calls(), 1);
533 }
534
535 #[tokio::test]
536 async fn complete_does_not_retry_when_strategy_is_none() {
537 let provider = Arc::new(RetryProvider::new(
538 "retry",
539 1,
540 ProviderError::ServerError("temporary".to_string()),
541 ));
542 let client = SimpleAgentsClient::from_config(provider.clone(), retry_test_config(1, 1.0));
543
544 let request = CompletionRequest::builder()
545 .model("gpt-4")
546 .message(Message::user("Hi"))
547 .build()
548 .unwrap();
549
550 let result = client
551 .complete(&request, CompletionOptions::default())
552 .await;
553
554 assert!(result.is_err());
555 assert_eq!(provider.calls(), 1);
556 }
557
558 #[test]
559 fn retry_delay_uses_backoff_multiplier() {
560 let error =
561 SimpleAgentsError::Provider(ProviderError::ServerError("temporary".to_string()));
562 let fixed = RetryConfig {
563 max_attempts: 3,
564 initial_backoff: Duration::from_millis(100),
565 max_backoff: Duration::from_millis(1_000),
566 backoff_multiplier: 1.0,
567 jitter: false,
568 };
569 let exponential = RetryConfig {
570 backoff_multiplier: 2.0,
571 ..fixed.clone()
572 };
573
574 assert_eq!(retry_delay(&fixed, 2, &error).as_millis(), 100);
575 assert_eq!(retry_delay(&exponential, 1, &error).as_millis(), 100);
576 assert_eq!(retry_delay(&exponential, 4, &error).as_millis(), 800);
577 }
578
579 struct StreamingProvider {
580 name: &'static str,
581 fail_after_first: bool,
582 }
583
584 impl StreamingProvider {
585 fn new(name: &'static str, fail_after_first: bool) -> Self {
586 Self {
587 name,
588 fail_after_first,
589 }
590 }
591
592 fn build_chunk(id: &str, content: &str) -> CompletionChunk {
593 CompletionChunk {
594 id: id.to_string(),
595 model: "test-model".to_string(),
596 choices: vec![ChoiceDelta {
597 index: 0,
598 delta: MessageDelta {
599 role: Some(Role::Assistant),
600 content: Some(content.to_string()),
601 reasoning_content: None,
602 tool_calls: None,
603 },
604 finish_reason: None,
605 }],
606 created: None,
607 usage: None,
608 }
609 }
610 }
611
612 #[async_trait]
613 impl Provider for StreamingProvider {
614 fn name(&self) -> &str {
615 self.name
616 }
617
618 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
619 Ok(ProviderRequest::new("http://example.com"))
620 }
621
622 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
623 Ok(ProviderResponse::new(
624 200,
625 serde_json::json!({"content": "ok"}),
626 ))
627 }
628
629 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
630 Ok(CompletionResponse {
631 id: "resp_stream".to_string(),
632 model: "test-model".to_string(),
633 choices: vec![CompletionChoice {
634 index: 0,
635 message: Message::assistant("ok"),
636 finish_reason: FinishReason::Stop,
637 logprobs: None,
638 }],
639 usage: Usage::new(1, 1),
640 created: None,
641 provider: Some(self.name.to_string()),
642 healing_metadata: None,
643 })
644 }
645
646 async fn execute_stream(
647 &self,
648 _req: ProviderRequest,
649 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
650 {
651 let stream = if self.fail_after_first {
652 let items: Vec<Result<CompletionChunk>> = vec![
653 Ok(Self::build_chunk("chunk-1", "hello")),
654 Err(SimpleAgentsError::Provider(ProviderError::ServerError(
655 "stream error".to_string(),
656 ))),
657 ];
658 futures_util::stream::iter(items)
659 } else {
660 let items: Vec<Result<CompletionChunk>> =
661 vec![Ok(Self::build_chunk("chunk-1", "hello"))];
662 futures_util::stream::iter(items)
663 };
664
665 Ok(Box::new(stream))
666 }
667 }
668
669 #[tokio::test]
670 async fn streaming_returns_chunks() {
671 let provider = Arc::new(StreamingProvider::new("p1", false));
672 let client = SimpleAgentsClient::new(provider);
673
674 let request = CompletionRequest::builder()
675 .model("gpt-4")
676 .message(Message::user("Hi"))
677 .stream(true)
678 .build()
679 .unwrap();
680
681 let outcome = client
682 .complete(&request, CompletionOptions::default())
683 .await
684 .unwrap();
685
686 let mut collected = Vec::new();
687 match outcome {
688 CompletionOutcome::Stream(mut stream) => {
689 while let Some(chunk) = stream.next().await {
690 collected.push(chunk.unwrap());
691 }
692 }
693 _ => panic!("expected stream outcome"),
694 }
695
696 assert_eq!(collected.len(), 1);
697 }
698
699 #[tokio::test]
700 async fn streaming_propagates_error() {
701 let provider = Arc::new(StreamingProvider::new("p1", true));
702 let client = SimpleAgentsClient::new(provider);
703
704 let request = CompletionRequest::builder()
705 .model("gpt-4")
706 .message(Message::user("Hi"))
707 .stream(true)
708 .build()
709 .unwrap();
710
711 let outcome = client
712 .complete(&request, CompletionOptions::default())
713 .await
714 .unwrap();
715
716 let mut chunks = Vec::new();
717 match outcome {
718 CompletionOutcome::Stream(mut stream) => {
719 while let Some(chunk) = stream.next().await {
720 chunks.push(chunk);
721 }
722 }
723 _ => panic!("expected stream outcome"),
724 }
725
726 assert_eq!(chunks.len(), 2);
727 assert!(chunks[0].is_ok());
728 assert!(chunks[1].is_err());
729 }
730}