Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use simple_agent_type::cache::Cache;
8use simple_agents_healing::coercion::CoercionEngine;
9use simple_agents_healing::parser::JsonishParser;
10use simple_agents_healing::schema::Schema;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19struct ClientState {
20    providers: Vec<Arc<dyn Provider>>,
21    provider_map: HashMap<String, Arc<dyn Provider>>,
22    router: Arc<RouterEngine>,
23}
24
25/// Unified SimpleAgents client.
26pub struct SimpleAgentsClient {
27    state: RwLock<ClientState>,
28    routing_mode: RoutingMode,
29    cache: Option<Arc<dyn Cache>>,
30    cache_ttl: Duration,
31    healing: HealingSettings,
32    middleware: Vec<Arc<dyn Middleware>>,
33}
34
35impl SimpleAgentsClient {
36    /// Start a new client builder.
37    pub fn builder() -> SimpleAgentsClientBuilder {
38        SimpleAgentsClientBuilder::new()
39    }
40
41    /// List registered provider names.
42    pub fn provider_names(&self) -> Result<Vec<String>> {
43        let state = self.state.read().map_err(|_| {
44            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
45        })?;
46        Ok(state.provider_map.keys().cloned().collect())
47    }
48
49    /// Retrieve a provider by name.
50    pub fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
51        let state = self.state.read().map_err(|_| {
52            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
53        })?;
54        Ok(state.provider_map.get(name).cloned())
55    }
56
57    /// Register an additional provider and rebuild the router.
58    pub fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
59        let mut state = self.state.write().map_err(|_| {
60            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
61        })?;
62        state
63            .provider_map
64            .insert(provider.name().to_string(), provider.clone());
65        state.providers.push(provider);
66        state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
67        Ok(())
68    }
69
70    /// Execute a completion request with routing, caching, and middleware.
71    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
72        request.validate()?;
73        self.before_request(request).await?;
74
75        let cache_key = if let Some(cache) = &self.cache {
76            if cache.is_enabled() {
77                Some(self.cache_key(request)?)
78            } else {
79                None
80            }
81        } else {
82            None
83        };
84
85        if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
86            if let Some(cached) = cache.get(key).await? {
87                let response: CompletionResponse = serde_json::from_slice(&cached)?;
88                self.on_cache_hit(request, &response).await?;
89                return Ok(response);
90            }
91        }
92
93        let start = Instant::now();
94        let router = {
95            let state = self.state.read().map_err(|_| {
96                SimpleAgentsError::Config("provider registry lock poisoned".to_string())
97            })?;
98            state.router.clone()
99        };
100        let response = router.complete(request).await;
101
102        match response {
103            Ok(response) => {
104                self.after_response(request, &response, start.elapsed())
105                    .await?;
106                if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
107                    let payload = serde_json::to_vec(&response)?;
108                    cache.set(&key, payload, self.cache_ttl).await?;
109                }
110                Ok(response)
111            }
112            Err(error) => {
113                self.on_error(request, &error, start.elapsed()).await?;
114                Err(error)
115            }
116        }
117    }
118
119    /// Execute a completion request and parse the response content as JSON.
120    pub async fn complete_json(&self, request: &CompletionRequest) -> Result<HealedJsonResponse> {
121        self.ensure_healing_enabled()?;
122        let response = self.complete(request).await?;
123        let content = response.content().ok_or_else(|| {
124            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
125                error_message: "response contained no content".to_string(),
126                input: String::new(),
127            })
128        })?;
129
130        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
131        let parsed = parser.parse(content)?;
132
133        Ok(HealedJsonResponse { response, parsed })
134    }
135
136    /// Execute a completion request and coerce the response into a schema.
137    pub async fn complete_with_schema(
138        &self,
139        request: &CompletionRequest,
140        schema: &Schema,
141    ) -> Result<HealedSchemaResponse> {
142        self.ensure_healing_enabled()?;
143        let healed = self.complete_json(request).await?;
144        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
145        let coerced = engine
146            .coerce(&healed.parsed.value, schema)
147            .map_err(SimpleAgentsError::Healing)?;
148
149        Ok(HealedSchemaResponse {
150            response: healed.response,
151            parsed: healed.parsed,
152            coerced,
153        })
154    }
155
156    /// Execute a streaming completion request.
157    pub async fn stream(
158        &self,
159        request: &CompletionRequest,
160    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
161        request.validate()?;
162        self.before_request(request).await?;
163        eprintln!(
164            "SimpleAgentsClient.stream: model={}, stream={:?}",
165            request.model, request.stream
166        );
167
168        let router = {
169            let state = self.state.read().map_err(|_| {
170                SimpleAgentsError::Config("provider registry lock poisoned".to_string())
171            })?;
172            state.router.clone()
173        };
174
175        router.stream(request).await
176    }
177
178    fn ensure_healing_enabled(&self) -> Result<()> {
179        if self.healing.enabled {
180            Ok(())
181        } else {
182            Err(SimpleAgentsError::Config(
183                "healing is disabled for this client".to_string(),
184            ))
185        }
186    }
187
188    fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
189        let serialized = serde_json::to_string(request)?;
190        Ok(CacheKey::from_parts("core", &request.model, &serialized))
191    }
192
193    async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
194        for middleware in &self.middleware {
195            middleware.before_request(request).await?;
196        }
197        Ok(())
198    }
199
200    async fn after_response(
201        &self,
202        request: &CompletionRequest,
203        response: &CompletionResponse,
204        latency: Duration,
205    ) -> Result<()> {
206        for middleware in &self.middleware {
207            middleware
208                .after_response(request, response, latency)
209                .await?;
210        }
211        Ok(())
212    }
213
214    async fn on_cache_hit(
215        &self,
216        request: &CompletionRequest,
217        response: &CompletionResponse,
218    ) -> Result<()> {
219        for middleware in &self.middleware {
220            middleware.on_cache_hit(request, response).await?;
221        }
222        Ok(())
223    }
224
225    async fn on_error(
226        &self,
227        request: &CompletionRequest,
228        error: &SimpleAgentsError,
229        latency: Duration,
230    ) -> Result<()> {
231        for middleware in &self.middleware {
232            middleware.on_error(request, error, latency).await?;
233        }
234        Ok(())
235    }
236}
237
238/// Builder for `SimpleAgentsClient`.
239pub struct SimpleAgentsClientBuilder {
240    providers: Vec<Arc<dyn Provider>>,
241    routing_mode: RoutingMode,
242    cache: Option<Arc<dyn Cache>>,
243    cache_ttl: Duration,
244    healing: HealingSettings,
245    middleware: Vec<Arc<dyn Middleware>>,
246}
247
248impl SimpleAgentsClientBuilder {
249    /// Create a new builder with defaults.
250    pub fn new() -> Self {
251        Self {
252            providers: Vec::new(),
253            routing_mode: RoutingMode::default(),
254            cache: None,
255            cache_ttl: Duration::from_secs(60),
256            healing: HealingSettings::default(),
257            middleware: Vec::new(),
258        }
259    }
260
261    /// Register a provider.
262    pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
263        self.providers.push(provider);
264        self
265    }
266
267    /// Register multiple providers at once.
268    pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
269        self.providers.extend(providers);
270        self
271    }
272
273    /// Configure routing mode.
274    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
275        self.routing_mode = mode;
276        self
277    }
278
279    /// Configure response cache.
280    pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
281        self.cache = Some(cache);
282        self
283    }
284
285    /// Configure cache TTL.
286    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
287        self.cache_ttl = ttl;
288        self
289    }
290
291    /// Configure healing settings.
292    pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
293        self.healing = settings;
294        self
295    }
296
297    /// Register a middleware hook.
298    pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
299        self.middleware.push(middleware);
300        self
301    }
302
303    /// Build the client.
304    pub fn build(self) -> Result<SimpleAgentsClient> {
305        if self.providers.is_empty() {
306            return Err(SimpleAgentsError::Config(
307                "at least one provider is required".to_string(),
308            ));
309        }
310
311        let provider_map = self
312            .providers
313            .iter()
314            .map(|provider| (provider.name().to_string(), provider.clone()))
315            .collect::<HashMap<_, _>>();
316
317        let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
318        let state = ClientState {
319            providers: self.providers,
320            provider_map,
321            router,
322        };
323
324        Ok(SimpleAgentsClient {
325            state: RwLock::new(state),
326            routing_mode: self.routing_mode,
327            cache: self.cache,
328            cache_ttl: self.cache_ttl,
329            healing: self.healing,
330            middleware: self.middleware,
331        })
332    }
333}
334
335impl Default for SimpleAgentsClientBuilder {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341#[async_trait]
342impl Middleware for () {
343    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
344        Ok(())
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use simple_agent_type::prelude::*;
352    use std::sync::atomic::{AtomicUsize, Ordering};
353
354    struct MockProvider {
355        name: &'static str,
356        calls: AtomicUsize,
357    }
358
359    impl MockProvider {
360        fn new(name: &'static str) -> Self {
361            Self {
362                name,
363                calls: AtomicUsize::new(0),
364            }
365        }
366    }
367
368    #[async_trait]
369    impl Provider for MockProvider {
370        fn name(&self) -> &str {
371            self.name
372        }
373
374        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
375            Ok(ProviderRequest::new("http://example.com"))
376        }
377
378        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
379            self.calls.fetch_add(1, Ordering::Relaxed);
380            Ok(ProviderResponse::new(
381                200,
382                serde_json::json!({"content": "ok"}),
383            ))
384        }
385
386        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
387            Ok(CompletionResponse {
388                id: "resp_test".to_string(),
389                model: "test-model".to_string(),
390                choices: vec![CompletionChoice {
391                    index: 0,
392                    message: Message::assistant("ok"),
393                    finish_reason: FinishReason::Stop,
394                    logprobs: None,
395                }],
396                usage: Usage::new(1, 1),
397                created: None,
398                provider: Some(self.name.to_string()),
399                healing_metadata: None,
400            })
401        }
402    }
403
404    #[tokio::test]
405    async fn client_build_requires_provider() {
406        let result = SimpleAgentsClientBuilder::new().build();
407        assert!(result.is_err());
408    }
409
410    #[tokio::test]
411    async fn register_provider_rebuilds_router() {
412        let provider = Arc::new(MockProvider::new("p1"));
413        let client = SimpleAgentsClientBuilder::new()
414            .with_provider(provider)
415            .build()
416            .unwrap();
417
418        let second = Arc::new(MockProvider::new("p2"));
419        client.register_provider(second).unwrap();
420
421        let names = client.provider_names().unwrap();
422        assert!(names.contains(&"p1".to_string()));
423        assert!(names.contains(&"p2".to_string()));
424    }
425}