Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use simple_agents_cache::Cache;
8use simple_agents_healing::coercion::CoercionEngine;
9use simple_agents_healing::parser::JsonishParser;
10use simple_agents_healing::schema::Schema;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13    CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19struct ClientState {
20    providers: Vec<Arc<dyn Provider>>,
21    provider_map: HashMap<String, Arc<dyn Provider>>,
22    router: Arc<RouterEngine>,
23}
24
25/// Unified SimpleAgents client.
26pub struct SimpleAgentsClient {
27    state: RwLock<ClientState>,
28    routing_mode: RoutingMode,
29    cache: Option<Arc<dyn Cache>>,
30    cache_ttl: Duration,
31    healing: HealingSettings,
32    middleware: Vec<Arc<dyn Middleware>>,
33}
34
35impl SimpleAgentsClient {
36    /// Start a new client builder.
37    pub fn builder() -> SimpleAgentsClientBuilder {
38        SimpleAgentsClientBuilder::new()
39    }
40
41    /// List registered provider names.
42    pub fn provider_names(&self) -> Result<Vec<String>> {
43        let state = self.state.read().map_err(|_| {
44            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
45        })?;
46        Ok(state.provider_map.keys().cloned().collect())
47    }
48
49    /// Retrieve a provider by name.
50    pub fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
51        let state = self.state.read().map_err(|_| {
52            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
53        })?;
54        Ok(state.provider_map.get(name).cloned())
55    }
56
57    /// Register an additional provider and rebuild the router.
58    pub fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
59        let mut state = self.state.write().map_err(|_| {
60            SimpleAgentsError::Config("provider registry lock poisoned".to_string())
61        })?;
62        state
63            .provider_map
64            .insert(provider.name().to_string(), provider.clone());
65        state.providers.push(provider);
66        state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
67        Ok(())
68    }
69
70    /// Execute a completion request with routing, caching, and middleware.
71    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
72        request.validate()?;
73        self.before_request(request).await?;
74
75        let cache_key = if let Some(cache) = &self.cache {
76            if cache.is_enabled() {
77                Some(self.cache_key(request)?)
78            } else {
79                None
80            }
81        } else {
82            None
83        };
84
85        if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
86            if let Some(cached) = cache.get(key).await? {
87                let response: CompletionResponse = serde_json::from_slice(&cached)?;
88                self.on_cache_hit(request, &response).await?;
89                return Ok(response);
90            }
91        }
92
93        let start = Instant::now();
94        let router = {
95            let state = self.state.read().map_err(|_| {
96                SimpleAgentsError::Config("provider registry lock poisoned".to_string())
97            })?;
98            state.router.clone()
99        };
100        let response = router.complete(request).await;
101
102        match response {
103            Ok(response) => {
104                self.after_response(request, &response, start.elapsed())
105                    .await?;
106                if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
107                    let payload = serde_json::to_vec(&response)?;
108                    cache.set(&key, payload, self.cache_ttl).await?;
109                }
110                Ok(response)
111            }
112            Err(error) => {
113                self.on_error(request, &error, start.elapsed()).await?;
114                Err(error)
115            }
116        }
117    }
118
119    /// Execute a completion request and parse the response content as JSON.
120    pub async fn complete_json(&self, request: &CompletionRequest) -> Result<HealedJsonResponse> {
121        self.ensure_healing_enabled()?;
122        let response = self.complete(request).await?;
123        let content = response.content().ok_or_else(|| {
124            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
125                error_message: "response contained no content".to_string(),
126                input: String::new(),
127            })
128        })?;
129
130        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
131        let parsed = parser.parse(content)?;
132
133        Ok(HealedJsonResponse { response, parsed })
134    }
135
136    /// Execute a completion request and coerce the response into a schema.
137    pub async fn complete_with_schema(
138        &self,
139        request: &CompletionRequest,
140        schema: &Schema,
141    ) -> Result<HealedSchemaResponse> {
142        self.ensure_healing_enabled()?;
143        let healed = self.complete_json(request).await?;
144        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
145        let coerced = engine
146            .coerce(&healed.parsed.value, schema)
147            .map_err(SimpleAgentsError::Healing)?;
148
149        Ok(HealedSchemaResponse {
150            response: healed.response,
151            parsed: healed.parsed,
152            coerced,
153        })
154    }
155
156    fn ensure_healing_enabled(&self) -> Result<()> {
157        if self.healing.enabled {
158            Ok(())
159        } else {
160            Err(SimpleAgentsError::Config(
161                "healing is disabled for this client".to_string(),
162            ))
163        }
164    }
165
166    fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
167        let serialized = serde_json::to_string(request)?;
168        Ok(CacheKey::from_parts("core", &request.model, &serialized))
169    }
170
171    async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
172        for middleware in &self.middleware {
173            middleware.before_request(request).await?;
174        }
175        Ok(())
176    }
177
178    async fn after_response(
179        &self,
180        request: &CompletionRequest,
181        response: &CompletionResponse,
182        latency: Duration,
183    ) -> Result<()> {
184        for middleware in &self.middleware {
185            middleware
186                .after_response(request, response, latency)
187                .await?;
188        }
189        Ok(())
190    }
191
192    async fn on_cache_hit(
193        &self,
194        request: &CompletionRequest,
195        response: &CompletionResponse,
196    ) -> Result<()> {
197        for middleware in &self.middleware {
198            middleware.on_cache_hit(request, response).await?;
199        }
200        Ok(())
201    }
202
203    async fn on_error(
204        &self,
205        request: &CompletionRequest,
206        error: &SimpleAgentsError,
207        latency: Duration,
208    ) -> Result<()> {
209        for middleware in &self.middleware {
210            middleware.on_error(request, error, latency).await?;
211        }
212        Ok(())
213    }
214}
215
216/// Builder for `SimpleAgentsClient`.
217pub struct SimpleAgentsClientBuilder {
218    providers: Vec<Arc<dyn Provider>>,
219    routing_mode: RoutingMode,
220    cache: Option<Arc<dyn Cache>>,
221    cache_ttl: Duration,
222    healing: HealingSettings,
223    middleware: Vec<Arc<dyn Middleware>>,
224}
225
226impl SimpleAgentsClientBuilder {
227    /// Create a new builder with defaults.
228    pub fn new() -> Self {
229        Self {
230            providers: Vec::new(),
231            routing_mode: RoutingMode::default(),
232            cache: None,
233            cache_ttl: Duration::from_secs(60),
234            healing: HealingSettings::default(),
235            middleware: Vec::new(),
236        }
237    }
238
239    /// Register a provider.
240    pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
241        self.providers.push(provider);
242        self
243    }
244
245    /// Register multiple providers at once.
246    pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
247        self.providers.extend(providers);
248        self
249    }
250
251    /// Configure routing mode.
252    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
253        self.routing_mode = mode;
254        self
255    }
256
257    /// Configure response cache.
258    pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
259        self.cache = Some(cache);
260        self
261    }
262
263    /// Configure cache TTL.
264    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
265        self.cache_ttl = ttl;
266        self
267    }
268
269    /// Configure healing settings.
270    pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
271        self.healing = settings;
272        self
273    }
274
275    /// Register a middleware hook.
276    pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
277        self.middleware.push(middleware);
278        self
279    }
280
281    /// Build the client.
282    pub fn build(self) -> Result<SimpleAgentsClient> {
283        if self.providers.is_empty() {
284            return Err(SimpleAgentsError::Config(
285                "at least one provider is required".to_string(),
286            ));
287        }
288
289        let provider_map = self
290            .providers
291            .iter()
292            .map(|provider| (provider.name().to_string(), provider.clone()))
293            .collect::<HashMap<_, _>>();
294
295        let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
296        let state = ClientState {
297            providers: self.providers,
298            provider_map,
299            router,
300        };
301
302        Ok(SimpleAgentsClient {
303            state: RwLock::new(state),
304            routing_mode: self.routing_mode,
305            cache: self.cache,
306            cache_ttl: self.cache_ttl,
307            healing: self.healing,
308            middleware: self.middleware,
309        })
310    }
311}
312
313impl Default for SimpleAgentsClientBuilder {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319#[async_trait]
320impl Middleware for () {
321    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
322        Ok(())
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use simple_agent_type::prelude::*;
330    use std::sync::atomic::{AtomicUsize, Ordering};
331
332    struct MockProvider {
333        name: &'static str,
334        calls: AtomicUsize,
335    }
336
337    impl MockProvider {
338        fn new(name: &'static str) -> Self {
339            Self {
340                name,
341                calls: AtomicUsize::new(0),
342            }
343        }
344    }
345
346    #[async_trait]
347    impl Provider for MockProvider {
348        fn name(&self) -> &str {
349            self.name
350        }
351
352        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
353            Ok(ProviderRequest::new("http://example.com"))
354        }
355
356        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
357            self.calls.fetch_add(1, Ordering::Relaxed);
358            Ok(ProviderResponse::new(
359                200,
360                serde_json::json!({"content": "ok"}),
361            ))
362        }
363
364        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
365            Ok(CompletionResponse {
366                id: "resp_test".to_string(),
367                model: "test-model".to_string(),
368                choices: vec![CompletionChoice {
369                    index: 0,
370                    message: Message::assistant("ok"),
371                    finish_reason: FinishReason::Stop,
372                    logprobs: None,
373                }],
374                usage: Usage::new(1, 1),
375                created: None,
376                provider: Some(self.name.to_string()),
377                healing_metadata: None,
378            })
379        }
380    }
381
382    #[tokio::test]
383    async fn client_build_requires_provider() {
384        let result = SimpleAgentsClientBuilder::new().build();
385        assert!(result.is_err());
386    }
387
388    #[tokio::test]
389    async fn register_provider_rebuilds_router() {
390        let provider = Arc::new(MockProvider::new("p1"));
391        let client = SimpleAgentsClientBuilder::new()
392            .with_provider(provider)
393            .build()
394            .unwrap();
395
396        let second = Arc::new(MockProvider::new("p2"));
397        client.register_provider(second).unwrap();
398
399        let names = client.provider_names().unwrap();
400        assert!(names.contains(&"p1".to_string()));
401        assert!(names.contains(&"p2".to_string()));
402    }
403}