1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use simple_agent_type::cache::Cache;
8use simple_agents_healing::coercion::CoercionEngine;
9use simple_agents_healing::parser::JsonishParser;
10use simple_agents_healing::schema::Schema;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19struct ClientState {
20 providers: Vec<Arc<dyn Provider>>,
21 provider_map: HashMap<String, Arc<dyn Provider>>,
22 router: Arc<RouterEngine>,
23}
24
25pub struct SimpleAgentsClient {
27 state: RwLock<ClientState>,
28 routing_mode: RoutingMode,
29 cache: Option<Arc<dyn Cache>>,
30 cache_ttl: Duration,
31 healing: HealingSettings,
32 middleware: Vec<Arc<dyn Middleware>>,
33}
34
35impl SimpleAgentsClient {
36 pub fn builder() -> SimpleAgentsClientBuilder {
38 SimpleAgentsClientBuilder::new()
39 }
40
41 pub fn provider_names(&self) -> Result<Vec<String>> {
43 let state = self.state.read().map_err(|_| {
44 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
45 })?;
46 Ok(state.provider_map.keys().cloned().collect())
47 }
48
49 pub fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
51 let state = self.state.read().map_err(|_| {
52 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
53 })?;
54 Ok(state.provider_map.get(name).cloned())
55 }
56
57 pub fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
59 let mut state = self.state.write().map_err(|_| {
60 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
61 })?;
62 state
63 .provider_map
64 .insert(provider.name().to_string(), provider.clone());
65 state.providers.push(provider);
66 state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
67 Ok(())
68 }
69
70 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
72 request.validate()?;
73 self.before_request(request).await?;
74
75 let cache_key = if let Some(cache) = &self.cache {
76 if cache.is_enabled() {
77 Some(self.cache_key(request)?)
78 } else {
79 None
80 }
81 } else {
82 None
83 };
84
85 if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
86 if let Some(cached) = cache.get(key).await? {
87 let response: CompletionResponse = serde_json::from_slice(&cached)?;
88 self.on_cache_hit(request, &response).await?;
89 return Ok(response);
90 }
91 }
92
93 let start = Instant::now();
94 let router = {
95 let state = self.state.read().map_err(|_| {
96 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
97 })?;
98 state.router.clone()
99 };
100 let response = router.complete(request).await;
101
102 match response {
103 Ok(response) => {
104 self.after_response(request, &response, start.elapsed())
105 .await?;
106 if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
107 let payload = serde_json::to_vec(&response)?;
108 cache.set(&key, payload, self.cache_ttl).await?;
109 }
110 Ok(response)
111 }
112 Err(error) => {
113 self.on_error(request, &error, start.elapsed()).await?;
114 Err(error)
115 }
116 }
117 }
118
119 pub async fn complete_json(&self, request: &CompletionRequest) -> Result<HealedJsonResponse> {
121 self.ensure_healing_enabled()?;
122 let response = self.complete(request).await?;
123 let content = response.content().ok_or_else(|| {
124 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
125 error_message: "response contained no content".to_string(),
126 input: String::new(),
127 })
128 })?;
129
130 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
131 let parsed = parser.parse(content)?;
132
133 Ok(HealedJsonResponse { response, parsed })
134 }
135
136 pub async fn complete_with_schema(
138 &self,
139 request: &CompletionRequest,
140 schema: &Schema,
141 ) -> Result<HealedSchemaResponse> {
142 self.ensure_healing_enabled()?;
143 let healed = self.complete_json(request).await?;
144 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
145 let coerced = engine
146 .coerce(&healed.parsed.value, schema)
147 .map_err(SimpleAgentsError::Healing)?;
148
149 Ok(HealedSchemaResponse {
150 response: healed.response,
151 parsed: healed.parsed,
152 coerced,
153 })
154 }
155
156 pub async fn stream(
158 &self,
159 request: &CompletionRequest,
160 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
161 request.validate()?;
162 self.before_request(request).await?;
163 eprintln!(
164 "SimpleAgentsClient.stream: model={}, stream={:?}",
165 request.model, request.stream
166 );
167
168 let router = {
169 let state = self.state.read().map_err(|_| {
170 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
171 })?;
172 state.router.clone()
173 };
174
175 router.stream(request).await
176 }
177
178 fn ensure_healing_enabled(&self) -> Result<()> {
179 if self.healing.enabled {
180 Ok(())
181 } else {
182 Err(SimpleAgentsError::Config(
183 "healing is disabled for this client".to_string(),
184 ))
185 }
186 }
187
188 fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
189 let serialized = serde_json::to_string(request)?;
190 Ok(CacheKey::from_parts("core", &request.model, &serialized))
191 }
192
193 async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
194 for middleware in &self.middleware {
195 middleware.before_request(request).await?;
196 }
197 Ok(())
198 }
199
200 async fn after_response(
201 &self,
202 request: &CompletionRequest,
203 response: &CompletionResponse,
204 latency: Duration,
205 ) -> Result<()> {
206 for middleware in &self.middleware {
207 middleware
208 .after_response(request, response, latency)
209 .await?;
210 }
211 Ok(())
212 }
213
214 async fn on_cache_hit(
215 &self,
216 request: &CompletionRequest,
217 response: &CompletionResponse,
218 ) -> Result<()> {
219 for middleware in &self.middleware {
220 middleware.on_cache_hit(request, response).await?;
221 }
222 Ok(())
223 }
224
225 async fn on_error(
226 &self,
227 request: &CompletionRequest,
228 error: &SimpleAgentsError,
229 latency: Duration,
230 ) -> Result<()> {
231 for middleware in &self.middleware {
232 middleware.on_error(request, error, latency).await?;
233 }
234 Ok(())
235 }
236}
237
238pub struct SimpleAgentsClientBuilder {
240 providers: Vec<Arc<dyn Provider>>,
241 routing_mode: RoutingMode,
242 cache: Option<Arc<dyn Cache>>,
243 cache_ttl: Duration,
244 healing: HealingSettings,
245 middleware: Vec<Arc<dyn Middleware>>,
246}
247
248impl SimpleAgentsClientBuilder {
249 pub fn new() -> Self {
251 Self {
252 providers: Vec::new(),
253 routing_mode: RoutingMode::default(),
254 cache: None,
255 cache_ttl: Duration::from_secs(60),
256 healing: HealingSettings::default(),
257 middleware: Vec::new(),
258 }
259 }
260
261 pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
263 self.providers.push(provider);
264 self
265 }
266
267 pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
269 self.providers.extend(providers);
270 self
271 }
272
273 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
275 self.routing_mode = mode;
276 self
277 }
278
279 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
281 self.cache = Some(cache);
282 self
283 }
284
285 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
287 self.cache_ttl = ttl;
288 self
289 }
290
291 pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
293 self.healing = settings;
294 self
295 }
296
297 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
299 self.middleware.push(middleware);
300 self
301 }
302
303 pub fn build(self) -> Result<SimpleAgentsClient> {
305 if self.providers.is_empty() {
306 return Err(SimpleAgentsError::Config(
307 "at least one provider is required".to_string(),
308 ));
309 }
310
311 let provider_map = self
312 .providers
313 .iter()
314 .map(|provider| (provider.name().to_string(), provider.clone()))
315 .collect::<HashMap<_, _>>();
316
317 let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
318 let state = ClientState {
319 providers: self.providers,
320 provider_map,
321 router,
322 };
323
324 Ok(SimpleAgentsClient {
325 state: RwLock::new(state),
326 routing_mode: self.routing_mode,
327 cache: self.cache,
328 cache_ttl: self.cache_ttl,
329 healing: self.healing,
330 middleware: self.middleware,
331 })
332 }
333}
334
335impl Default for SimpleAgentsClientBuilder {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341#[async_trait]
342impl Middleware for () {
343 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
344 Ok(())
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use simple_agent_type::prelude::*;
352 use std::sync::atomic::{AtomicUsize, Ordering};
353
354 struct MockProvider {
355 name: &'static str,
356 calls: AtomicUsize,
357 }
358
359 impl MockProvider {
360 fn new(name: &'static str) -> Self {
361 Self {
362 name,
363 calls: AtomicUsize::new(0),
364 }
365 }
366 }
367
368 #[async_trait]
369 impl Provider for MockProvider {
370 fn name(&self) -> &str {
371 self.name
372 }
373
374 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
375 Ok(ProviderRequest::new("http://example.com"))
376 }
377
378 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
379 self.calls.fetch_add(1, Ordering::Relaxed);
380 Ok(ProviderResponse::new(
381 200,
382 serde_json::json!({"content": "ok"}),
383 ))
384 }
385
386 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
387 Ok(CompletionResponse {
388 id: "resp_test".to_string(),
389 model: "test-model".to_string(),
390 choices: vec![CompletionChoice {
391 index: 0,
392 message: Message::assistant("ok"),
393 finish_reason: FinishReason::Stop,
394 logprobs: None,
395 }],
396 usage: Usage::new(1, 1),
397 created: None,
398 provider: Some(self.name.to_string()),
399 healing_metadata: None,
400 })
401 }
402 }
403
404 #[tokio::test]
405 async fn client_build_requires_provider() {
406 let result = SimpleAgentsClientBuilder::new().build();
407 assert!(result.is_err());
408 }
409
410 #[tokio::test]
411 async fn register_provider_rebuilds_router() {
412 let provider = Arc::new(MockProvider::new("p1"));
413 let client = SimpleAgentsClientBuilder::new()
414 .with_provider(provider)
415 .build()
416 .unwrap();
417
418 let second = Arc::new(MockProvider::new("p2"));
419 client.register_provider(second).unwrap();
420
421 let names = client.provider_names().unwrap();
422 assert!(names.contains(&"p1".to_string()));
423 assert!(names.contains(&"p2".to_string()));
424 }
425}