1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use simple_agent_type::cache::Cache;
8use simple_agent_type::cache::CacheKey;
9use simple_agent_type::prelude::{
10 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
11};
12use simple_agents_healing::coercion::CoercionEngine;
13use simple_agents_healing::parser::JsonishParser;
14use simple_agents_healing::schema::Schema;
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19pub enum CompletionMode {
21 Standard,
23 HealedJson,
25 CoercedSchema(Schema),
27}
28
29pub struct CompletionOptions {
31 pub mode: CompletionMode,
33}
34
35impl Default for CompletionOptions {
36 fn default() -> Self {
37 Self {
38 mode: CompletionMode::Standard,
39 }
40 }
41}
42
43pub enum CompletionOutcome {
45 Response(CompletionResponse),
47 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
49 HealedJson(HealedJsonResponse),
51 CoercedSchema(HealedSchemaResponse),
53}
54
55struct ClientState {
56 providers: Vec<Arc<dyn Provider>>,
57 provider_map: HashMap<String, Arc<dyn Provider>>,
58 router: Arc<RouterEngine>,
59}
60
61pub struct SimpleAgentsClient {
63 state: RwLock<ClientState>,
64 routing_mode: RoutingMode,
65 cache: Option<Arc<dyn Cache>>,
66 cache_ttl: Duration,
67 healing: HealingSettings,
68 middleware: Vec<Arc<dyn Middleware>>,
69}
70
71impl SimpleAgentsClient {
72 pub fn builder() -> SimpleAgentsClientBuilder {
74 SimpleAgentsClientBuilder::new()
75 }
76
77 pub fn provider_names(&self) -> Result<Vec<String>> {
79 let state = self.state.read().map_err(|_| {
80 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
81 })?;
82 Ok(state.provider_map.keys().cloned().collect())
83 }
84
85 pub fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
87 let state = self.state.read().map_err(|_| {
88 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
89 })?;
90 Ok(state.provider_map.get(name).cloned())
91 }
92
93 pub fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
95 let mut state = self.state.write().map_err(|_| {
96 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
97 })?;
98 state
99 .provider_map
100 .insert(provider.name().to_string(), provider.clone());
101 state.providers.push(provider);
102 state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
103 Ok(())
104 }
105
106 pub async fn complete(
108 &self,
109 request: &CompletionRequest,
110 options: CompletionOptions,
111 ) -> Result<CompletionOutcome> {
112 if request.stream.unwrap_or(false) {
113 let stream = self.stream(request).await?;
114 return Ok(CompletionOutcome::Stream(stream));
115 }
116
117 match options.mode {
118 CompletionMode::Standard => {
119 let response = self.complete_response(request).await?;
120 Ok(CompletionOutcome::Response(response))
121 }
122 CompletionMode::HealedJson => {
123 let healed = self.complete_json_internal(request).await?;
124 Ok(CompletionOutcome::HealedJson(healed))
125 }
126 CompletionMode::CoercedSchema(schema) => {
127 let healed = self.complete_with_schema_internal(request, &schema).await?;
128 Ok(CompletionOutcome::CoercedSchema(healed))
129 }
130 }
131 }
132
133 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
134 request.validate()?;
135 self.before_request(request).await?;
136
137 let cache_key = if let Some(cache) = &self.cache {
138 if cache.is_enabled() {
139 Some(self.cache_key(request)?)
140 } else {
141 None
142 }
143 } else {
144 None
145 };
146
147 if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
148 if let Some(cached) = cache.get(key).await? {
149 let response: CompletionResponse = serde_json::from_slice(&cached)?;
150 self.on_cache_hit(request, &response).await?;
151 return Ok(response);
152 }
153 }
154
155 let start = Instant::now();
156 let router = {
157 let state = self.state.read().map_err(|_| {
158 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
159 })?;
160 state.router.clone()
161 };
162 let response = router.complete(request).await;
163
164 match response {
165 Ok(response) => {
166 self.after_response(request, &response, start.elapsed())
167 .await?;
168 if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
169 let payload = serde_json::to_vec(&response)?;
170 cache.set(&key, payload, self.cache_ttl).await?;
171 }
172 Ok(response)
173 }
174 Err(error) => {
175 self.on_error(request, &error, start.elapsed()).await?;
176 Err(error)
177 }
178 }
179 }
180
181 async fn complete_json_internal(
183 &self,
184 request: &CompletionRequest,
185 ) -> Result<HealedJsonResponse> {
186 self.ensure_healing_enabled()?;
187 let response = self.complete_response(request).await?;
188 let content = response.content().ok_or_else(|| {
189 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
190 error_message: "response contained no content".to_string(),
191 input: String::new(),
192 })
193 })?;
194
195 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
196 let parsed = parser.parse(content)?;
197
198 Ok(HealedJsonResponse { response, parsed })
199 }
200
201 async fn complete_with_schema_internal(
203 &self,
204 request: &CompletionRequest,
205 schema: &Schema,
206 ) -> Result<HealedSchemaResponse> {
207 self.ensure_healing_enabled()?;
208 let healed = self.complete_json_internal(request).await?;
209 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
210 let coerced = engine
211 .coerce(&healed.parsed.value, schema)
212 .map_err(SimpleAgentsError::Healing)?;
213
214 Ok(HealedSchemaResponse {
215 response: healed.response,
216 parsed: healed.parsed,
217 coerced,
218 })
219 }
220
221 async fn stream(
223 &self,
224 request: &CompletionRequest,
225 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
226 request.validate()?;
227 self.before_request(request).await?;
228 eprintln!(
229 "SimpleAgentsClient.stream: model={}, stream={:?}",
230 request.model, request.stream
231 );
232
233 let router = {
234 let state = self.state.read().map_err(|_| {
235 SimpleAgentsError::Config("provider registry lock poisoned".to_string())
236 })?;
237 state.router.clone()
238 };
239
240 router.stream(request).await
241 }
242
243 fn ensure_healing_enabled(&self) -> Result<()> {
244 if self.healing.enabled {
245 Ok(())
246 } else {
247 Err(SimpleAgentsError::Config(
248 "healing is disabled for this client".to_string(),
249 ))
250 }
251 }
252
253 fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
254 let serialized = serde_json::to_string(request)?;
255 Ok(CacheKey::from_parts("core", &request.model, &serialized))
256 }
257
258 async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
259 for middleware in &self.middleware {
260 middleware.before_request(request).await?;
261 }
262 Ok(())
263 }
264
265 async fn after_response(
266 &self,
267 request: &CompletionRequest,
268 response: &CompletionResponse,
269 latency: Duration,
270 ) -> Result<()> {
271 for middleware in &self.middleware {
272 middleware
273 .after_response(request, response, latency)
274 .await?;
275 }
276 Ok(())
277 }
278
279 async fn on_cache_hit(
280 &self,
281 request: &CompletionRequest,
282 response: &CompletionResponse,
283 ) -> Result<()> {
284 for middleware in &self.middleware {
285 middleware.on_cache_hit(request, response).await?;
286 }
287 Ok(())
288 }
289
290 async fn on_error(
291 &self,
292 request: &CompletionRequest,
293 error: &SimpleAgentsError,
294 latency: Duration,
295 ) -> Result<()> {
296 for middleware in &self.middleware {
297 middleware.on_error(request, error, latency).await?;
298 }
299 Ok(())
300 }
301}
302
303pub struct SimpleAgentsClientBuilder {
305 providers: Vec<Arc<dyn Provider>>,
306 routing_mode: RoutingMode,
307 cache: Option<Arc<dyn Cache>>,
308 cache_ttl: Duration,
309 healing: HealingSettings,
310 middleware: Vec<Arc<dyn Middleware>>,
311}
312
313impl SimpleAgentsClientBuilder {
314 pub fn new() -> Self {
316 Self {
317 providers: Vec::new(),
318 routing_mode: RoutingMode::default(),
319 cache: None,
320 cache_ttl: Duration::from_secs(60),
321 healing: HealingSettings::default(),
322 middleware: Vec::new(),
323 }
324 }
325
326 pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
328 self.providers.push(provider);
329 self
330 }
331
332 pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
334 self.providers.extend(providers);
335 self
336 }
337
338 pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
340 self.routing_mode = mode;
341 self
342 }
343
344 pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
346 self.cache = Some(cache);
347 self
348 }
349
350 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
352 self.cache_ttl = ttl;
353 self
354 }
355
356 pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
358 self.healing = settings;
359 self
360 }
361
362 pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
364 self.middleware.push(middleware);
365 self
366 }
367
368 pub fn build(self) -> Result<SimpleAgentsClient> {
370 if self.providers.is_empty() {
371 return Err(SimpleAgentsError::Config(
372 "at least one provider is required".to_string(),
373 ));
374 }
375
376 let provider_map = self
377 .providers
378 .iter()
379 .map(|provider| (provider.name().to_string(), provider.clone()))
380 .collect::<HashMap<_, _>>();
381
382 let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
383 let state = ClientState {
384 providers: self.providers,
385 provider_map,
386 router,
387 };
388
389 Ok(SimpleAgentsClient {
390 state: RwLock::new(state),
391 routing_mode: self.routing_mode,
392 cache: self.cache,
393 cache_ttl: self.cache_ttl,
394 healing: self.healing,
395 middleware: self.middleware,
396 })
397 }
398}
399
400impl Default for SimpleAgentsClientBuilder {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406#[async_trait]
407impl Middleware for () {
408 async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
409 Ok(())
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use simple_agent_type::prelude::*;
417 use std::sync::atomic::{AtomicUsize, Ordering};
418
419 struct MockProvider {
420 name: &'static str,
421 calls: AtomicUsize,
422 }
423
424 impl MockProvider {
425 fn new(name: &'static str) -> Self {
426 Self {
427 name,
428 calls: AtomicUsize::new(0),
429 }
430 }
431 }
432
433 #[async_trait]
434 impl Provider for MockProvider {
435 fn name(&self) -> &str {
436 self.name
437 }
438
439 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
440 Ok(ProviderRequest::new("http://example.com"))
441 }
442
443 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
444 self.calls.fetch_add(1, Ordering::Relaxed);
445 Ok(ProviderResponse::new(
446 200,
447 serde_json::json!({"content": "ok"}),
448 ))
449 }
450
451 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
452 Ok(CompletionResponse {
453 id: "resp_test".to_string(),
454 model: "test-model".to_string(),
455 choices: vec![CompletionChoice {
456 index: 0,
457 message: Message::assistant("ok"),
458 finish_reason: FinishReason::Stop,
459 logprobs: None,
460 }],
461 usage: Usage::new(1, 1),
462 created: None,
463 provider: Some(self.name.to_string()),
464 healing_metadata: None,
465 })
466 }
467 }
468
469 #[tokio::test]
470 async fn client_build_requires_provider() {
471 let result = SimpleAgentsClientBuilder::new().build();
472 assert!(result.is_err());
473 }
474
475 #[tokio::test]
476 async fn register_provider_rebuilds_router() {
477 let provider = Arc::new(MockProvider::new("p1"));
478 let client = SimpleAgentsClientBuilder::new()
479 .with_provider(provider)
480 .build()
481 .unwrap();
482
483 let second = Arc::new(MockProvider::new("p2"));
484 client.register_provider(second).unwrap();
485
486 let names = client.provider_names().unwrap();
487 assert!(names.contains(&"p1".to_string()));
488 assert!(names.contains(&"p2".to_string()));
489 }
490}