1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use simple_agents_cache::Cache;
8use simple_agents_healing::coercion::CoercionEngine;
9use simple_agents_healing::parser::JsonishParser;
10use simple_agents_healing::schema::Schema;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13 CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19struct ClientState {
20 providers: Vec<Arc<dyn Provider>>,
21 provider_map: HashMap<String, Arc<dyn Provider>>,
22 router: Arc<RouterEngine>,
23}
24
25pub struct SimpleAgentsClient {
27 state: RwLock<ClientState>,
28 routing_mode: RoutingMode,
29 cache: Option<Arc<dyn Cache>>,
30 cache_ttl: Duration,
31 healing: HealingSettings,
32 middleware: Vec<Arc<dyn Middleware>>,
33}
34
35impl SimpleAgentsClient {
36 pub fn builder() -> SimpleAgentsClientBuilder {
38 SimpleAgentsClientBuilder::new()
39 }
40
41 pub fn provider_names(&self) -> Result<Vec<String>> {
43 let state = self.state.read().map_err(|_| {
44 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
45 })?;
46 Ok(state.provider_map.keys().cloned().collect())
47 }
48
49 pub fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
51 let state = self.state.read().map_err(|_| {
52 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
53 })?;
54 Ok(state.provider_map.get(name).cloned())
55 }
56
57 pub fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
59 let mut state = self.state.write().map_err(|_| {
60 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
61 })?;
62 state
63 .provider_map
64 .insert(provider.name().to_string(), provider.clone());
65 state.providers.push(provider);
66 state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
67 Ok(())
68 }
69
70 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
72 request.validate()?;
73 self.before_request(request).await?;
74
75 let cache_key = if let Some(cache) = &self.cache {
76 if cache.is_enabled() {
77 Some(self.cache_key(request)?)
78 } else {
79 None
80 }
81 } else {
82 None
83 };
84
85 if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
86 if let Some(cached) = cache.get(key).await? {
87 let response: CompletionResponse = serde_json::from_slice(&cached)?;
88 self.on_cache_hit(request, &response).await?;
89 return Ok(response);
90 }
91 }
92
93 let start = Instant::now();
94 let router = {
95 let state = self.state.read().map_err(|_| {
96 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
97 })?;
98 state.router.clone()
99 };
100 let response = router.complete(request).await;
101
102 match response {
103 Ok(response) => {
104 self.after_response(request, &response, start.elapsed())
105 .await?;
106 if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
107 let payload = serde_json::to_vec(&response)?;
108 cache.set(&key, payload, self.cache_ttl).await?;
109 }
110 Ok(response)
111 }
112 Err(error) => {
113 self.on_error(request, &error, start.elapsed()).await?;
114 Err(error)
115 }
116 }
117 }
118
119 pub async fn complete_json(&self, request: &CompletionRequest) -> Result<HealedJsonResponse> {
121 self.ensure_healing_enabled()?;
122 let response = self.complete(request).await?;
123 let content = response.content().ok_or_else(|| {
124 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
125 error_message: "response contained no content".to_string(),
126 input: String::new(),
127 })
128 })?;
129
130 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
131 let parsed = parser.parse(content)?;
132
133 Ok(HealedJsonResponse { response, parsed })
134 }
135
136 pub async fn complete_with_schema(
138 &self,
139 request: &CompletionRequest,
140 schema: &Schema,
141 ) -> Result<HealedSchemaResponse> {
142 self.ensure_healing_enabled()?;
143 let healed = self.complete_json(request).await?;
144 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
145 let coerced = engine
146 .coerce(&healed.parsed.value, schema)
147 .map_err(SimpleAgentsError::Healing)?;
148
149 Ok(HealedSchemaResponse {
150 response: healed.response,
151 parsed: healed.parsed,
152 coerced,
153 })
154 }
155
156 fn ensure_healing_enabled(&self) -> Result<()> {
157 if self.healing.enabled {
158 Ok(())
159 } else {
160 Err(SimpleAgentsError::Config(
161 "healing is disabled for this client".to_string(),
162 ))
163 }
164 }
165
166 fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
167 let serialized = serde_json::to_string(request)?;
168 Ok(CacheKey::from_parts("core", &request.model, &serialized))
169 }
170
171 async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
172 for middleware in &self.middleware {
173 middleware.before_request(request).await?;
174 }
175 Ok(())
176 }
177
178 async fn after_response(
179 &self,
180 request: &CompletionRequest,
181 response: &CompletionResponse,
182 latency: Duration,
183 ) -> Result<()> {
184 for middleware in &self.middleware {
185 middleware
186 .after_response(request, response, latency)
187 .await?;
188 }
189 Ok(())
190 }
191
192 async fn on_cache_hit(
193 &self,
194 request: &CompletionRequest,
195 response: &CompletionResponse,
196 ) -> Result<()> {
197 for middleware in &self.middleware {
198 middleware.on_cache_hit(request, response).await?;
199 }
200 Ok(())
201 }
202
203 async fn on_error(
204 &self,
205 request: &CompletionRequest,
206 error: &SimpleAgentsError,
207 latency: Duration,
208 ) -> Result<()> {
209 for middleware in &self.middleware {
210 middleware.on_error(request, error, latency).await?;
211 }
212 Ok(())
213 }
214}
215
216pub struct SimpleAgentsClientBuilder {
218 providers: Vec<Arc<dyn Provider>>,
219 routing_mode: RoutingMode,
220 cache: Option<Arc<dyn Cache>>,
221 cache_ttl: Duration,
222 healing: HealingSettings,
223 middleware: Vec<Arc<dyn Middleware>>,
224}
225
226impl SimpleAgentsClientBuilder {
227 pub fn new() -> Self {
229 Self {
230 providers: Vec::new(),
231 routing_mode: RoutingMode::default(),
232 cache: None,
233 cache_ttl: Duration::from_secs(60),
234 healing: HealingSettings::default(),
235 middleware: Vec::new(),
236 }
237 }
238
239 pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
241 self.providers.push(provider);
242 self
243 }
244
245 pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
247 self.providers.extend(providers);
248 self
249 }
250
251 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
253 self.routing_mode = mode;
254 self
255 }
256
257 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
259 self.cache = Some(cache);
260 self
261 }
262
263 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
265 self.cache_ttl = ttl;
266 self
267 }
268
269 pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
271 self.healing = settings;
272 self
273 }
274
275 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
277 self.middleware.push(middleware);
278 self
279 }
280
281 pub fn build(self) -> Result<SimpleAgentsClient> {
283 if self.providers.is_empty() {
284 return Err(SimpleAgentsError::Config(
285 "at least one provider is required".to_string(),
286 ));
287 }
288
289 let provider_map = self
290 .providers
291 .iter()
292 .map(|provider| (provider.name().to_string(), provider.clone()))
293 .collect::<HashMap<_, _>>();
294
295 let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
296 let state = ClientState {
297 providers: self.providers,
298 provider_map,
299 router,
300 };
301
302 Ok(SimpleAgentsClient {
303 state: RwLock::new(state),
304 routing_mode: self.routing_mode,
305 cache: self.cache,
306 cache_ttl: self.cache_ttl,
307 healing: self.healing,
308 middleware: self.middleware,
309 })
310 }
311}
312
313impl Default for SimpleAgentsClientBuilder {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319#[async_trait]
320impl Middleware for () {
321 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
322 Ok(())
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use simple_agent_type::prelude::*;
330 use std::sync::atomic::{AtomicUsize, Ordering};
331
332 struct MockProvider {
333 name: &'static str,
334 calls: AtomicUsize,
335 }
336
337 impl MockProvider {
338 fn new(name: &'static str) -> Self {
339 Self {
340 name,
341 calls: AtomicUsize::new(0),
342 }
343 }
344 }
345
346 #[async_trait]
347 impl Provider for MockProvider {
348 fn name(&self) -> &str {
349 self.name
350 }
351
352 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
353 Ok(ProviderRequest::new("http://example.com"))
354 }
355
356 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
357 self.calls.fetch_add(1, Ordering::Relaxed);
358 Ok(ProviderResponse::new(
359 200,
360 serde_json::json!({"content": "ok"}),
361 ))
362 }
363
364 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
365 Ok(CompletionResponse {
366 id: "resp_test".to_string(),
367 model: "test-model".to_string(),
368 choices: vec![CompletionChoice {
369 index: 0,
370 message: Message::assistant("ok"),
371 finish_reason: FinishReason::Stop,
372 logprobs: None,
373 }],
374 usage: Usage::new(1, 1),
375 created: None,
376 provider: Some(self.name.to_string()),
377 healing_metadata: None,
378 })
379 }
380 }
381
382 #[tokio::test]
383 async fn client_build_requires_provider() {
384 let result = SimpleAgentsClientBuilder::new().build();
385 assert!(result.is_err());
386 }
387
388 #[tokio::test]
389 async fn register_provider_rebuilds_router() {
390 let provider = Arc::new(MockProvider::new("p1"));
391 let client = SimpleAgentsClientBuilder::new()
392 .with_provider(provider)
393 .build()
394 .unwrap();
395
396 let second = Arc::new(MockProvider::new("p2"));
397 client.register_provider(second).unwrap();
398
399 let names = client.provider_names().unwrap();
400 assert!(names.contains(&"p1".to_string()));
401 assert!(names.contains(&"p2".to_string()));
402 }
403}