Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use futures_util::future::BoxFuture;
8use futures_util::stream::{self, Stream};
9use futures_util::StreamExt;
10use simple_agent_type::cache::Cache;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use simple_agents_healing::coercion::CoercionEngine;
16use simple_agents_healing::parser::JsonishParser;
17use simple_agents_healing::schema::Schema;
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::RwLock;
22use tracing::debug;
23
24/// Mode for completion post-processing.
25pub enum CompletionMode {
26    /// Return the raw completion response.
27    Standard,
28    /// Parse the response content as JSON using healing.
29    HealedJson,
30    /// Parse and coerce the response into the provided schema.
31    CoercedSchema(Schema),
32}
33
34/// Options that control completion behavior.
35pub struct CompletionOptions {
36    /// Completion post-processing mode.
37    pub mode: CompletionMode,
38}
39
40impl Default for CompletionOptions {
41    fn default() -> Self {
42        Self {
43            mode: CompletionMode::Standard,
44        }
45    }
46}
47
48/// Result of a unified completion call.
49pub enum CompletionOutcome {
50    /// A standard, non-streaming completion response.
51    Response(CompletionResponse),
52    /// A streaming response yielding completion chunks.
53    Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
54    /// A healed JSON response.
55    HealedJson(HealedJsonResponse),
56    /// A schema-coerced response.
57    CoercedSchema(HealedSchemaResponse),
58}
59
60struct ClientState {
61    providers: Vec<Arc<dyn Provider>>,
62    provider_map: HashMap<String, Arc<dyn Provider>>,
63    router: Arc<RouterEngine>,
64}
65
66/// Unified SimpleAgents client.
67pub struct SimpleAgentsClient {
68    state: RwLock<ClientState>,
69    routing_mode: RoutingMode,
70    cache: Option<Arc<dyn Cache>>,
71    cache_ttl: Duration,
72    healing: HealingSettings,
73    middleware: Vec<Arc<dyn Middleware>>,
74}
75
76impl SimpleAgentsClient {
77    /// Start a new client builder.
78    pub fn builder() -> SimpleAgentsClientBuilder {
79        SimpleAgentsClientBuilder::new()
80    }
81
82    /// List registered provider names.
83    pub async fn provider_names(&self) -> Result<Vec<String>> {
84        let state = self.state.read().await;
85        Ok(state.provider_map.keys().cloned().collect())
86    }
87
88    /// Retrieve a provider by name.
89    pub async fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
90        let state = self.state.read().await;
91        Ok(state.provider_map.get(name).cloned())
92    }
93
94    /// Register an additional provider and rebuild the router.
95    pub async fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
96        let mut state = self.state.write().await;
97        let name = provider.name().to_string();
98
99        if state.provider_map.contains_key(&name) {
100            return Err(SimpleAgentsError::Config(format!(
101                "provider already registered: {}",
102                name
103            )));
104        }
105
106        state.provider_map.insert(name, provider.clone());
107        state.providers.push(provider);
108        state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
109        Ok(())
110    }
111
112    /// Execute a completion request with routing, caching, and middleware.
113    pub async fn complete(
114        &self,
115        request: &CompletionRequest,
116        options: CompletionOptions,
117    ) -> Result<CompletionOutcome> {
118        if request.stream.unwrap_or(false) {
119            let stream = self.stream(request).await?;
120            return Ok(CompletionOutcome::Stream(stream));
121        }
122
123        match options.mode {
124            CompletionMode::Standard => {
125                let response = self.complete_response(request).await?;
126                Ok(CompletionOutcome::Response(response))
127            }
128            CompletionMode::HealedJson => {
129                let healed = self.complete_json_internal(request).await?;
130                Ok(CompletionOutcome::HealedJson(healed))
131            }
132            CompletionMode::CoercedSchema(schema) => {
133                let healed = self.complete_with_schema_internal(request, &schema).await?;
134                Ok(CompletionOutcome::CoercedSchema(healed))
135            }
136        }
137    }
138
139    async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
140        request.validate()?;
141        self.before_request(request).await?;
142
143        let cache_key = if let Some(cache) = &self.cache {
144            if cache.is_enabled() {
145                Some(self.cache_key(request)?)
146            } else {
147                None
148            }
149        } else {
150            None
151        };
152
153        if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
154            if let Some(cached) = cache.get(key).await? {
155                let response: CompletionResponse = serde_json::from_slice(&cached)?;
156                self.on_cache_hit(request, &response).await?;
157                return Ok(response);
158            }
159        }
160
161        let start = Instant::now();
162        let router = {
163            let state = self.state.read().await;
164            state.router.clone()
165        };
166        let response = router.complete(request).await;
167
168        match response {
169            Ok(response) => {
170                self.after_response(request, &response, start.elapsed())
171                    .await?;
172                if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
173                    let payload = serde_json::to_vec(&response)?;
174                    cache.set(&key, payload, self.cache_ttl).await?;
175                }
176                Ok(response)
177            }
178            Err(error) => {
179                self.on_error(request, &error, start.elapsed()).await?;
180                Err(error)
181            }
182        }
183    }
184
185    /// Execute a completion request and parse the response content as JSON.
186    async fn complete_json_internal(
187        &self,
188        request: &CompletionRequest,
189    ) -> Result<HealedJsonResponse> {
190        self.ensure_healing_enabled()?;
191        let response = self.complete_response(request).await?;
192        let content = response.content().ok_or_else(|| {
193            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
194                error_message: "response contained no content".to_string(),
195                input: String::new(),
196            })
197        })?;
198
199        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
200        let parsed = parser.parse(content)?;
201
202        Ok(HealedJsonResponse { response, parsed })
203    }
204
205    /// Execute a completion request and coerce the response into a schema.
206    async fn complete_with_schema_internal(
207        &self,
208        request: &CompletionRequest,
209        schema: &Schema,
210    ) -> Result<HealedSchemaResponse> {
211        self.ensure_healing_enabled()?;
212        let healed = self.complete_json_internal(request).await?;
213        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
214        let coerced = engine
215            .coerce(&healed.parsed.value, schema)
216            .map_err(SimpleAgentsError::Healing)?;
217
218        Ok(HealedSchemaResponse {
219            response: healed.response,
220            parsed: healed.parsed,
221            coerced,
222        })
223    }
224
225    /// Execute a streaming completion request.
226    async fn stream(
227        &self,
228        request: &CompletionRequest,
229    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
230        request.validate()?;
231        self.before_request(request).await?;
232        debug!(
233            model = %request.model,
234            stream = ?request.stream,
235            "SimpleAgentsClient.stream start"
236        );
237
238        let router = {
239            let state = self.state.read().await;
240            state.router.clone()
241        };
242
243        let start = Instant::now();
244        let middleware = self.middleware.clone();
245        let instrumented_request = request.clone();
246        let inner = router.stream(request).await?;
247
248        let wrapped = Self::instrument_stream(inner, instrumented_request, middleware, start);
249        Ok(Box::new(wrapped))
250    }
251
252    fn ensure_healing_enabled(&self) -> Result<()> {
253        if self.healing.enabled {
254            Ok(())
255        } else {
256            Err(SimpleAgentsError::Config(
257                "healing is disabled for this client".to_string(),
258            ))
259        }
260    }
261
262    fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
263        let serialized = serde_json::to_string(request)?;
264        Ok(CacheKey::from_parts("core", &request.model, &serialized))
265    }
266
267    async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
268        for middleware in &self.middleware {
269            middleware.before_request(request).await?;
270        }
271        Ok(())
272    }
273
274    async fn after_response(
275        &self,
276        request: &CompletionRequest,
277        response: &CompletionResponse,
278        latency: Duration,
279    ) -> Result<()> {
280        for middleware in &self.middleware {
281            middleware
282                .after_response(request, response, latency)
283                .await?;
284        }
285        Ok(())
286    }
287
288    async fn on_cache_hit(
289        &self,
290        request: &CompletionRequest,
291        response: &CompletionResponse,
292    ) -> Result<()> {
293        for middleware in &self.middleware {
294            middleware.on_cache_hit(request, response).await?;
295        }
296        Ok(())
297    }
298
299    async fn on_error(
300        &self,
301        request: &CompletionRequest,
302        error: &SimpleAgentsError,
303        latency: Duration,
304    ) -> Result<()> {
305        for middleware in &self.middleware {
306            middleware.on_error(request, error, latency).await?;
307        }
308        Ok(())
309    }
310}
311
312impl SimpleAgentsClient {
313    fn instrument_stream(
314        inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
315        request: CompletionRequest,
316        middleware: Vec<Arc<dyn Middleware>>,
317        start: Instant,
318    ) -> impl Stream<Item = Result<CompletionChunk>> + Send + Unpin {
319        struct StreamState {
320            inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
321            middleware: Vec<Arc<dyn Middleware>>,
322            request: CompletionRequest,
323            start: Instant,
324            done: bool,
325        }
326
327        stream::unfold(
328            StreamState {
329                inner,
330                middleware,
331                request,
332                start,
333                done: false,
334            },
335            |mut state| -> BoxFuture<Option<(Result<CompletionChunk>, StreamState)>> {
336                Box::pin(async move {
337                    if state.done {
338                        return None;
339                    }
340
341                    match state.inner.next().await {
342                        Some(Ok(chunk)) => Some((Ok(chunk), state)),
343                        Some(Err(err)) => {
344                            let latency = state.start.elapsed();
345                            for middleware in &state.middleware {
346                                if let Err(mw_err) =
347                                    middleware.on_error(&state.request, &err, latency).await
348                                {
349                                    state.done = true;
350                                    return Some((Err(mw_err), state));
351                                }
352                            }
353                            state.done = true;
354                            Some((Err(err), state))
355                        }
356                        None => {
357                            let latency = state.start.elapsed();
358                            for middleware in &state.middleware {
359                                if let Err(mw_err) =
360                                    middleware.after_stream(&state.request, latency).await
361                                {
362                                    state.done = true;
363                                    return Some((Err(mw_err), state));
364                                }
365                            }
366                            None
367                        }
368                    }
369                })
370            },
371        )
372    }
373}
374
375/// Builder for `SimpleAgentsClient`.
376pub struct SimpleAgentsClientBuilder {
377    providers: Vec<Arc<dyn Provider>>,
378    routing_mode: RoutingMode,
379    cache: Option<Arc<dyn Cache>>,
380    cache_ttl: Duration,
381    healing: HealingSettings,
382    middleware: Vec<Arc<dyn Middleware>>,
383}
384
385impl SimpleAgentsClientBuilder {
386    /// Create a new builder with defaults.
387    pub fn new() -> Self {
388        Self {
389            providers: Vec::new(),
390            routing_mode: RoutingMode::default(),
391            cache: None,
392            cache_ttl: Duration::from_secs(60),
393            healing: HealingSettings::default(),
394            middleware: Vec::new(),
395        }
396    }
397
398    /// Register a provider.
399    pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
400        self.providers.push(provider);
401        self
402    }
403
404    /// Register multiple providers at once.
405    pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
406        self.providers.extend(providers);
407        self
408    }
409
410    /// Configure routing mode.
411    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
412        self.routing_mode = mode;
413        self
414    }
415
416    /// Configure response cache.
417    pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
418        self.cache = Some(cache);
419        self
420    }
421
422    /// Configure cache TTL.
423    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
424        self.cache_ttl = ttl;
425        self
426    }
427
428    /// Configure healing settings.
429    pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
430        self.healing = settings;
431        self
432    }
433
434    /// Register a middleware hook.
435    pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
436        self.middleware.push(middleware);
437        self
438    }
439
440    /// Build the client.
441    pub fn build(self) -> Result<SimpleAgentsClient> {
442        if self.providers.is_empty() {
443            return Err(SimpleAgentsError::Config(
444                "at least one provider is required".to_string(),
445            ));
446        }
447
448        let provider_map = self
449            .providers
450            .iter()
451            .map(|provider| (provider.name().to_string(), provider.clone()))
452            .collect::<HashMap<_, _>>();
453
454        let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
455        let state = ClientState {
456            providers: self.providers,
457            provider_map,
458            router,
459        };
460
461        Ok(SimpleAgentsClient {
462            state: RwLock::new(state),
463            routing_mode: self.routing_mode,
464            cache: self.cache,
465            cache_ttl: self.cache_ttl,
466            healing: self.healing,
467            middleware: self.middleware,
468        })
469    }
470}
471
472impl Default for SimpleAgentsClientBuilder {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478#[async_trait]
479impl Middleware for () {
480    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
481        Ok(())
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use futures_util::{stream, StreamExt};
489    use simple_agent_type::error::ProviderError;
490    use simple_agent_type::prelude::*;
491    use std::sync::atomic::{AtomicUsize, Ordering};
492    use std::time::Duration;
493
494    struct MockProvider {
495        name: &'static str,
496        calls: AtomicUsize,
497    }
498
499    impl MockProvider {
500        fn new(name: &'static str) -> Self {
501            Self {
502                name,
503                calls: AtomicUsize::new(0),
504            }
505        }
506    }
507
508    #[async_trait]
509    impl Provider for MockProvider {
510        fn name(&self) -> &str {
511            self.name
512        }
513
514        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
515            Ok(ProviderRequest::new("http://example.com"))
516        }
517
518        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
519            self.calls.fetch_add(1, Ordering::Relaxed);
520            Ok(ProviderResponse::new(
521                200,
522                serde_json::json!({"content": "ok"}),
523            ))
524        }
525
526        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
527            Ok(CompletionResponse {
528                id: "resp_test".to_string(),
529                model: "test-model".to_string(),
530                choices: vec![CompletionChoice {
531                    index: 0,
532                    message: Message::assistant("ok"),
533                    finish_reason: FinishReason::Stop,
534                    logprobs: None,
535                }],
536                usage: Usage::new(1, 1),
537                created: None,
538                provider: Some(self.name.to_string()),
539                healing_metadata: None,
540            })
541        }
542    }
543
544    #[tokio::test]
545    async fn client_build_requires_provider() {
546        let result = SimpleAgentsClientBuilder::new().build();
547        assert!(result.is_err());
548    }
549
550    #[tokio::test]
551    async fn register_provider_rebuilds_router() {
552        let provider = Arc::new(MockProvider::new("p1"));
553        let client = SimpleAgentsClientBuilder::new()
554            .with_provider(provider)
555            .build()
556            .unwrap();
557
558        let second = Arc::new(MockProvider::new("p2"));
559        client.register_provider(second).await.unwrap();
560
561        let names = client.provider_names().await.unwrap();
562        assert!(names.contains(&"p1".to_string()));
563        assert!(names.contains(&"p2".to_string()));
564    }
565
566    #[tokio::test]
567    async fn duplicate_provider_registration_fails() {
568        let provider = Arc::new(MockProvider::new("p1"));
569        let client = SimpleAgentsClientBuilder::new()
570            .with_provider(provider.clone())
571            .build()
572            .unwrap();
573
574        let result = client.register_provider(provider).await;
575        assert!(matches!(
576            result,
577            Err(SimpleAgentsError::Config(msg)) if msg.contains("provider already registered")
578        ));
579    }
580
581    #[derive(Default)]
582    struct RecordingMiddleware {
583        before: AtomicUsize,
584        after_stream: AtomicUsize,
585        errors: AtomicUsize,
586    }
587
588    #[async_trait]
589    impl Middleware for RecordingMiddleware {
590        async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
591            self.before.fetch_add(1, Ordering::Relaxed);
592            Ok(())
593        }
594
595        async fn after_stream(
596            &self,
597            _request: &CompletionRequest,
598            _latency: Duration,
599        ) -> Result<()> {
600            self.after_stream.fetch_add(1, Ordering::Relaxed);
601            Ok(())
602        }
603
604        async fn on_error(
605            &self,
606            _request: &CompletionRequest,
607            _error: &SimpleAgentsError,
608            _latency: Duration,
609        ) -> Result<()> {
610            self.errors.fetch_add(1, Ordering::Relaxed);
611            Ok(())
612        }
613
614        fn name(&self) -> &str {
615            "recording"
616        }
617    }
618
619    struct StreamingProvider {
620        name: &'static str,
621        fail_after_first: bool,
622    }
623
624    impl StreamingProvider {
625        fn new(name: &'static str, fail_after_first: bool) -> Self {
626            Self {
627                name,
628                fail_after_first,
629            }
630        }
631
632        fn build_chunk(id: &str, content: &str) -> CompletionChunk {
633            CompletionChunk {
634                id: id.to_string(),
635                model: "test-model".to_string(),
636                choices: vec![ChoiceDelta {
637                    index: 0,
638                    delta: MessageDelta {
639                        role: Some(Role::Assistant),
640                        content: Some(content.to_string()),
641                    },
642                    finish_reason: None,
643                }],
644                created: None,
645            }
646        }
647    }
648
649    #[async_trait]
650    impl Provider for StreamingProvider {
651        fn name(&self) -> &str {
652            self.name
653        }
654
655        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
656            Ok(ProviderRequest::new("http://example.com"))
657        }
658
659        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
660            Ok(ProviderResponse::new(
661                200,
662                serde_json::json!({"content": "ok"}),
663            ))
664        }
665
666        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
667            Ok(CompletionResponse {
668                id: "resp_stream".to_string(),
669                model: "test-model".to_string(),
670                choices: vec![CompletionChoice {
671                    index: 0,
672                    message: Message::assistant("ok"),
673                    finish_reason: FinishReason::Stop,
674                    logprobs: None,
675                }],
676                usage: Usage::new(1, 1),
677                created: None,
678                provider: Some(self.name.to_string()),
679                healing_metadata: None,
680            })
681        }
682
683        async fn execute_stream(
684            &self,
685            _req: ProviderRequest,
686        ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
687        {
688            let stream = if self.fail_after_first {
689                let items: Vec<Result<CompletionChunk>> = vec![
690                    Ok(Self::build_chunk("chunk-1", "hello")),
691                    Err(SimpleAgentsError::Provider(ProviderError::ServerError(
692                        "stream error".to_string(),
693                    ))),
694                ];
695                stream::iter(items)
696            } else {
697                let items: Vec<Result<CompletionChunk>> =
698                    vec![Ok(Self::build_chunk("chunk-1", "hello"))];
699                stream::iter(items)
700            };
701
702            Ok(Box::new(stream))
703        }
704    }
705
706    #[tokio::test]
707    async fn streaming_invokes_after_stream_on_success() {
708        let provider = Arc::new(StreamingProvider::new("p1", false));
709        let middleware = Arc::new(RecordingMiddleware::default());
710
711        let client = SimpleAgentsClientBuilder::new()
712            .with_provider(provider)
713            .with_middleware(middleware.clone())
714            .build()
715            .unwrap();
716
717        let request = CompletionRequest::builder()
718            .model("gpt-4")
719            .message(Message::user("Hi"))
720            .stream(true)
721            .build()
722            .unwrap();
723
724        let outcome = client
725            .complete(&request, CompletionOptions::default())
726            .await
727            .unwrap();
728
729        let mut collected = Vec::new();
730        match outcome {
731            CompletionOutcome::Stream(mut stream) => {
732                while let Some(chunk) = stream.next().await {
733                    collected.push(chunk.unwrap());
734                }
735            }
736            _ => panic!("expected stream outcome"),
737        }
738
739        assert_eq!(collected.len(), 1);
740        assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
741        assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 1);
742        assert_eq!(middleware.errors.load(Ordering::Relaxed), 0);
743    }
744
745    #[tokio::test]
746    async fn streaming_invokes_on_error_on_failure() {
747        let provider = Arc::new(StreamingProvider::new("p1", true));
748        let middleware = Arc::new(RecordingMiddleware::default());
749
750        let client = SimpleAgentsClientBuilder::new()
751            .with_provider(provider)
752            .with_middleware(middleware.clone())
753            .build()
754            .unwrap();
755
756        let request = CompletionRequest::builder()
757            .model("gpt-4")
758            .message(Message::user("Hi"))
759            .stream(true)
760            .build()
761            .unwrap();
762
763        let outcome = client
764            .complete(&request, CompletionOptions::default())
765            .await
766            .unwrap();
767
768        let mut chunks = Vec::new();
769        match outcome {
770            CompletionOutcome::Stream(mut stream) => {
771                while let Some(chunk) = stream.next().await {
772                    chunks.push(chunk);
773                }
774            }
775            _ => panic!("expected stream outcome"),
776        }
777
778        assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
779        assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 0);
780        assert_eq!(middleware.errors.load(Ordering::Relaxed), 1);
781        assert_eq!(chunks.len(), 2);
782        assert!(chunks[0].as_ref().is_ok());
783        assert!(chunks[1].is_err());
784    }
785}