1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use futures_util::future::BoxFuture;
8use futures_util::stream::{self, Stream};
9use futures_util::StreamExt;
10use simple_agent_type::cache::Cache;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use simple_agents_healing::coercion::CoercionEngine;
16use simple_agents_healing::parser::JsonishParser;
17use simple_agents_healing::schema::Schema;
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::RwLock;
22use tracing::debug;
23
24#[derive(Clone)]
26pub enum CompletionMode {
27 Standard,
29 HealedJson,
31 CoercedSchema(Schema),
33}
34
35#[derive(Clone)]
37pub struct CompletionOptions {
38 pub mode: CompletionMode,
40}
41
42impl Default for CompletionOptions {
43 fn default() -> Self {
44 Self {
45 mode: CompletionMode::Standard,
46 }
47 }
48}
49
50pub enum CompletionOutcome {
52 Response(CompletionResponse),
54 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
56 HealedJson(HealedJsonResponse),
58 CoercedSchema(HealedSchemaResponse),
60}
61
62struct ClientState {
63 providers: Vec<Arc<dyn Provider>>,
64 provider_map: HashMap<String, Arc<dyn Provider>>,
65 router: Arc<RouterEngine>,
66}
67
68pub struct SimpleAgentsClient {
70 state: RwLock<ClientState>,
71 routing_mode: RoutingMode,
72 cache: Option<Arc<dyn Cache>>,
73 cache_ttl: Duration,
74 healing: HealingSettings,
75 middleware: Vec<Arc<dyn Middleware>>,
76}
77
78impl SimpleAgentsClient {
79 pub fn builder() -> SimpleAgentsClientBuilder {
81 SimpleAgentsClientBuilder::new()
82 }
83
84 pub async fn provider_names(&self) -> Result<Vec<String>> {
86 let state = self.state.read().await;
87 Ok(state.provider_map.keys().cloned().collect())
88 }
89
90 pub async fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
92 let state = self.state.read().await;
93 Ok(state.provider_map.get(name).cloned())
94 }
95
96 pub async fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
98 let mut state = self.state.write().await;
99 let name = provider.name().to_string();
100
101 if state.provider_map.contains_key(&name) {
102 return Err(SimpleAgentsError::Config(format!(
103 "provider already registered: {}",
104 name
105 )));
106 }
107
108 state.provider_map.insert(name, provider.clone());
109 state.providers.push(provider);
110 state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
111 Ok(())
112 }
113
114 pub async fn complete(
116 &self,
117 request: &CompletionRequest,
118 options: CompletionOptions,
119 ) -> Result<CompletionOutcome> {
120 if request.stream.unwrap_or(false) {
121 let stream = self.stream(request).await?;
122 return Ok(CompletionOutcome::Stream(stream));
123 }
124
125 match options.mode {
126 CompletionMode::Standard => {
127 let response = self.complete_response(request).await?;
128 Ok(CompletionOutcome::Response(response))
129 }
130 CompletionMode::HealedJson => {
131 let healed = self.complete_json_internal(request).await?;
132 Ok(CompletionOutcome::HealedJson(healed))
133 }
134 CompletionMode::CoercedSchema(schema) => {
135 let healed = self.complete_with_schema_internal(request, &schema).await?;
136 Ok(CompletionOutcome::CoercedSchema(healed))
137 }
138 }
139 }
140
141 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
142 request.validate()?;
143 self.before_request(request).await?;
144
145 let cache_key = if let Some(cache) = &self.cache {
146 if cache.is_enabled() {
147 Some(self.cache_key(request)?)
148 } else {
149 None
150 }
151 } else {
152 None
153 };
154
155 if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
156 if let Some(cached) = cache.get(key).await? {
157 let response: CompletionResponse = serde_json::from_slice(&cached)?;
158 self.on_cache_hit(request, &response).await?;
159 return Ok(response);
160 }
161 }
162
163 let start = Instant::now();
164 let router = {
165 let state = self.state.read().await;
166 state.router.clone()
167 };
168 let response = router.complete(request).await;
169
170 match response {
171 Ok(response) => {
172 self.after_response(request, &response, start.elapsed())
173 .await?;
174 if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
175 let payload = serde_json::to_vec(&response)?;
176 cache.set(&key, payload, self.cache_ttl).await?;
177 }
178 Ok(response)
179 }
180 Err(error) => {
181 self.on_error(request, &error, start.elapsed()).await?;
182 Err(error)
183 }
184 }
185 }
186
187 async fn complete_json_internal(
189 &self,
190 request: &CompletionRequest,
191 ) -> Result<HealedJsonResponse> {
192 self.ensure_healing_enabled()?;
193 let response = self.complete_response(request).await?;
194 let content = response.content().ok_or_else(|| {
195 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
196 error_message: "response contained no content".to_string(),
197 input: String::new(),
198 })
199 })?;
200
201 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
202 let parsed = parser.parse(content)?;
203
204 Ok(HealedJsonResponse { response, parsed })
205 }
206
207 async fn complete_with_schema_internal(
209 &self,
210 request: &CompletionRequest,
211 schema: &Schema,
212 ) -> Result<HealedSchemaResponse> {
213 self.ensure_healing_enabled()?;
214 let healed = self.complete_json_internal(request).await?;
215 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
216 let coerced = engine
217 .coerce(&healed.parsed.value, schema)
218 .map_err(SimpleAgentsError::Healing)?;
219
220 Ok(HealedSchemaResponse {
221 response: healed.response,
222 parsed: healed.parsed,
223 coerced,
224 })
225 }
226
227 async fn stream(
229 &self,
230 request: &CompletionRequest,
231 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
232 request.validate()?;
233 self.before_request(request).await?;
234 debug!(
235 model = %request.model,
236 stream = ?request.stream,
237 "SimpleAgentsClient.stream start"
238 );
239
240 let router = {
241 let state = self.state.read().await;
242 state.router.clone()
243 };
244
245 let start = Instant::now();
246 let middleware = self.middleware.clone();
247 let instrumented_request = request.clone();
248 let inner = router.stream(request).await?;
249
250 let wrapped = Self::instrument_stream(inner, instrumented_request, middleware, start);
251 Ok(Box::new(wrapped))
252 }
253
254 fn ensure_healing_enabled(&self) -> Result<()> {
255 if self.healing.enabled {
256 Ok(())
257 } else {
258 Err(SimpleAgentsError::Config(
259 "healing is disabled for this client".to_string(),
260 ))
261 }
262 }
263
264 fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
265 let serialized = serde_json::to_string(request)?;
266 Ok(CacheKey::from_parts("core", &request.model, &serialized))
267 }
268
269 async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
270 for middleware in &self.middleware {
271 middleware.before_request(request).await?;
272 }
273 Ok(())
274 }
275
276 async fn after_response(
277 &self,
278 request: &CompletionRequest,
279 response: &CompletionResponse,
280 latency: Duration,
281 ) -> Result<()> {
282 for middleware in &self.middleware {
283 middleware
284 .after_response(request, response, latency)
285 .await?;
286 }
287 Ok(())
288 }
289
290 async fn on_cache_hit(
291 &self,
292 request: &CompletionRequest,
293 response: &CompletionResponse,
294 ) -> Result<()> {
295 for middleware in &self.middleware {
296 middleware.on_cache_hit(request, response).await?;
297 }
298 Ok(())
299 }
300
301 async fn on_error(
302 &self,
303 request: &CompletionRequest,
304 error: &SimpleAgentsError,
305 latency: Duration,
306 ) -> Result<()> {
307 for middleware in &self.middleware {
308 middleware.on_error(request, error, latency).await?;
309 }
310 Ok(())
311 }
312}
313
314impl SimpleAgentsClient {
315 fn instrument_stream(
316 inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
317 request: CompletionRequest,
318 middleware: Vec<Arc<dyn Middleware>>,
319 start: Instant,
320 ) -> impl Stream<Item = Result<CompletionChunk>> + Send + Unpin {
321 struct StreamState {
322 inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
323 middleware: Vec<Arc<dyn Middleware>>,
324 request: CompletionRequest,
325 start: Instant,
326 done: bool,
327 }
328
329 stream::unfold(
330 StreamState {
331 inner,
332 middleware,
333 request,
334 start,
335 done: false,
336 },
337 |mut state| -> BoxFuture<Option<(Result<CompletionChunk>, StreamState)>> {
338 Box::pin(async move {
339 if state.done {
340 return None;
341 }
342
343 match state.inner.next().await {
344 Some(Ok(chunk)) => Some((Ok(chunk), state)),
345 Some(Err(err)) => {
346 let latency = state.start.elapsed();
347 for middleware in &state.middleware {
348 if let Err(mw_err) =
349 middleware.on_error(&state.request, &err, latency).await
350 {
351 state.done = true;
352 return Some((Err(mw_err), state));
353 }
354 }
355 state.done = true;
356 Some((Err(err), state))
357 }
358 None => {
359 let latency = state.start.elapsed();
360 for middleware in &state.middleware {
361 if let Err(mw_err) =
362 middleware.after_stream(&state.request, latency).await
363 {
364 state.done = true;
365 return Some((Err(mw_err), state));
366 }
367 }
368 None
369 }
370 }
371 })
372 },
373 )
374 }
375}
376
377pub struct SimpleAgentsClientBuilder {
379 providers: Vec<Arc<dyn Provider>>,
380 routing_mode: RoutingMode,
381 cache: Option<Arc<dyn Cache>>,
382 cache_ttl: Duration,
383 healing: HealingSettings,
384 middleware: Vec<Arc<dyn Middleware>>,
385}
386
387impl SimpleAgentsClientBuilder {
388 pub fn new() -> Self {
390 Self {
391 providers: Vec::new(),
392 routing_mode: RoutingMode::default(),
393 cache: None,
394 cache_ttl: Duration::from_secs(60),
395 healing: HealingSettings::default(),
396 middleware: Vec::new(),
397 }
398 }
399
400 pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
402 self.providers.push(provider);
403 self
404 }
405
406 pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
408 self.providers.extend(providers);
409 self
410 }
411
412 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
414 self.routing_mode = mode;
415 self
416 }
417
418 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
420 self.cache = Some(cache);
421 self
422 }
423
424 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
426 self.cache_ttl = ttl;
427 self
428 }
429
430 pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
432 self.healing = settings;
433 self
434 }
435
436 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
438 self.middleware.push(middleware);
439 self
440 }
441
442 pub fn build(self) -> Result<SimpleAgentsClient> {
444 if self.providers.is_empty() {
445 return Err(SimpleAgentsError::Config(
446 "at least one provider is required".to_string(),
447 ));
448 }
449
450 let provider_map = self
451 .providers
452 .iter()
453 .map(|provider| (provider.name().to_string(), provider.clone()))
454 .collect::<HashMap<_, _>>();
455
456 let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
457 let state = ClientState {
458 providers: self.providers,
459 provider_map,
460 router,
461 };
462
463 Ok(SimpleAgentsClient {
464 state: RwLock::new(state),
465 routing_mode: self.routing_mode,
466 cache: self.cache,
467 cache_ttl: self.cache_ttl,
468 healing: self.healing,
469 middleware: self.middleware,
470 })
471 }
472}
473
474impl Default for SimpleAgentsClientBuilder {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480#[async_trait]
481impl Middleware for () {
482 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
483 Ok(())
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use futures_util::{stream, StreamExt};
491 use simple_agent_type::error::ProviderError;
492 use simple_agent_type::prelude::*;
493 use std::sync::atomic::{AtomicUsize, Ordering};
494 use std::time::Duration;
495
496 struct MockProvider {
497 name: &'static str,
498 calls: AtomicUsize,
499 }
500
501 impl MockProvider {
502 fn new(name: &'static str) -> Self {
503 Self {
504 name,
505 calls: AtomicUsize::new(0),
506 }
507 }
508 }
509
510 #[async_trait]
511 impl Provider for MockProvider {
512 fn name(&self) -> &str {
513 self.name
514 }
515
516 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
517 Ok(ProviderRequest::new("http://example.com"))
518 }
519
520 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
521 self.calls.fetch_add(1, Ordering::Relaxed);
522 Ok(ProviderResponse::new(
523 200,
524 serde_json::json!({"content": "ok"}),
525 ))
526 }
527
528 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
529 Ok(CompletionResponse {
530 id: "resp_test".to_string(),
531 model: "test-model".to_string(),
532 choices: vec![CompletionChoice {
533 index: 0,
534 message: Message::assistant("ok"),
535 finish_reason: FinishReason::Stop,
536 logprobs: None,
537 }],
538 usage: Usage::new(1, 1),
539 created: None,
540 provider: Some(self.name.to_string()),
541 healing_metadata: None,
542 })
543 }
544 }
545
546 #[tokio::test]
547 async fn client_build_requires_provider() {
548 let result = SimpleAgentsClientBuilder::new().build();
549 assert!(result.is_err());
550 }
551
552 #[tokio::test]
553 async fn register_provider_rebuilds_router() {
554 let provider = Arc::new(MockProvider::new("p1"));
555 let client = SimpleAgentsClientBuilder::new()
556 .with_provider(provider)
557 .build()
558 .unwrap();
559
560 let second = Arc::new(MockProvider::new("p2"));
561 client.register_provider(second).await.unwrap();
562
563 let names = client.provider_names().await.unwrap();
564 assert!(names.contains(&"p1".to_string()));
565 assert!(names.contains(&"p2".to_string()));
566 }
567
568 #[tokio::test]
569 async fn duplicate_provider_registration_fails() {
570 let provider = Arc::new(MockProvider::new("p1"));
571 let client = SimpleAgentsClientBuilder::new()
572 .with_provider(provider.clone())
573 .build()
574 .unwrap();
575
576 let result = client.register_provider(provider).await;
577 assert!(matches!(
578 result,
579 Err(SimpleAgentsError::Config(msg)) if msg.contains("provider already registered")
580 ));
581 }
582
583 #[derive(Default)]
584 struct RecordingMiddleware {
585 before: AtomicUsize,
586 after_stream: AtomicUsize,
587 errors: AtomicUsize,
588 }
589
590 #[async_trait]
591 impl Middleware for RecordingMiddleware {
592 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
593 self.before.fetch_add(1, Ordering::Relaxed);
594 Ok(())
595 }
596
597 async fn after_stream(
598 &self,
599 _request: &CompletionRequest,
600 _latency: Duration,
601 ) -> Result<()> {
602 self.after_stream.fetch_add(1, Ordering::Relaxed);
603 Ok(())
604 }
605
606 async fn on_error(
607 &self,
608 _request: &CompletionRequest,
609 _error: &SimpleAgentsError,
610 _latency: Duration,
611 ) -> Result<()> {
612 self.errors.fetch_add(1, Ordering::Relaxed);
613 Ok(())
614 }
615
616 fn name(&self) -> &str {
617 "recording"
618 }
619 }
620
621 struct StreamingProvider {
622 name: &'static str,
623 fail_after_first: bool,
624 }
625
626 impl StreamingProvider {
627 fn new(name: &'static str, fail_after_first: bool) -> Self {
628 Self {
629 name,
630 fail_after_first,
631 }
632 }
633
634 fn build_chunk(id: &str, content: &str) -> CompletionChunk {
635 CompletionChunk {
636 id: id.to_string(),
637 model: "test-model".to_string(),
638 choices: vec![ChoiceDelta {
639 index: 0,
640 delta: MessageDelta {
641 role: Some(Role::Assistant),
642 content: Some(content.to_string()),
643 },
644 finish_reason: None,
645 }],
646 created: None,
647 }
648 }
649 }
650
651 #[async_trait]
652 impl Provider for StreamingProvider {
653 fn name(&self) -> &str {
654 self.name
655 }
656
657 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
658 Ok(ProviderRequest::new("http://example.com"))
659 }
660
661 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
662 Ok(ProviderResponse::new(
663 200,
664 serde_json::json!({"content": "ok"}),
665 ))
666 }
667
668 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
669 Ok(CompletionResponse {
670 id: "resp_stream".to_string(),
671 model: "test-model".to_string(),
672 choices: vec![CompletionChoice {
673 index: 0,
674 message: Message::assistant("ok"),
675 finish_reason: FinishReason::Stop,
676 logprobs: None,
677 }],
678 usage: Usage::new(1, 1),
679 created: None,
680 provider: Some(self.name.to_string()),
681 healing_metadata: None,
682 })
683 }
684
685 async fn execute_stream(
686 &self,
687 _req: ProviderRequest,
688 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
689 {
690 let stream = if self.fail_after_first {
691 let items: Vec<Result<CompletionChunk>> = vec![
692 Ok(Self::build_chunk("chunk-1", "hello")),
693 Err(SimpleAgentsError::Provider(ProviderError::ServerError(
694 "stream error".to_string(),
695 ))),
696 ];
697 stream::iter(items)
698 } else {
699 let items: Vec<Result<CompletionChunk>> =
700 vec![Ok(Self::build_chunk("chunk-1", "hello"))];
701 stream::iter(items)
702 };
703
704 Ok(Box::new(stream))
705 }
706 }
707
708 #[tokio::test]
709 async fn streaming_invokes_after_stream_on_success() {
710 let provider = Arc::new(StreamingProvider::new("p1", false));
711 let middleware = Arc::new(RecordingMiddleware::default());
712
713 let client = SimpleAgentsClientBuilder::new()
714 .with_provider(provider)
715 .with_middleware(middleware.clone())
716 .build()
717 .unwrap();
718
719 let request = CompletionRequest::builder()
720 .model("gpt-4")
721 .message(Message::user("Hi"))
722 .stream(true)
723 .build()
724 .unwrap();
725
726 let outcome = client
727 .complete(&request, CompletionOptions::default())
728 .await
729 .unwrap();
730
731 let mut collected = Vec::new();
732 match outcome {
733 CompletionOutcome::Stream(mut stream) => {
734 while let Some(chunk) = stream.next().await {
735 collected.push(chunk.unwrap());
736 }
737 }
738 _ => panic!("expected stream outcome"),
739 }
740
741 assert_eq!(collected.len(), 1);
742 assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
743 assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 1);
744 assert_eq!(middleware.errors.load(Ordering::Relaxed), 0);
745 }
746
747 #[tokio::test]
748 async fn streaming_invokes_on_error_on_failure() {
749 let provider = Arc::new(StreamingProvider::new("p1", true));
750 let middleware = Arc::new(RecordingMiddleware::default());
751
752 let client = SimpleAgentsClientBuilder::new()
753 .with_provider(provider)
754 .with_middleware(middleware.clone())
755 .build()
756 .unwrap();
757
758 let request = CompletionRequest::builder()
759 .model("gpt-4")
760 .message(Message::user("Hi"))
761 .stream(true)
762 .build()
763 .unwrap();
764
765 let outcome = client
766 .complete(&request, CompletionOptions::default())
767 .await
768 .unwrap();
769
770 let mut chunks = Vec::new();
771 match outcome {
772 CompletionOutcome::Stream(mut stream) => {
773 while let Some(chunk) = stream.next().await {
774 chunks.push(chunk);
775 }
776 }
777 _ => panic!("expected stream outcome"),
778 }
779
780 assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
781 assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 0);
782 assert_eq!(middleware.errors.load(Ordering::Relaxed), 1);
783 assert_eq!(chunks.len(), 2);
784 assert!(chunks[0].as_ref().is_ok());
785 assert!(chunks[1].is_err());
786 }
787}