Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use futures_util::future::BoxFuture;
8use futures_util::stream::{self, Stream};
9use futures_util::StreamExt;
10use simple_agent_type::cache::Cache;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use simple_agents_healing::coercion::CoercionEngine;
16use simple_agents_healing::parser::JsonishParser;
17use simple_agents_healing::schema::Schema;
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::RwLock;
22use tracing::debug;
23
24/// Mode for completion post-processing.
25#[derive(Clone)]
26pub enum CompletionMode {
27    /// Return the raw completion response.
28    Standard,
29    /// Parse the response content as JSON using healing.
30    HealedJson,
31    /// Parse and coerce the response into the provided schema.
32    CoercedSchema(Schema),
33}
34
35/// Options that control completion behavior.
36#[derive(Clone)]
37pub struct CompletionOptions {
38    /// Completion post-processing mode.
39    pub mode: CompletionMode,
40}
41
42impl Default for CompletionOptions {
43    fn default() -> Self {
44        Self {
45            mode: CompletionMode::Standard,
46        }
47    }
48}
49
50/// Result of a unified completion call.
51pub enum CompletionOutcome {
52    /// A standard, non-streaming completion response.
53    Response(CompletionResponse),
54    /// A streaming response yielding completion chunks.
55    Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
56    /// A healed JSON response.
57    HealedJson(HealedJsonResponse),
58    /// A schema-coerced response.
59    CoercedSchema(HealedSchemaResponse),
60}
61
62struct ClientState {
63    providers: Vec<Arc<dyn Provider>>,
64    provider_map: HashMap<String, Arc<dyn Provider>>,
65    router: Arc<RouterEngine>,
66}
67
68/// Unified SimpleAgents client.
69pub struct SimpleAgentsClient {
70    state: RwLock<ClientState>,
71    routing_mode: RoutingMode,
72    cache: Option<Arc<dyn Cache>>,
73    cache_ttl: Duration,
74    healing: HealingSettings,
75    middleware: Vec<Arc<dyn Middleware>>,
76}
77
78impl SimpleAgentsClient {
79    /// Start a new client builder.
80    pub fn builder() -> SimpleAgentsClientBuilder {
81        SimpleAgentsClientBuilder::new()
82    }
83
84    /// List registered provider names.
85    pub async fn provider_names(&self) -> Result<Vec<String>> {
86        let state = self.state.read().await;
87        Ok(state.provider_map.keys().cloned().collect())
88    }
89
90    /// Retrieve a provider by name.
91    pub async fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
92        let state = self.state.read().await;
93        Ok(state.provider_map.get(name).cloned())
94    }
95
96    /// Register an additional provider and rebuild the router.
97    pub async fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
98        let mut state = self.state.write().await;
99        let name = provider.name().to_string();
100
101        if state.provider_map.contains_key(&name) {
102            return Err(SimpleAgentsError::Config(format!(
103                "provider already registered: {}",
104                name
105            )));
106        }
107
108        state.provider_map.insert(name, provider.clone());
109        state.providers.push(provider);
110        state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
111        Ok(())
112    }
113
114    /// Execute a completion request with routing, caching, and middleware.
115    pub async fn complete(
116        &self,
117        request: &CompletionRequest,
118        options: CompletionOptions,
119    ) -> Result<CompletionOutcome> {
120        if request.stream.unwrap_or(false) {
121            let stream = self.stream(request).await?;
122            return Ok(CompletionOutcome::Stream(stream));
123        }
124
125        match options.mode {
126            CompletionMode::Standard => {
127                let response = self.complete_response(request).await?;
128                Ok(CompletionOutcome::Response(response))
129            }
130            CompletionMode::HealedJson => {
131                let healed = self.complete_json_internal(request).await?;
132                Ok(CompletionOutcome::HealedJson(healed))
133            }
134            CompletionMode::CoercedSchema(schema) => {
135                let healed = self.complete_with_schema_internal(request, &schema).await?;
136                Ok(CompletionOutcome::CoercedSchema(healed))
137            }
138        }
139    }
140
141    async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
142        request.validate()?;
143        self.before_request(request).await?;
144
145        let cache_key = if let Some(cache) = &self.cache {
146            if cache.is_enabled() {
147                Some(self.cache_key(request)?)
148            } else {
149                None
150            }
151        } else {
152            None
153        };
154
155        if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
156            if let Some(cached) = cache.get(key).await? {
157                let response: CompletionResponse = serde_json::from_slice(&cached)?;
158                self.on_cache_hit(request, &response).await?;
159                return Ok(response);
160            }
161        }
162
163        let start = Instant::now();
164        let router = {
165            let state = self.state.read().await;
166            state.router.clone()
167        };
168        let response = router.complete(request).await;
169
170        match response {
171            Ok(response) => {
172                self.after_response(request, &response, start.elapsed())
173                    .await?;
174                if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
175                    let payload = serde_json::to_vec(&response)?;
176                    cache.set(&key, payload, self.cache_ttl).await?;
177                }
178                Ok(response)
179            }
180            Err(error) => {
181                self.on_error(request, &error, start.elapsed()).await?;
182                Err(error)
183            }
184        }
185    }
186
187    /// Execute a completion request and parse the response content as JSON.
188    async fn complete_json_internal(
189        &self,
190        request: &CompletionRequest,
191    ) -> Result<HealedJsonResponse> {
192        self.ensure_healing_enabled()?;
193        let response = self.complete_response(request).await?;
194        let content = response.content().ok_or_else(|| {
195            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
196                error_message: "response contained no content".to_string(),
197                input: String::new(),
198            })
199        })?;
200
201        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
202        let parsed = parser.parse(content)?;
203
204        Ok(HealedJsonResponse { response, parsed })
205    }
206
207    /// Execute a completion request and coerce the response into a schema.
208    async fn complete_with_schema_internal(
209        &self,
210        request: &CompletionRequest,
211        schema: &Schema,
212    ) -> Result<HealedSchemaResponse> {
213        self.ensure_healing_enabled()?;
214        let healed = self.complete_json_internal(request).await?;
215        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
216        let coerced = engine
217            .coerce(&healed.parsed.value, schema)
218            .map_err(SimpleAgentsError::Healing)?;
219
220        Ok(HealedSchemaResponse {
221            response: healed.response,
222            parsed: healed.parsed,
223            coerced,
224        })
225    }
226
227    /// Execute a streaming completion request.
228    async fn stream(
229        &self,
230        request: &CompletionRequest,
231    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
232        request.validate()?;
233        self.before_request(request).await?;
234        debug!(
235            model = %request.model,
236            stream = ?request.stream,
237            "SimpleAgentsClient.stream start"
238        );
239
240        let router = {
241            let state = self.state.read().await;
242            state.router.clone()
243        };
244
245        let start = Instant::now();
246        let middleware = self.middleware.clone();
247        let instrumented_request = request.clone();
248        let inner = router.stream(request).await?;
249
250        let wrapped = Self::instrument_stream(inner, instrumented_request, middleware, start);
251        Ok(Box::new(wrapped))
252    }
253
254    fn ensure_healing_enabled(&self) -> Result<()> {
255        if self.healing.enabled {
256            Ok(())
257        } else {
258            Err(SimpleAgentsError::Config(
259                "healing is disabled for this client".to_string(),
260            ))
261        }
262    }
263
264    fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
265        let serialized = serde_json::to_string(request)?;
266        Ok(CacheKey::from_parts("core", &request.model, &serialized))
267    }
268
269    async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
270        for middleware in &self.middleware {
271            middleware.before_request(request).await?;
272        }
273        Ok(())
274    }
275
276    async fn after_response(
277        &self,
278        request: &CompletionRequest,
279        response: &CompletionResponse,
280        latency: Duration,
281    ) -> Result<()> {
282        for middleware in &self.middleware {
283            middleware
284                .after_response(request, response, latency)
285                .await?;
286        }
287        Ok(())
288    }
289
290    async fn on_cache_hit(
291        &self,
292        request: &CompletionRequest,
293        response: &CompletionResponse,
294    ) -> Result<()> {
295        for middleware in &self.middleware {
296            middleware.on_cache_hit(request, response).await?;
297        }
298        Ok(())
299    }
300
301    async fn on_error(
302        &self,
303        request: &CompletionRequest,
304        error: &SimpleAgentsError,
305        latency: Duration,
306    ) -> Result<()> {
307        for middleware in &self.middleware {
308            middleware.on_error(request, error, latency).await?;
309        }
310        Ok(())
311    }
312}
313
314impl SimpleAgentsClient {
315    fn instrument_stream(
316        inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
317        request: CompletionRequest,
318        middleware: Vec<Arc<dyn Middleware>>,
319        start: Instant,
320    ) -> impl Stream<Item = Result<CompletionChunk>> + Send + Unpin {
321        struct StreamState {
322            inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
323            middleware: Vec<Arc<dyn Middleware>>,
324            request: CompletionRequest,
325            start: Instant,
326            done: bool,
327        }
328
329        stream::unfold(
330            StreamState {
331                inner,
332                middleware,
333                request,
334                start,
335                done: false,
336            },
337            |mut state| -> BoxFuture<Option<(Result<CompletionChunk>, StreamState)>> {
338                Box::pin(async move {
339                    if state.done {
340                        return None;
341                    }
342
343                    match state.inner.next().await {
344                        Some(Ok(chunk)) => Some((Ok(chunk), state)),
345                        Some(Err(err)) => {
346                            let latency = state.start.elapsed();
347                            for middleware in &state.middleware {
348                                if let Err(mw_err) =
349                                    middleware.on_error(&state.request, &err, latency).await
350                                {
351                                    state.done = true;
352                                    return Some((Err(mw_err), state));
353                                }
354                            }
355                            state.done = true;
356                            Some((Err(err), state))
357                        }
358                        None => {
359                            let latency = state.start.elapsed();
360                            for middleware in &state.middleware {
361                                if let Err(mw_err) =
362                                    middleware.after_stream(&state.request, latency).await
363                                {
364                                    state.done = true;
365                                    return Some((Err(mw_err), state));
366                                }
367                            }
368                            None
369                        }
370                    }
371                })
372            },
373        )
374    }
375}
376
377/// Builder for `SimpleAgentsClient`.
378pub struct SimpleAgentsClientBuilder {
379    providers: Vec<Arc<dyn Provider>>,
380    routing_mode: RoutingMode,
381    cache: Option<Arc<dyn Cache>>,
382    cache_ttl: Duration,
383    healing: HealingSettings,
384    middleware: Vec<Arc<dyn Middleware>>,
385}
386
387impl SimpleAgentsClientBuilder {
388    /// Create a new builder with defaults.
389    pub fn new() -> Self {
390        Self {
391            providers: Vec::new(),
392            routing_mode: RoutingMode::default(),
393            cache: None,
394            cache_ttl: Duration::from_secs(60),
395            healing: HealingSettings::default(),
396            middleware: Vec::new(),
397        }
398    }
399
400    /// Register a provider.
401    pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
402        self.providers.push(provider);
403        self
404    }
405
406    /// Register multiple providers at once.
407    pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
408        self.providers.extend(providers);
409        self
410    }
411
412    /// Configure routing mode.
413    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
414        self.routing_mode = mode;
415        self
416    }
417
418    /// Configure response cache.
419    pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
420        self.cache = Some(cache);
421        self
422    }
423
424    /// Configure cache TTL.
425    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
426        self.cache_ttl = ttl;
427        self
428    }
429
430    /// Configure healing settings.
431    pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
432        self.healing = settings;
433        self
434    }
435
436    /// Register a middleware hook.
437    pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
438        self.middleware.push(middleware);
439        self
440    }
441
442    /// Build the client.
443    pub fn build(self) -> Result<SimpleAgentsClient> {
444        if self.providers.is_empty() {
445            return Err(SimpleAgentsError::Config(
446                "at least one provider is required".to_string(),
447            ));
448        }
449
450        let provider_map = self
451            .providers
452            .iter()
453            .map(|provider| (provider.name().to_string(), provider.clone()))
454            .collect::<HashMap<_, _>>();
455
456        let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
457        let state = ClientState {
458            providers: self.providers,
459            provider_map,
460            router,
461        };
462
463        Ok(SimpleAgentsClient {
464            state: RwLock::new(state),
465            routing_mode: self.routing_mode,
466            cache: self.cache,
467            cache_ttl: self.cache_ttl,
468            healing: self.healing,
469            middleware: self.middleware,
470        })
471    }
472}
473
474impl Default for SimpleAgentsClientBuilder {
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480#[async_trait]
481impl Middleware for () {
482    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
483        Ok(())
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use futures_util::{stream, StreamExt};
491    use simple_agent_type::error::ProviderError;
492    use simple_agent_type::prelude::*;
493    use std::sync::atomic::{AtomicUsize, Ordering};
494    use std::time::Duration;
495
496    struct MockProvider {
497        name: &'static str,
498        calls: AtomicUsize,
499    }
500
501    impl MockProvider {
502        fn new(name: &'static str) -> Self {
503            Self {
504                name,
505                calls: AtomicUsize::new(0),
506            }
507        }
508    }
509
510    #[async_trait]
511    impl Provider for MockProvider {
512        fn name(&self) -> &str {
513            self.name
514        }
515
516        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
517            Ok(ProviderRequest::new("http://example.com"))
518        }
519
520        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
521            self.calls.fetch_add(1, Ordering::Relaxed);
522            Ok(ProviderResponse::new(
523                200,
524                serde_json::json!({"content": "ok"}),
525            ))
526        }
527
528        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
529            Ok(CompletionResponse {
530                id: "resp_test".to_string(),
531                model: "test-model".to_string(),
532                choices: vec![CompletionChoice {
533                    index: 0,
534                    message: Message::assistant("ok"),
535                    finish_reason: FinishReason::Stop,
536                    logprobs: None,
537                }],
538                usage: Usage::new(1, 1),
539                created: None,
540                provider: Some(self.name.to_string()),
541                healing_metadata: None,
542            })
543        }
544    }
545
546    #[tokio::test]
547    async fn client_build_requires_provider() {
548        let result = SimpleAgentsClientBuilder::new().build();
549        assert!(result.is_err());
550    }
551
552    #[tokio::test]
553    async fn register_provider_rebuilds_router() {
554        let provider = Arc::new(MockProvider::new("p1"));
555        let client = SimpleAgentsClientBuilder::new()
556            .with_provider(provider)
557            .build()
558            .unwrap();
559
560        let second = Arc::new(MockProvider::new("p2"));
561        client.register_provider(second).await.unwrap();
562
563        let names = client.provider_names().await.unwrap();
564        assert!(names.contains(&"p1".to_string()));
565        assert!(names.contains(&"p2".to_string()));
566    }
567
568    #[tokio::test]
569    async fn duplicate_provider_registration_fails() {
570        let provider = Arc::new(MockProvider::new("p1"));
571        let client = SimpleAgentsClientBuilder::new()
572            .with_provider(provider.clone())
573            .build()
574            .unwrap();
575
576        let result = client.register_provider(provider).await;
577        assert!(matches!(
578            result,
579            Err(SimpleAgentsError::Config(msg)) if msg.contains("provider already registered")
580        ));
581    }
582
583    #[derive(Default)]
584    struct RecordingMiddleware {
585        before: AtomicUsize,
586        after_stream: AtomicUsize,
587        errors: AtomicUsize,
588    }
589
590    #[async_trait]
591    impl Middleware for RecordingMiddleware {
592        async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
593            self.before.fetch_add(1, Ordering::Relaxed);
594            Ok(())
595        }
596
597        async fn after_stream(
598            &self,
599            _request: &CompletionRequest,
600            _latency: Duration,
601        ) -> Result<()> {
602            self.after_stream.fetch_add(1, Ordering::Relaxed);
603            Ok(())
604        }
605
606        async fn on_error(
607            &self,
608            _request: &CompletionRequest,
609            _error: &SimpleAgentsError,
610            _latency: Duration,
611        ) -> Result<()> {
612            self.errors.fetch_add(1, Ordering::Relaxed);
613            Ok(())
614        }
615
616        fn name(&self) -> &str {
617            "recording"
618        }
619    }
620
621    struct StreamingProvider {
622        name: &'static str,
623        fail_after_first: bool,
624    }
625
626    impl StreamingProvider {
627        fn new(name: &'static str, fail_after_first: bool) -> Self {
628            Self {
629                name,
630                fail_after_first,
631            }
632        }
633
634        fn build_chunk(id: &str, content: &str) -> CompletionChunk {
635            CompletionChunk {
636                id: id.to_string(),
637                model: "test-model".to_string(),
638                choices: vec![ChoiceDelta {
639                    index: 0,
640                    delta: MessageDelta {
641                        role: Some(Role::Assistant),
642                        content: Some(content.to_string()),
643                    },
644                    finish_reason: None,
645                }],
646                created: None,
647            }
648        }
649    }
650
651    #[async_trait]
652    impl Provider for StreamingProvider {
653        fn name(&self) -> &str {
654            self.name
655        }
656
657        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
658            Ok(ProviderRequest::new("http://example.com"))
659        }
660
661        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
662            Ok(ProviderResponse::new(
663                200,
664                serde_json::json!({"content": "ok"}),
665            ))
666        }
667
668        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
669            Ok(CompletionResponse {
670                id: "resp_stream".to_string(),
671                model: "test-model".to_string(),
672                choices: vec![CompletionChoice {
673                    index: 0,
674                    message: Message::assistant("ok"),
675                    finish_reason: FinishReason::Stop,
676                    logprobs: None,
677                }],
678                usage: Usage::new(1, 1),
679                created: None,
680                provider: Some(self.name.to_string()),
681                healing_metadata: None,
682            })
683        }
684
685        async fn execute_stream(
686            &self,
687            _req: ProviderRequest,
688        ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
689        {
690            let stream = if self.fail_after_first {
691                let items: Vec<Result<CompletionChunk>> = vec![
692                    Ok(Self::build_chunk("chunk-1", "hello")),
693                    Err(SimpleAgentsError::Provider(ProviderError::ServerError(
694                        "stream error".to_string(),
695                    ))),
696                ];
697                stream::iter(items)
698            } else {
699                let items: Vec<Result<CompletionChunk>> =
700                    vec![Ok(Self::build_chunk("chunk-1", "hello"))];
701                stream::iter(items)
702            };
703
704            Ok(Box::new(stream))
705        }
706    }
707
708    #[tokio::test]
709    async fn streaming_invokes_after_stream_on_success() {
710        let provider = Arc::new(StreamingProvider::new("p1", false));
711        let middleware = Arc::new(RecordingMiddleware::default());
712
713        let client = SimpleAgentsClientBuilder::new()
714            .with_provider(provider)
715            .with_middleware(middleware.clone())
716            .build()
717            .unwrap();
718
719        let request = CompletionRequest::builder()
720            .model("gpt-4")
721            .message(Message::user("Hi"))
722            .stream(true)
723            .build()
724            .unwrap();
725
726        let outcome = client
727            .complete(&request, CompletionOptions::default())
728            .await
729            .unwrap();
730
731        let mut collected = Vec::new();
732        match outcome {
733            CompletionOutcome::Stream(mut stream) => {
734                while let Some(chunk) = stream.next().await {
735                    collected.push(chunk.unwrap());
736                }
737            }
738            _ => panic!("expected stream outcome"),
739        }
740
741        assert_eq!(collected.len(), 1);
742        assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
743        assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 1);
744        assert_eq!(middleware.errors.load(Ordering::Relaxed), 0);
745    }
746
747    #[tokio::test]
748    async fn streaming_invokes_on_error_on_failure() {
749        let provider = Arc::new(StreamingProvider::new("p1", true));
750        let middleware = Arc::new(RecordingMiddleware::default());
751
752        let client = SimpleAgentsClientBuilder::new()
753            .with_provider(provider)
754            .with_middleware(middleware.clone())
755            .build()
756            .unwrap();
757
758        let request = CompletionRequest::builder()
759            .model("gpt-4")
760            .message(Message::user("Hi"))
761            .stream(true)
762            .build()
763            .unwrap();
764
765        let outcome = client
766            .complete(&request, CompletionOptions::default())
767            .await
768            .unwrap();
769
770        let mut chunks = Vec::new();
771        match outcome {
772            CompletionOutcome::Stream(mut stream) => {
773                while let Some(chunk) = stream.next().await {
774                    chunks.push(chunk);
775                }
776            }
777            _ => panic!("expected stream outcome"),
778        }
779
780        assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
781        assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 0);
782        assert_eq!(middleware.errors.load(Ordering::Relaxed), 1);
783        assert_eq!(chunks.len(), 2);
784        assert!(chunks[0].as_ref().is_ok());
785        assert!(chunks[1].is_err());
786    }
787}