1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use futures_util::future::BoxFuture;
8use futures_util::stream::{self, Stream};
9use futures_util::StreamExt;
10use simple_agent_type::cache::Cache;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use simple_agents_healing::coercion::CoercionEngine;
16use simple_agents_healing::parser::JsonishParser;
17use simple_agents_healing::schema::Schema;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::RwLock;
22use tracing::debug;
23
24#[derive(Clone)]
26pub enum CompletionMode {
27 Standard,
29 HealedJson,
31 CoercedSchema(Schema),
33}
34
35#[derive(Clone)]
37pub struct CompletionOptions {
38 pub mode: CompletionMode,
40}
41
42impl Default for CompletionOptions {
43 fn default() -> Self {
44 Self {
45 mode: CompletionMode::Standard,
46 }
47 }
48}
49
50pub enum CompletionOutcome {
52 Response(CompletionResponse),
54 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
56 HealedJson(HealedJsonResponse),
58 CoercedSchema(HealedSchemaResponse),
60}
61
62struct ClientState {
63 providers: Vec<Arc<dyn Provider>>,
64 provider_map: HashMap<String, Arc<dyn Provider>>,
65 router: Arc<RouterEngine>,
66}
67
68pub struct SimpleAgentsClient {
70 state: RwLock<ClientState>,
71 routing_mode: RoutingMode,
72 cache: Option<Arc<dyn Cache>>,
73 cache_ttl: Duration,
74 healing: HealingSettings,
75 middleware: Vec<Arc<dyn Middleware>>,
76}
77
78impl SimpleAgentsClient {
79 pub fn builder() -> SimpleAgentsClientBuilder {
81 SimpleAgentsClientBuilder::new()
82 }
83
84 pub async fn provider_names(&self) -> Result<Vec<String>> {
86 let state = self.state.read().await;
87 Ok(state.provider_map.keys().cloned().collect())
88 }
89
90 pub async fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
92 let state = self.state.read().await;
93 Ok(state.provider_map.get(name).cloned())
94 }
95
96 pub async fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
98 let mut state = self.state.write().await;
99 let name = provider.name().to_string();
100
101 if state.provider_map.contains_key(&name) {
102 return Err(SimpleAgentsError::Config(format!(
103 "provider already registered: {}",
104 name
105 )));
106 }
107
108 state.provider_map.insert(name, provider.clone());
109 state.providers.push(provider);
110 state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
111 Ok(())
112 }
113
114 pub async fn complete(
116 &self,
117 request: &CompletionRequest,
118 options: CompletionOptions,
119 ) -> Result<CompletionOutcome> {
120 if request.stream.unwrap_or(false) {
121 let stream = self.stream(request).await?;
122 return Ok(CompletionOutcome::Stream(stream));
123 }
124
125 match options.mode {
126 CompletionMode::Standard => {
127 let response = self.complete_response(request).await?;
128 Ok(CompletionOutcome::Response(response))
129 }
130 CompletionMode::HealedJson => {
131 let healed = self.complete_json_internal(request).await?;
132 Ok(CompletionOutcome::HealedJson(healed))
133 }
134 CompletionMode::CoercedSchema(schema) => {
135 let healed = self.complete_with_schema_internal(request, &schema).await?;
136 Ok(CompletionOutcome::CoercedSchema(healed))
137 }
138 }
139 }
140
141 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
142 request.validate()?;
143 self.before_request(request).await?;
144
145 let cache_key = if let Some(cache) = &self.cache {
146 if cache.is_enabled() {
147 Some(self.cache_key(request)?)
148 } else {
149 None
150 }
151 } else {
152 None
153 };
154
155 if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
156 if let Some(cached) = cache.get(key).await? {
157 let response: CompletionResponse = serde_json::from_slice(&cached)?;
158 self.on_cache_hit(request, &response).await?;
159 return Ok(response);
160 }
161 }
162
163 let start = Instant::now();
164 let router = {
165 let state = self.state.read().await;
166 state.router.clone()
167 };
168 let response = router.complete(request).await;
169
170 match response {
171 Ok(response) => {
172 self.after_response(request, &response, start.elapsed())
173 .await?;
174 if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
175 let payload = serde_json::to_vec(&response)?;
176 cache.set(&key, payload, self.cache_ttl).await?;
177 }
178 Ok(response)
179 }
180 Err(error) => {
181 self.on_error(request, &error, start.elapsed()).await?;
182 Err(error)
183 }
184 }
185 }
186
187 async fn complete_json_internal(
189 &self,
190 request: &CompletionRequest,
191 ) -> Result<HealedJsonResponse> {
192 self.ensure_healing_enabled()?;
193 let response = self.complete_response(request).await?;
194 let content = response.content().ok_or_else(|| {
195 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
196 error_message: "response contained no content".to_string(),
197 input: String::new(),
198 })
199 })?;
200
201 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
202 let parsed = parser.parse(content)?;
203
204 Ok(HealedJsonResponse { response, parsed })
205 }
206
207 async fn complete_with_schema_internal(
209 &self,
210 request: &CompletionRequest,
211 schema: &Schema,
212 ) -> Result<HealedSchemaResponse> {
213 self.ensure_healing_enabled()?;
214 let healed = self.complete_json_internal(request).await?;
215 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
216 let coerced = engine
217 .coerce(&healed.parsed.value, schema)
218 .map_err(SimpleAgentsError::Healing)?;
219
220 Ok(HealedSchemaResponse {
221 response: healed.response,
222 parsed: healed.parsed,
223 coerced,
224 })
225 }
226
227 async fn stream(
229 &self,
230 request: &CompletionRequest,
231 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
232 request.validate()?;
233 self.before_request(request).await?;
234 debug!(
235 model = %request.model,
236 stream = ?request.stream,
237 "SimpleAgentsClient.stream start"
238 );
239
240 let router = {
241 let state = self.state.read().await;
242 state.router.clone()
243 };
244
245 let start = Instant::now();
246 let middleware = self.middleware.clone();
247 let instrumented_request = request.clone();
248 let inner = router.stream(request).await?;
249
250 let wrapped = Self::instrument_stream(inner, instrumented_request, middleware, start);
251 Ok(Box::new(wrapped))
252 }
253
254 fn ensure_healing_enabled(&self) -> Result<()> {
255 if self.healing.enabled {
256 Ok(())
257 } else {
258 Err(SimpleAgentsError::Config(
259 "healing is disabled for this client".to_string(),
260 ))
261 }
262 }
263
264 fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
265 let serialized = serde_json::to_string(request)?;
266 Ok(CacheKey::from_parts("core", &request.model, &serialized))
267 }
268
269 async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
270 for middleware in &self.middleware {
271 middleware.before_request(request).await?;
272 }
273 Ok(())
274 }
275
276 async fn after_response(
277 &self,
278 request: &CompletionRequest,
279 response: &CompletionResponse,
280 latency: Duration,
281 ) -> Result<()> {
282 for middleware in &self.middleware {
283 middleware
284 .after_response(request, response, latency)
285 .await?;
286 }
287 Ok(())
288 }
289
290 async fn on_cache_hit(
291 &self,
292 request: &CompletionRequest,
293 response: &CompletionResponse,
294 ) -> Result<()> {
295 for middleware in &self.middleware {
296 middleware.on_cache_hit(request, response).await?;
297 }
298 Ok(())
299 }
300
301 async fn on_error(
302 &self,
303 request: &CompletionRequest,
304 error: &SimpleAgentsError,
305 latency: Duration,
306 ) -> Result<()> {
307 for middleware in &self.middleware {
308 middleware.on_error(request, error, latency).await?;
309 }
310 Ok(())
311 }
312}
313
314impl SimpleAgentsClient {
315 fn instrument_stream(
316 inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
317 request: CompletionRequest,
318 middleware: Vec<Arc<dyn Middleware>>,
319 start: Instant,
320 ) -> impl Stream<Item = Result<CompletionChunk>> + Send + Unpin {
321 struct StreamState {
322 inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
323 middleware: Vec<Arc<dyn Middleware>>,
324 request: CompletionRequest,
325 start: Instant,
326 done: bool,
327 }
328
329 stream::unfold(
330 StreamState {
331 inner,
332 middleware,
333 request,
334 start,
335 done: false,
336 },
337 |mut state| -> BoxFuture<Option<(Result<CompletionChunk>, StreamState)>> {
338 Box::pin(async move {
339 if state.done {
340 return None;
341 }
342
343 match state.inner.next().await {
344 Some(Ok(chunk)) => Some((Ok(chunk), state)),
345 Some(Err(err)) => {
346 let latency = state.start.elapsed();
347 for middleware in &state.middleware {
348 if let Err(mw_err) =
349 middleware.on_error(&state.request, &err, latency).await
350 {
351 state.done = true;
352 return Some((Err(mw_err), state));
353 }
354 }
355 state.done = true;
356 Some((Err(err), state))
357 }
358 None => {
359 let latency = state.start.elapsed();
360 for middleware in &state.middleware {
361 if let Err(mw_err) =
362 middleware.after_stream(&state.request, latency).await
363 {
364 state.done = true;
365 return Some((Err(mw_err), state));
366 }
367 }
368 None
369 }
370 }
371 })
372 },
373 )
374 }
375}
376
377pub struct SimpleAgentsClientBuilder {
379 providers: Vec<Arc<dyn Provider>>,
380 routing_mode: RoutingMode,
381 cache: Option<Arc<dyn Cache>>,
382 cache_ttl: Duration,
383 healing: HealingSettings,
384 middleware: Vec<Arc<dyn Middleware>>,
385}
386
387impl SimpleAgentsClientBuilder {
388 pub fn new() -> Self {
390 Self {
391 providers: Vec::new(),
392 routing_mode: RoutingMode::default(),
393 cache: None,
394 cache_ttl: Duration::from_secs(60),
395 healing: HealingSettings::default(),
396 middleware: Vec::new(),
397 }
398 }
399
400 pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
402 self.providers.push(provider);
403 self
404 }
405
406 pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
408 self.providers.extend(providers);
409 self
410 }
411
412 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
414 self.routing_mode = mode;
415 self
416 }
417
418 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
420 self.cache = Some(cache);
421 self
422 }
423
424 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
426 self.cache_ttl = ttl;
427 self
428 }
429
430 pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
432 self.healing = settings;
433 self
434 }
435
436 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
438 self.middleware.push(middleware);
439 self
440 }
441
442 pub fn build(self) -> Result<SimpleAgentsClient> {
444 if self.providers.is_empty() {
445 return Err(SimpleAgentsError::Config(
446 "at least one provider is required".to_string(),
447 ));
448 }
449
450 let mut seen = HashSet::new();
451 for provider in &self.providers {
452 let name = provider.name();
453 if !seen.insert(name.to_string()) {
454 return Err(SimpleAgentsError::Config(format!(
455 "duplicate provider configured in builder: {}",
456 name
457 )));
458 }
459 }
460
461 let provider_map = self
462 .providers
463 .iter()
464 .map(|provider| (provider.name().to_string(), provider.clone()))
465 .collect::<HashMap<_, _>>();
466
467 let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
468 let state = ClientState {
469 providers: self.providers,
470 provider_map,
471 router,
472 };
473
474 Ok(SimpleAgentsClient {
475 state: RwLock::new(state),
476 routing_mode: self.routing_mode,
477 cache: self.cache,
478 cache_ttl: self.cache_ttl,
479 healing: self.healing,
480 middleware: self.middleware,
481 })
482 }
483}
484
485impl Default for SimpleAgentsClientBuilder {
486 fn default() -> Self {
487 Self::new()
488 }
489}
490
491#[async_trait]
492impl Middleware for () {
493 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
494 Ok(())
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use futures_util::{stream, StreamExt};
502 use simple_agent_type::error::ProviderError;
503 use simple_agent_type::prelude::*;
504 use std::sync::atomic::{AtomicUsize, Ordering};
505 use std::time::Duration;
506
507 struct MockProvider {
508 name: &'static str,
509 calls: AtomicUsize,
510 }
511
512 impl MockProvider {
513 fn new(name: &'static str) -> Self {
514 Self {
515 name,
516 calls: AtomicUsize::new(0),
517 }
518 }
519 }
520
521 #[async_trait]
522 impl Provider for MockProvider {
523 fn name(&self) -> &str {
524 self.name
525 }
526
527 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
528 Ok(ProviderRequest::new("http://example.com"))
529 }
530
531 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
532 self.calls.fetch_add(1, Ordering::Relaxed);
533 Ok(ProviderResponse::new(
534 200,
535 serde_json::json!({"content": "ok"}),
536 ))
537 }
538
539 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
540 Ok(CompletionResponse {
541 id: "resp_test".to_string(),
542 model: "test-model".to_string(),
543 choices: vec![CompletionChoice {
544 index: 0,
545 message: Message::assistant("ok"),
546 finish_reason: FinishReason::Stop,
547 logprobs: None,
548 }],
549 usage: Usage::new(1, 1),
550 created: None,
551 provider: Some(self.name.to_string()),
552 healing_metadata: None,
553 })
554 }
555 }
556
557 #[tokio::test]
558 async fn client_build_requires_provider() {
559 let result = SimpleAgentsClientBuilder::new().build();
560 assert!(result.is_err());
561 }
562
563 #[tokio::test]
564 async fn register_provider_rebuilds_router() {
565 let provider = Arc::new(MockProvider::new("p1"));
566 let client = SimpleAgentsClientBuilder::new()
567 .with_provider(provider)
568 .build()
569 .unwrap();
570
571 let second = Arc::new(MockProvider::new("p2"));
572 client.register_provider(second).await.unwrap();
573
574 let names = client.provider_names().await.unwrap();
575 assert!(names.contains(&"p1".to_string()));
576 assert!(names.contains(&"p2".to_string()));
577 }
578
579 #[tokio::test]
580 async fn duplicate_provider_registration_fails() {
581 let provider = Arc::new(MockProvider::new("p1"));
582 let client = SimpleAgentsClientBuilder::new()
583 .with_provider(provider.clone())
584 .build()
585 .unwrap();
586
587 let result = client.register_provider(provider).await;
588 assert!(matches!(
589 result,
590 Err(SimpleAgentsError::Config(msg)) if msg.contains("provider already registered")
591 ));
592 }
593
594 #[tokio::test]
595 async fn duplicate_provider_in_builder_with_provider_fails() {
596 let p1 = Arc::new(MockProvider::new("p1"));
597 let p1_dup = Arc::new(MockProvider::new("p1"));
598
599 let result = SimpleAgentsClientBuilder::new()
600 .with_provider(p1)
601 .with_provider(p1_dup)
602 .build();
603
604 assert!(matches!(
605 result,
606 Err(SimpleAgentsError::Config(msg)) if msg.contains("duplicate provider configured in builder")
607 ));
608 }
609
610 #[tokio::test]
611 async fn duplicate_provider_in_builder_with_providers_fails() {
612 let result = SimpleAgentsClientBuilder::new()
613 .with_providers(vec![
614 Arc::new(MockProvider::new("p1")),
615 Arc::new(MockProvider::new("p1")),
616 ])
617 .build();
618
619 assert!(matches!(
620 result,
621 Err(SimpleAgentsError::Config(msg)) if msg.contains("duplicate provider configured in builder")
622 ));
623 }
624
625 #[derive(Default)]
626 struct RecordingMiddleware {
627 before: AtomicUsize,
628 after_stream: AtomicUsize,
629 errors: AtomicUsize,
630 }
631
632 #[async_trait]
633 impl Middleware for RecordingMiddleware {
634 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
635 self.before.fetch_add(1, Ordering::Relaxed);
636 Ok(())
637 }
638
639 async fn after_stream(
640 &self,
641 _request: &CompletionRequest,
642 _latency: Duration,
643 ) -> Result<()> {
644 self.after_stream.fetch_add(1, Ordering::Relaxed);
645 Ok(())
646 }
647
648 async fn on_error(
649 &self,
650 _request: &CompletionRequest,
651 _error: &SimpleAgentsError,
652 _latency: Duration,
653 ) -> Result<()> {
654 self.errors.fetch_add(1, Ordering::Relaxed);
655 Ok(())
656 }
657
658 fn name(&self) -> &str {
659 "recording"
660 }
661 }
662
663 struct StreamingProvider {
664 name: &'static str,
665 fail_after_first: bool,
666 }
667
668 impl StreamingProvider {
669 fn new(name: &'static str, fail_after_first: bool) -> Self {
670 Self {
671 name,
672 fail_after_first,
673 }
674 }
675
676 fn build_chunk(id: &str, content: &str) -> CompletionChunk {
677 CompletionChunk {
678 id: id.to_string(),
679 model: "test-model".to_string(),
680 choices: vec![ChoiceDelta {
681 index: 0,
682 delta: MessageDelta {
683 role: Some(Role::Assistant),
684 content: Some(content.to_string()),
685 reasoning_content: None,
686 tool_calls: None,
687 },
688 finish_reason: None,
689 }],
690 created: None,
691 usage: None,
692 }
693 }
694 }
695
696 #[async_trait]
697 impl Provider for StreamingProvider {
698 fn name(&self) -> &str {
699 self.name
700 }
701
702 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
703 Ok(ProviderRequest::new("http://example.com"))
704 }
705
706 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
707 Ok(ProviderResponse::new(
708 200,
709 serde_json::json!({"content": "ok"}),
710 ))
711 }
712
713 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
714 Ok(CompletionResponse {
715 id: "resp_stream".to_string(),
716 model: "test-model".to_string(),
717 choices: vec![CompletionChoice {
718 index: 0,
719 message: Message::assistant("ok"),
720 finish_reason: FinishReason::Stop,
721 logprobs: None,
722 }],
723 usage: Usage::new(1, 1),
724 created: None,
725 provider: Some(self.name.to_string()),
726 healing_metadata: None,
727 })
728 }
729
730 async fn execute_stream(
731 &self,
732 _req: ProviderRequest,
733 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
734 {
735 let stream = if self.fail_after_first {
736 let items: Vec<Result<CompletionChunk>> = vec![
737 Ok(Self::build_chunk("chunk-1", "hello")),
738 Err(SimpleAgentsError::Provider(ProviderError::ServerError(
739 "stream error".to_string(),
740 ))),
741 ];
742 stream::iter(items)
743 } else {
744 let items: Vec<Result<CompletionChunk>> =
745 vec![Ok(Self::build_chunk("chunk-1", "hello"))];
746 stream::iter(items)
747 };
748
749 Ok(Box::new(stream))
750 }
751 }
752
753 #[tokio::test]
754 async fn streaming_invokes_after_stream_on_success() {
755 let provider = Arc::new(StreamingProvider::new("p1", false));
756 let middleware = Arc::new(RecordingMiddleware::default());
757
758 let client = SimpleAgentsClientBuilder::new()
759 .with_provider(provider)
760 .with_middleware(middleware.clone())
761 .build()
762 .unwrap();
763
764 let request = CompletionRequest::builder()
765 .model("gpt-4")
766 .message(Message::user("Hi"))
767 .stream(true)
768 .build()
769 .unwrap();
770
771 let outcome = client
772 .complete(&request, CompletionOptions::default())
773 .await
774 .unwrap();
775
776 let mut collected = Vec::new();
777 match outcome {
778 CompletionOutcome::Stream(mut stream) => {
779 while let Some(chunk) = stream.next().await {
780 collected.push(chunk.unwrap());
781 }
782 }
783 _ => panic!("expected stream outcome"),
784 }
785
786 assert_eq!(collected.len(), 1);
787 assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
788 assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 1);
789 assert_eq!(middleware.errors.load(Ordering::Relaxed), 0);
790 }
791
792 #[tokio::test]
793 async fn streaming_invokes_on_error_on_failure() {
794 let provider = Arc::new(StreamingProvider::new("p1", true));
795 let middleware = Arc::new(RecordingMiddleware::default());
796
797 let client = SimpleAgentsClientBuilder::new()
798 .with_provider(provider)
799 .with_middleware(middleware.clone())
800 .build()
801 .unwrap();
802
803 let request = CompletionRequest::builder()
804 .model("gpt-4")
805 .message(Message::user("Hi"))
806 .stream(true)
807 .build()
808 .unwrap();
809
810 let outcome = client
811 .complete(&request, CompletionOptions::default())
812 .await
813 .unwrap();
814
815 let mut chunks = Vec::new();
816 match outcome {
817 CompletionOutcome::Stream(mut stream) => {
818 while let Some(chunk) = stream.next().await {
819 chunks.push(chunk);
820 }
821 }
822 _ => panic!("expected stream outcome"),
823 }
824
825 assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
826 assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 0);
827 assert_eq!(middleware.errors.load(Ordering::Relaxed), 1);
828 assert_eq!(chunks.len(), 2);
829 assert!(chunks[0].as_ref().is_ok());
830 assert!(chunks[1].is_err());
831 }
832}