Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use simple_agent_type::cache::Cache;
8use simple_agent_type::cache::CacheKey;
9use simple_agent_type::prelude::{
10    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
11};
12use simple_agents_healing::coercion::CoercionEngine;
13use simple_agents_healing::parser::JsonishParser;
14use simple_agents_healing::schema::Schema;
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19/// Mode for completion post-processing.
20pub enum CompletionMode {
21    /// Return the raw completion response.
22    Standard,
23    /// Parse the response content as JSON using healing.
24    HealedJson,
25    /// Parse and coerce the response into the provided schema.
26    CoercedSchema(Schema),
27}
28
29/// Options that control completion behavior.
30pub struct CompletionOptions {
31    /// Completion post-processing mode.
32    pub mode: CompletionMode,
33}
34
35impl Default for CompletionOptions {
36    fn default() -> Self {
37        Self {
38            mode: CompletionMode::Standard,
39        }
40    }
41}
42
43/// Result of a unified completion call.
44pub enum CompletionOutcome {
45    /// A standard, non-streaming completion response.
46    Response(CompletionResponse),
47    /// A streaming response yielding completion chunks.
48    Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
49    /// A healed JSON response.
50    HealedJson(HealedJsonResponse),
51    /// A schema-coerced response.
52    CoercedSchema(HealedSchemaResponse),
53}
54
55struct ClientState {
56    providers: Vec<Arc<dyn Provider>>,
57    provider_map: HashMap<String, Arc<dyn Provider>>,
58    router: Arc<RouterEngine>,
59}
60
61/// Unified SimpleAgents client.
62pub struct SimpleAgentsClient {
63    state: RwLock<ClientState>,
64    routing_mode: RoutingMode,
65    cache: Option<Arc<dyn Cache>>,
66    cache_ttl: Duration,
67    healing: HealingSettings,
68    middleware: Vec<Arc<dyn Middleware>>,
69}
70
71impl SimpleAgentsClient {
72    /// Start a new client builder.
73    pub fn builder() -> SimpleAgentsClientBuilder {
74        SimpleAgentsClientBuilder::new()
75    }
76
77    /// List registered provider names.
78    pub fn provider_names(&self) -> Result<Vec<String>> {
79        let state = self.state.read().map_err(|_| {
80            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
81        })?;
82        Ok(state.provider_map.keys().cloned().collect())
83    }
84
85    /// Retrieve a provider by name.
86    pub fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
87        let state = self.state.read().map_err(|_| {
88            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
89        })?;
90        Ok(state.provider_map.get(name).cloned())
91    }
92
93    /// Register an additional provider and rebuild the router.
94    pub fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
95        let mut state = self.state.write().map_err(|_| {
96            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
97        })?;
98        state
99            .provider_map
100            .insert(provider.name().to_string(), provider.clone());
101        state.providers.push(provider);
102        state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
103        Ok(())
104    }
105
106    /// Execute a completion request with routing, caching, and middleware.
107    pub async fn complete(
108        &self,
109        request: &CompletionRequest,
110        options: CompletionOptions,
111    ) -> Result<CompletionOutcome> {
112        if request.stream.unwrap_or(false) {
113            let stream = self.stream(request).await?;
114            return Ok(CompletionOutcome::Stream(stream));
115        }
116
117        match options.mode {
118            CompletionMode::Standard => {
119                let response = self.complete_response(request).await?;
120                Ok(CompletionOutcome::Response(response))
121            }
122            CompletionMode::HealedJson => {
123                let healed = self.complete_json_internal(request).await?;
124                Ok(CompletionOutcome::HealedJson(healed))
125            }
126            CompletionMode::CoercedSchema(schema) => {
127                let healed = self.complete_with_schema_internal(request, &schema).await?;
128                Ok(CompletionOutcome::CoercedSchema(healed))
129            }
130        }
131    }
132
133    async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
134        request.validate()?;
135        self.before_request(request).await?;
136
137        let cache_key = if let Some(cache) = &self.cache {
138            if cache.is_enabled() {
139                Some(self.cache_key(request)?)
140            } else {
141                None
142            }
143        } else {
144            None
145        };
146
147        if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
148            if let Some(cached) = cache.get(key).await? {
149                let response: CompletionResponse = serde_json::from_slice(&cached)?;
150                self.on_cache_hit(request, &response).await?;
151                return Ok(response);
152            }
153        }
154
155        let start = Instant::now();
156        let router = {
157            let state = self.state.read().map_err(|_| {
158                SimpleAgentsError::Config("provider registry lock poisoned".to_string())
159            })?;
160            state.router.clone()
161        };
162        let response = router.complete(request).await;
163
164        match response {
165            Ok(response) => {
166                self.after_response(request, &response, start.elapsed())
167                    .await?;
168                if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
169                    let payload = serde_json::to_vec(&response)?;
170                    cache.set(&key, payload, self.cache_ttl).await?;
171                }
172                Ok(response)
173            }
174            Err(error) => {
175                self.on_error(request, &error, start.elapsed()).await?;
176                Err(error)
177            }
178        }
179    }
180
181    /// Execute a completion request and parse the response content as JSON.
182    async fn complete_json_internal(
183        &self,
184        request: &CompletionRequest,
185    ) -> Result<HealedJsonResponse> {
186        self.ensure_healing_enabled()?;
187        let response = self.complete_response(request).await?;
188        let content = response.content().ok_or_else(|| {
189            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
190                error_message: "response contained no content".to_string(),
191                input: String::new(),
192            })
193        })?;
194
195        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
196        let parsed = parser.parse(content)?;
197
198        Ok(HealedJsonResponse { response, parsed })
199    }
200
201    /// Execute a completion request and coerce the response into a schema.
202    async fn complete_with_schema_internal(
203        &self,
204        request: &CompletionRequest,
205        schema: &Schema,
206    ) -> Result<HealedSchemaResponse> {
207        self.ensure_healing_enabled()?;
208        let healed = self.complete_json_internal(request).await?;
209        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
210        let coerced = engine
211            .coerce(&healed.parsed.value, schema)
212            .map_err(SimpleAgentsError::Healing)?;
213
214        Ok(HealedSchemaResponse {
215            response: healed.response,
216            parsed: healed.parsed,
217            coerced,
218        })
219    }
220
221    /// Execute a streaming completion request.
222    async fn stream(
223        &self,
224        request: &CompletionRequest,
225    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
226        request.validate()?;
227        self.before_request(request).await?;
228        eprintln!(
229            "SimpleAgentsClient.stream: model={}, stream={:?}",
230            request.model, request.stream
231        );
232
233        let router = {
234            let state = self.state.read().map_err(|_| {
235                SimpleAgentsError::Config("provider registry lock poisoned".to_string())
236            })?;
237            state.router.clone()
238        };
239
240        router.stream(request).await
241    }
242
243    fn ensure_healing_enabled(&self) -> Result<()> {
244        if self.healing.enabled {
245            Ok(())
246        } else {
247            Err(SimpleAgentsError::Config(
248                "healing is disabled for this client".to_string(),
249            ))
250        }
251    }
252
253    fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
254        let serialized = serde_json::to_string(request)?;
255        Ok(CacheKey::from_parts("core", &request.model, &serialized))
256    }
257
258    async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
259        for middleware in &self.middleware {
260            middleware.before_request(request).await?;
261        }
262        Ok(())
263    }
264
265    async fn after_response(
266        &self,
267        request: &CompletionRequest,
268        response: &CompletionResponse,
269        latency: Duration,
270    ) -> Result<()> {
271        for middleware in &self.middleware {
272            middleware
273                .after_response(request, response, latency)
274                .await?;
275        }
276        Ok(())
277    }
278
279    async fn on_cache_hit(
280        &self,
281        request: &CompletionRequest,
282        response: &CompletionResponse,
283    ) -> Result<()> {
284        for middleware in &self.middleware {
285            middleware.on_cache_hit(request, response).await?;
286        }
287        Ok(())
288    }
289
290    async fn on_error(
291        &self,
292        request: &CompletionRequest,
293        error: &SimpleAgentsError,
294        latency: Duration,
295    ) -> Result<()> {
296        for middleware in &self.middleware {
297            middleware.on_error(request, error, latency).await?;
298        }
299        Ok(())
300    }
301}
302
303/// Builder for `SimpleAgentsClient`.
304pub struct SimpleAgentsClientBuilder {
305    providers: Vec<Arc<dyn Provider>>,
306    routing_mode: RoutingMode,
307    cache: Option<Arc<dyn Cache>>,
308    cache_ttl: Duration,
309    healing: HealingSettings,
310    middleware: Vec<Arc<dyn Middleware>>,
311}
312
313impl SimpleAgentsClientBuilder {
314    /// Create a new builder with defaults.
315    pub fn new() -> Self {
316        Self {
317            providers: Vec::new(),
318            routing_mode: RoutingMode::default(),
319            cache: None,
320            cache_ttl: Duration::from_secs(60),
321            healing: HealingSettings::default(),
322            middleware: Vec::new(),
323        }
324    }
325
326    /// Register a provider.
327    pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
328        self.providers.push(provider);
329        self
330    }
331
332    /// Register multiple providers at once.
333    pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
334        self.providers.extend(providers);
335        self
336    }
337
338    /// Configure routing mode.
339    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
340        self.routing_mode = mode;
341        self
342    }
343
344    /// Configure response cache.
345    pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
346        self.cache = Some(cache);
347        self
348    }
349
350    /// Configure cache TTL.
351    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
352        self.cache_ttl = ttl;
353        self
354    }
355
356    /// Configure healing settings.
357    pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
358        self.healing = settings;
359        self
360    }
361
362    /// Register a middleware hook.
363    pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
364        self.middleware.push(middleware);
365        self
366    }
367
368    /// Build the client.
369    pub fn build(self) -> Result<SimpleAgentsClient> {
370        if self.providers.is_empty() {
371            return Err(SimpleAgentsError::Config(
372                "at least one provider is required".to_string(),
373            ));
374        }
375
376        let provider_map = self
377            .providers
378            .iter()
379            .map(|provider| (provider.name().to_string(), provider.clone()))
380            .collect::<HashMap<_, _>>();
381
382        let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
383        let state = ClientState {
384            providers: self.providers,
385            provider_map,
386            router,
387        };
388
389        Ok(SimpleAgentsClient {
390            state: RwLock::new(state),
391            routing_mode: self.routing_mode,
392            cache: self.cache,
393            cache_ttl: self.cache_ttl,
394            healing: self.healing,
395            middleware: self.middleware,
396        })
397    }
398}
399
400impl Default for SimpleAgentsClientBuilder {
401    fn default() -> Self {
402        Self::new()
403    }
404}
405
406#[async_trait]
407impl Middleware for () {
408    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
409        Ok(())
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use simple_agent_type::prelude::*;
417    use std::sync::atomic::{AtomicUsize, Ordering};
418
419    struct MockProvider {
420        name: &'static str,
421        calls: AtomicUsize,
422    }
423
424    impl MockProvider {
425        fn new(name: &'static str) -> Self {
426            Self {
427                name,
428                calls: AtomicUsize::new(0),
429            }
430        }
431    }
432
433    #[async_trait]
434    impl Provider for MockProvider {
435        fn name(&self) -> &str {
436            self.name
437        }
438
439        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
440            Ok(ProviderRequest::new("http://example.com"))
441        }
442
443        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
444            self.calls.fetch_add(1, Ordering::Relaxed);
445            Ok(ProviderResponse::new(
446                200,
447                serde_json::json!({"content": "ok"}),
448            ))
449        }
450
451        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
452            Ok(CompletionResponse {
453                id: "resp_test".to_string(),
454                model: "test-model".to_string(),
455                choices: vec![CompletionChoice {
456                    index: 0,
457                    message: Message::assistant("ok"),
458                    finish_reason: FinishReason::Stop,
459                    logprobs: None,
460                }],
461                usage: Usage::new(1, 1),
462                created: None,
463                provider: Some(self.name.to_string()),
464                healing_metadata: None,
465            })
466        }
467    }
468
469    #[tokio::test]
470    async fn client_build_requires_provider() {
471        let result = SimpleAgentsClientBuilder::new().build();
472        assert!(result.is_err());
473    }
474
475    #[tokio::test]
476    async fn register_provider_rebuilds_router() {
477        let provider = Arc::new(MockProvider::new("p1"));
478        let client = SimpleAgentsClientBuilder::new()
479            .with_provider(provider)
480            .build()
481            .unwrap();
482
483        let second = Arc::new(MockProvider::new("p2"));
484        client.register_provider(second).unwrap();
485
486        let names = client.provider_names().unwrap();
487        assert!(names.contains(&"p1".to_string()));
488        assert!(names.contains(&"p2".to_string()));
489    }
490}