1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use futures_util::future::BoxFuture;
8use futures_util::stream::{self, Stream};
9use futures_util::StreamExt;
10use simple_agent_type::cache::Cache;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use simple_agents_healing::coercion::CoercionEngine;
16use simple_agents_healing::parser::JsonishParser;
17use simple_agents_healing::schema::Schema;
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::RwLock;
22use tracing::debug;
23
24pub enum CompletionMode {
26 Standard,
28 HealedJson,
30 CoercedSchema(Schema),
32}
33
34pub struct CompletionOptions {
36 pub mode: CompletionMode,
38}
39
40impl Default for CompletionOptions {
41 fn default() -> Self {
42 Self {
43 mode: CompletionMode::Standard,
44 }
45 }
46}
47
48pub enum CompletionOutcome {
50 Response(CompletionResponse),
52 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
54 HealedJson(HealedJsonResponse),
56 CoercedSchema(HealedSchemaResponse),
58}
59
60struct ClientState {
61 providers: Vec<Arc<dyn Provider>>,
62 provider_map: HashMap<String, Arc<dyn Provider>>,
63 router: Arc<RouterEngine>,
64}
65
66pub struct SimpleAgentsClient {
68 state: RwLock<ClientState>,
69 routing_mode: RoutingMode,
70 cache: Option<Arc<dyn Cache>>,
71 cache_ttl: Duration,
72 healing: HealingSettings,
73 middleware: Vec<Arc<dyn Middleware>>,
74}
75
76impl SimpleAgentsClient {
77 pub fn builder() -> SimpleAgentsClientBuilder {
79 SimpleAgentsClientBuilder::new()
80 }
81
82 pub async fn provider_names(&self) -> Result<Vec<String>> {
84 let state = self.state.read().await;
85 Ok(state.provider_map.keys().cloned().collect())
86 }
87
88 pub async fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
90 let state = self.state.read().await;
91 Ok(state.provider_map.get(name).cloned())
92 }
93
94 pub async fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
96 let mut state = self.state.write().await;
97 let name = provider.name().to_string();
98
99 if state.provider_map.contains_key(&name) {
100 return Err(SimpleAgentsError::Config(format!(
101 "provider already registered: {}",
102 name
103 )));
104 }
105
106 state.provider_map.insert(name, provider.clone());
107 state.providers.push(provider);
108 state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
109 Ok(())
110 }
111
112 pub async fn complete(
114 &self,
115 request: &CompletionRequest,
116 options: CompletionOptions,
117 ) -> Result<CompletionOutcome> {
118 if request.stream.unwrap_or(false) {
119 let stream = self.stream(request).await?;
120 return Ok(CompletionOutcome::Stream(stream));
121 }
122
123 match options.mode {
124 CompletionMode::Standard => {
125 let response = self.complete_response(request).await?;
126 Ok(CompletionOutcome::Response(response))
127 }
128 CompletionMode::HealedJson => {
129 let healed = self.complete_json_internal(request).await?;
130 Ok(CompletionOutcome::HealedJson(healed))
131 }
132 CompletionMode::CoercedSchema(schema) => {
133 let healed = self.complete_with_schema_internal(request, &schema).await?;
134 Ok(CompletionOutcome::CoercedSchema(healed))
135 }
136 }
137 }
138
139 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
140 request.validate()?;
141 self.before_request(request).await?;
142
143 let cache_key = if let Some(cache) = &self.cache {
144 if cache.is_enabled() {
145 Some(self.cache_key(request)?)
146 } else {
147 None
148 }
149 } else {
150 None
151 };
152
153 if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
154 if let Some(cached) = cache.get(key).await? {
155 let response: CompletionResponse = serde_json::from_slice(&cached)?;
156 self.on_cache_hit(request, &response).await?;
157 return Ok(response);
158 }
159 }
160
161 let start = Instant::now();
162 let router = {
163 let state = self.state.read().await;
164 state.router.clone()
165 };
166 let response = router.complete(request).await;
167
168 match response {
169 Ok(response) => {
170 self.after_response(request, &response, start.elapsed())
171 .await?;
172 if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
173 let payload = serde_json::to_vec(&response)?;
174 cache.set(&key, payload, self.cache_ttl).await?;
175 }
176 Ok(response)
177 }
178 Err(error) => {
179 self.on_error(request, &error, start.elapsed()).await?;
180 Err(error)
181 }
182 }
183 }
184
185 async fn complete_json_internal(
187 &self,
188 request: &CompletionRequest,
189 ) -> Result<HealedJsonResponse> {
190 self.ensure_healing_enabled()?;
191 let response = self.complete_response(request).await?;
192 let content = response.content().ok_or_else(|| {
193 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
194 error_message: "response contained no content".to_string(),
195 input: String::new(),
196 })
197 })?;
198
199 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
200 let parsed = parser.parse(content)?;
201
202 Ok(HealedJsonResponse { response, parsed })
203 }
204
205 async fn complete_with_schema_internal(
207 &self,
208 request: &CompletionRequest,
209 schema: &Schema,
210 ) -> Result<HealedSchemaResponse> {
211 self.ensure_healing_enabled()?;
212 let healed = self.complete_json_internal(request).await?;
213 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
214 let coerced = engine
215 .coerce(&healed.parsed.value, schema)
216 .map_err(SimpleAgentsError::Healing)?;
217
218 Ok(HealedSchemaResponse {
219 response: healed.response,
220 parsed: healed.parsed,
221 coerced,
222 })
223 }
224
225 async fn stream(
227 &self,
228 request: &CompletionRequest,
229 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
230 request.validate()?;
231 self.before_request(request).await?;
232 debug!(
233 model = %request.model,
234 stream = ?request.stream,
235 "SimpleAgentsClient.stream start"
236 );
237
238 let router = {
239 let state = self.state.read().await;
240 state.router.clone()
241 };
242
243 let start = Instant::now();
244 let middleware = self.middleware.clone();
245 let instrumented_request = request.clone();
246 let inner = router.stream(request).await?;
247
248 let wrapped = Self::instrument_stream(inner, instrumented_request, middleware, start);
249 Ok(Box::new(wrapped))
250 }
251
252 fn ensure_healing_enabled(&self) -> Result<()> {
253 if self.healing.enabled {
254 Ok(())
255 } else {
256 Err(SimpleAgentsError::Config(
257 "healing is disabled for this client".to_string(),
258 ))
259 }
260 }
261
262 fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
263 let serialized = serde_json::to_string(request)?;
264 Ok(CacheKey::from_parts("core", &request.model, &serialized))
265 }
266
267 async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
268 for middleware in &self.middleware {
269 middleware.before_request(request).await?;
270 }
271 Ok(())
272 }
273
274 async fn after_response(
275 &self,
276 request: &CompletionRequest,
277 response: &CompletionResponse,
278 latency: Duration,
279 ) -> Result<()> {
280 for middleware in &self.middleware {
281 middleware
282 .after_response(request, response, latency)
283 .await?;
284 }
285 Ok(())
286 }
287
288 async fn on_cache_hit(
289 &self,
290 request: &CompletionRequest,
291 response: &CompletionResponse,
292 ) -> Result<()> {
293 for middleware in &self.middleware {
294 middleware.on_cache_hit(request, response).await?;
295 }
296 Ok(())
297 }
298
299 async fn on_error(
300 &self,
301 request: &CompletionRequest,
302 error: &SimpleAgentsError,
303 latency: Duration,
304 ) -> Result<()> {
305 for middleware in &self.middleware {
306 middleware.on_error(request, error, latency).await?;
307 }
308 Ok(())
309 }
310}
311
312impl SimpleAgentsClient {
313 fn instrument_stream(
314 inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
315 request: CompletionRequest,
316 middleware: Vec<Arc<dyn Middleware>>,
317 start: Instant,
318 ) -> impl Stream<Item = Result<CompletionChunk>> + Send + Unpin {
319 struct StreamState {
320 inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
321 middleware: Vec<Arc<dyn Middleware>>,
322 request: CompletionRequest,
323 start: Instant,
324 done: bool,
325 }
326
327 stream::unfold(
328 StreamState {
329 inner,
330 middleware,
331 request,
332 start,
333 done: false,
334 },
335 |mut state| -> BoxFuture<Option<(Result<CompletionChunk>, StreamState)>> {
336 Box::pin(async move {
337 if state.done {
338 return None;
339 }
340
341 match state.inner.next().await {
342 Some(Ok(chunk)) => Some((Ok(chunk), state)),
343 Some(Err(err)) => {
344 let latency = state.start.elapsed();
345 for middleware in &state.middleware {
346 if let Err(mw_err) =
347 middleware.on_error(&state.request, &err, latency).await
348 {
349 state.done = true;
350 return Some((Err(mw_err), state));
351 }
352 }
353 state.done = true;
354 Some((Err(err), state))
355 }
356 None => {
357 let latency = state.start.elapsed();
358 for middleware in &state.middleware {
359 if let Err(mw_err) =
360 middleware.after_stream(&state.request, latency).await
361 {
362 state.done = true;
363 return Some((Err(mw_err), state));
364 }
365 }
366 None
367 }
368 }
369 })
370 },
371 )
372 }
373}
374
375pub struct SimpleAgentsClientBuilder {
377 providers: Vec<Arc<dyn Provider>>,
378 routing_mode: RoutingMode,
379 cache: Option<Arc<dyn Cache>>,
380 cache_ttl: Duration,
381 healing: HealingSettings,
382 middleware: Vec<Arc<dyn Middleware>>,
383}
384
385impl SimpleAgentsClientBuilder {
386 pub fn new() -> Self {
388 Self {
389 providers: Vec::new(),
390 routing_mode: RoutingMode::default(),
391 cache: None,
392 cache_ttl: Duration::from_secs(60),
393 healing: HealingSettings::default(),
394 middleware: Vec::new(),
395 }
396 }
397
398 pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
400 self.providers.push(provider);
401 self
402 }
403
404 pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
406 self.providers.extend(providers);
407 self
408 }
409
410 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
412 self.routing_mode = mode;
413 self
414 }
415
416 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
418 self.cache = Some(cache);
419 self
420 }
421
422 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
424 self.cache_ttl = ttl;
425 self
426 }
427
428 pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
430 self.healing = settings;
431 self
432 }
433
434 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
436 self.middleware.push(middleware);
437 self
438 }
439
440 pub fn build(self) -> Result<SimpleAgentsClient> {
442 if self.providers.is_empty() {
443 return Err(SimpleAgentsError::Config(
444 "at least one provider is required".to_string(),
445 ));
446 }
447
448 let provider_map = self
449 .providers
450 .iter()
451 .map(|provider| (provider.name().to_string(), provider.clone()))
452 .collect::<HashMap<_, _>>();
453
454 let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
455 let state = ClientState {
456 providers: self.providers,
457 provider_map,
458 router,
459 };
460
461 Ok(SimpleAgentsClient {
462 state: RwLock::new(state),
463 routing_mode: self.routing_mode,
464 cache: self.cache,
465 cache_ttl: self.cache_ttl,
466 healing: self.healing,
467 middleware: self.middleware,
468 })
469 }
470}
471
472impl Default for SimpleAgentsClientBuilder {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478#[async_trait]
479impl Middleware for () {
480 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
481 Ok(())
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use futures_util::{stream, StreamExt};
489 use simple_agent_type::error::ProviderError;
490 use simple_agent_type::prelude::*;
491 use std::sync::atomic::{AtomicUsize, Ordering};
492 use std::time::Duration;
493
494 struct MockProvider {
495 name: &'static str,
496 calls: AtomicUsize,
497 }
498
499 impl MockProvider {
500 fn new(name: &'static str) -> Self {
501 Self {
502 name,
503 calls: AtomicUsize::new(0),
504 }
505 }
506 }
507
508 #[async_trait]
509 impl Provider for MockProvider {
510 fn name(&self) -> &str {
511 self.name
512 }
513
514 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
515 Ok(ProviderRequest::new("http://example.com"))
516 }
517
518 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
519 self.calls.fetch_add(1, Ordering::Relaxed);
520 Ok(ProviderResponse::new(
521 200,
522 serde_json::json!({"content": "ok"}),
523 ))
524 }
525
526 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
527 Ok(CompletionResponse {
528 id: "resp_test".to_string(),
529 model: "test-model".to_string(),
530 choices: vec![CompletionChoice {
531 index: 0,
532 message: Message::assistant("ok"),
533 finish_reason: FinishReason::Stop,
534 logprobs: None,
535 }],
536 usage: Usage::new(1, 1),
537 created: None,
538 provider: Some(self.name.to_string()),
539 healing_metadata: None,
540 })
541 }
542 }
543
544 #[tokio::test]
545 async fn client_build_requires_provider() {
546 let result = SimpleAgentsClientBuilder::new().build();
547 assert!(result.is_err());
548 }
549
550 #[tokio::test]
551 async fn register_provider_rebuilds_router() {
552 let provider = Arc::new(MockProvider::new("p1"));
553 let client = SimpleAgentsClientBuilder::new()
554 .with_provider(provider)
555 .build()
556 .unwrap();
557
558 let second = Arc::new(MockProvider::new("p2"));
559 client.register_provider(second).await.unwrap();
560
561 let names = client.provider_names().await.unwrap();
562 assert!(names.contains(&"p1".to_string()));
563 assert!(names.contains(&"p2".to_string()));
564 }
565
566 #[tokio::test]
567 async fn duplicate_provider_registration_fails() {
568 let provider = Arc::new(MockProvider::new("p1"));
569 let client = SimpleAgentsClientBuilder::new()
570 .with_provider(provider.clone())
571 .build()
572 .unwrap();
573
574 let result = client.register_provider(provider).await;
575 assert!(matches!(
576 result,
577 Err(SimpleAgentsError::Config(msg)) if msg.contains("provider already registered")
578 ));
579 }
580
581 #[derive(Default)]
582 struct RecordingMiddleware {
583 before: AtomicUsize,
584 after_stream: AtomicUsize,
585 errors: AtomicUsize,
586 }
587
588 #[async_trait]
589 impl Middleware for RecordingMiddleware {
590 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
591 self.before.fetch_add(1, Ordering::Relaxed);
592 Ok(())
593 }
594
595 async fn after_stream(
596 &self,
597 _request: &CompletionRequest,
598 _latency: Duration,
599 ) -> Result<()> {
600 self.after_stream.fetch_add(1, Ordering::Relaxed);
601 Ok(())
602 }
603
604 async fn on_error(
605 &self,
606 _request: &CompletionRequest,
607 _error: &SimpleAgentsError,
608 _latency: Duration,
609 ) -> Result<()> {
610 self.errors.fetch_add(1, Ordering::Relaxed);
611 Ok(())
612 }
613
614 fn name(&self) -> &str {
615 "recording"
616 }
617 }
618
619 struct StreamingProvider {
620 name: &'static str,
621 fail_after_first: bool,
622 }
623
624 impl StreamingProvider {
625 fn new(name: &'static str, fail_after_first: bool) -> Self {
626 Self {
627 name,
628 fail_after_first,
629 }
630 }
631
632 fn build_chunk(id: &str, content: &str) -> CompletionChunk {
633 CompletionChunk {
634 id: id.to_string(),
635 model: "test-model".to_string(),
636 choices: vec![ChoiceDelta {
637 index: 0,
638 delta: MessageDelta {
639 role: Some(Role::Assistant),
640 content: Some(content.to_string()),
641 },
642 finish_reason: None,
643 }],
644 created: None,
645 }
646 }
647 }
648
649 #[async_trait]
650 impl Provider for StreamingProvider {
651 fn name(&self) -> &str {
652 self.name
653 }
654
655 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
656 Ok(ProviderRequest::new("http://example.com"))
657 }
658
659 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
660 Ok(ProviderResponse::new(
661 200,
662 serde_json::json!({"content": "ok"}),
663 ))
664 }
665
666 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
667 Ok(CompletionResponse {
668 id: "resp_stream".to_string(),
669 model: "test-model".to_string(),
670 choices: vec![CompletionChoice {
671 index: 0,
672 message: Message::assistant("ok"),
673 finish_reason: FinishReason::Stop,
674 logprobs: None,
675 }],
676 usage: Usage::new(1, 1),
677 created: None,
678 provider: Some(self.name.to_string()),
679 healing_metadata: None,
680 })
681 }
682
683 async fn execute_stream(
684 &self,
685 _req: ProviderRequest,
686 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
687 {
688 let stream = if self.fail_after_first {
689 let items: Vec<Result<CompletionChunk>> = vec![
690 Ok(Self::build_chunk("chunk-1", "hello")),
691 Err(SimpleAgentsError::Provider(ProviderError::ServerError(
692 "stream error".to_string(),
693 ))),
694 ];
695 stream::iter(items)
696 } else {
697 let items: Vec<Result<CompletionChunk>> =
698 vec![Ok(Self::build_chunk("chunk-1", "hello"))];
699 stream::iter(items)
700 };
701
702 Ok(Box::new(stream))
703 }
704 }
705
706 #[tokio::test]
707 async fn streaming_invokes_after_stream_on_success() {
708 let provider = Arc::new(StreamingProvider::new("p1", false));
709 let middleware = Arc::new(RecordingMiddleware::default());
710
711 let client = SimpleAgentsClientBuilder::new()
712 .with_provider(provider)
713 .with_middleware(middleware.clone())
714 .build()
715 .unwrap();
716
717 let request = CompletionRequest::builder()
718 .model("gpt-4")
719 .message(Message::user("Hi"))
720 .stream(true)
721 .build()
722 .unwrap();
723
724 let outcome = client
725 .complete(&request, CompletionOptions::default())
726 .await
727 .unwrap();
728
729 let mut collected = Vec::new();
730 match outcome {
731 CompletionOutcome::Stream(mut stream) => {
732 while let Some(chunk) = stream.next().await {
733 collected.push(chunk.unwrap());
734 }
735 }
736 _ => panic!("expected stream outcome"),
737 }
738
739 assert_eq!(collected.len(), 1);
740 assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
741 assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 1);
742 assert_eq!(middleware.errors.load(Ordering::Relaxed), 0);
743 }
744
745 #[tokio::test]
746 async fn streaming_invokes_on_error_on_failure() {
747 let provider = Arc::new(StreamingProvider::new("p1", true));
748 let middleware = Arc::new(RecordingMiddleware::default());
749
750 let client = SimpleAgentsClientBuilder::new()
751 .with_provider(provider)
752 .with_middleware(middleware.clone())
753 .build()
754 .unwrap();
755
756 let request = CompletionRequest::builder()
757 .model("gpt-4")
758 .message(Message::user("Hi"))
759 .stream(true)
760 .build()
761 .unwrap();
762
763 let outcome = client
764 .complete(&request, CompletionOptions::default())
765 .await
766 .unwrap();
767
768 let mut chunks = Vec::new();
769 match outcome {
770 CompletionOutcome::Stream(mut stream) => {
771 while let Some(chunk) = stream.next().await {
772 chunks.push(chunk);
773 }
774 }
775 _ => panic!("expected stream outcome"),
776 }
777
778 assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
779 assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 0);
780 assert_eq!(middleware.errors.load(Ordering::Relaxed), 1);
781 assert_eq!(chunks.len(), 2);
782 assert!(chunks[0].as_ref().is_ok());
783 assert!(chunks[1].is_err());
784 }
785}