1use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use serde::{Deserialize, Serialize};
5use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use simple_agent_type::telemetry::{ApiFormat, TelemetryConfig, TraceContext};
9use simple_agents_healing::coercion::CoercionEngine;
10use simple_agents_healing::parser::JsonishParser;
11use simple_agents_healing::schema::Schema;
12use std::sync::Arc;
13use tracing::debug;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RetryConfig {
22 pub max_attempts: u8,
24 pub backoff_ms: u64,
26}
27
28impl Default for RetryConfig {
29 fn default() -> Self {
30 Self {
31 max_attempts: 3,
32 backoff_ms: 1000,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ClientConfig {
40 pub provider: String,
42 pub api_key: String,
44 pub base_url: Option<String>,
46 pub api_format: ApiFormat,
48 pub extra_headers: Option<Vec<(String, String)>>,
50 pub telemetry: Option<TelemetryConfig>,
52 pub default_retry: RetryConfig,
54}
55
56impl Default for ClientConfig {
57 fn default() -> Self {
58 Self {
59 provider: "openai".into(),
60 api_key: String::new(),
61 base_url: None,
62 api_format: ApiFormat::default(),
63 extra_headers: None,
64 telemetry: None,
65 default_retry: RetryConfig::default(),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ExecutionFlags {
73 pub workflow_streaming: bool,
75 pub node_llm_streaming: bool,
77}
78
79impl Default for ExecutionFlags {
80 fn default() -> Self {
81 Self {
82 workflow_streaming: false,
83 node_llm_streaming: true,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RunOptions {
91 pub nerdstats: bool,
93 pub telemetry_enabled: bool,
95 pub trace_context: Option<TraceContext>,
97 pub execution_flags: ExecutionFlags,
99}
100
101impl Default for RunOptions {
102 fn default() -> Self {
103 Self {
104 nerdstats: true,
105 telemetry_enabled: true,
106 trace_context: None,
107 execution_flags: ExecutionFlags::default(),
108 }
109 }
110}
111
112#[derive(Clone)]
114pub enum CompletionMode {
115 Standard,
117 HealedJson,
119 CoercedSchema(Schema),
121}
122
123#[derive(Clone)]
125pub struct CompletionOptions {
126 pub mode: CompletionMode,
128}
129
130impl Default for CompletionOptions {
131 fn default() -> Self {
132 Self {
133 mode: CompletionMode::Standard,
134 }
135 }
136}
137
138pub enum CompletionOutcome {
140 Response(CompletionResponse),
142 Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
144 HealedJson(HealedJsonResponse),
146 CoercedSchema(HealedSchemaResponse),
148}
149
150pub struct SimpleAgentsClient {
152 provider: Arc<dyn Provider>,
153 config: ClientConfig,
154 healing: HealingSettings,
155}
156
157impl SimpleAgentsClient {
158 pub fn new(provider: Arc<dyn Provider>) -> Self {
160 Self {
161 provider,
162 config: ClientConfig::default(),
163 healing: HealingSettings::default(),
164 }
165 }
166
167 pub fn from_config(provider: Arc<dyn Provider>, config: ClientConfig) -> Self {
169 Self {
170 provider,
171 config,
172 healing: HealingSettings::default(),
173 }
174 }
175
176 pub fn with_healing(provider: Arc<dyn Provider>, healing: HealingSettings) -> Self {
178 Self {
179 provider,
180 config: ClientConfig::default(),
181 healing,
182 }
183 }
184
185 pub fn config(&self) -> &ClientConfig {
187 &self.config
188 }
189
190 pub fn provider_name(&self) -> &str {
192 self.provider.name()
193 }
194
195 pub async fn complete(
197 &self,
198 request: &CompletionRequest,
199 options: CompletionOptions,
200 ) -> Result<CompletionOutcome> {
201 if request.stream.unwrap_or(false) {
202 let stream = self.stream(request).await?;
203 return Ok(CompletionOutcome::Stream(stream));
204 }
205
206 match options.mode {
207 CompletionMode::Standard => {
208 let response = self.complete_response(request).await?;
209 Ok(CompletionOutcome::Response(response))
210 }
211 CompletionMode::HealedJson => {
212 let healed = self.complete_json_internal(request).await?;
213 Ok(CompletionOutcome::HealedJson(healed))
214 }
215 CompletionMode::CoercedSchema(schema) => {
216 let healed = self.complete_with_schema_internal(request, &schema).await?;
217 Ok(CompletionOutcome::CoercedSchema(healed))
218 }
219 }
220 }
221
222 async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
223 request.validate()?;
224
225 let provider_request = self.provider.transform_request(request)?;
226 let provider_response = self.provider.execute(provider_request).await?;
227 self.provider.transform_response(provider_response)
228 }
229
230 async fn complete_json_internal(
231 &self,
232 request: &CompletionRequest,
233 ) -> Result<HealedJsonResponse> {
234 self.ensure_healing_enabled()?;
235 let response = self.complete_response(request).await?;
236 let content = response.content().ok_or_else(|| {
237 SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
238 error_message: "response contained no content".to_string(),
239 input: String::new(),
240 })
241 })?;
242
243 let parser = JsonishParser::with_config(self.healing.parser_config.clone());
244 let parsed = parser.parse(content)?;
245
246 Ok(HealedJsonResponse { response, parsed })
247 }
248
249 async fn complete_with_schema_internal(
250 &self,
251 request: &CompletionRequest,
252 schema: &Schema,
253 ) -> Result<HealedSchemaResponse> {
254 self.ensure_healing_enabled()?;
255 let healed = self.complete_json_internal(request).await?;
256 let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
257 let coerced = engine
258 .coerce(&healed.parsed.value, schema)
259 .map_err(SimpleAgentsError::Healing)?;
260
261 Ok(HealedSchemaResponse {
262 response: healed.response,
263 parsed: healed.parsed,
264 coerced,
265 })
266 }
267
268 async fn stream(
269 &self,
270 request: &CompletionRequest,
271 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
272 request.validate()?;
273 debug!(
274 model = %request.model,
275 stream = ?request.stream,
276 "SimpleAgentsClient.stream start"
277 );
278
279 let provider_request = self.provider.transform_request(request)?;
280 self.provider.execute_stream(provider_request).await
281 }
282
283 fn ensure_healing_enabled(&self) -> Result<()> {
284 if self.healing.enabled {
285 Ok(())
286 } else {
287 Err(SimpleAgentsError::Config(
288 "healing is disabled for this client".to_string(),
289 ))
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use async_trait::async_trait;
298 use futures_util::StreamExt;
299 use simple_agent_type::error::ProviderError;
300 use simple_agent_type::prelude::*;
301 use std::sync::atomic::{AtomicUsize, Ordering};
302
303 struct MockProvider {
304 name: &'static str,
305 calls: AtomicUsize,
306 }
307
308 impl MockProvider {
309 fn new(name: &'static str) -> Self {
310 Self {
311 name,
312 calls: AtomicUsize::new(0),
313 }
314 }
315 }
316
317 #[async_trait]
318 impl Provider for MockProvider {
319 fn name(&self) -> &str {
320 self.name
321 }
322
323 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
324 Ok(ProviderRequest::new("http://example.com"))
325 }
326
327 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
328 self.calls.fetch_add(1, Ordering::Relaxed);
329 Ok(ProviderResponse::new(
330 200,
331 serde_json::json!({"content": "ok"}),
332 ))
333 }
334
335 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
336 Ok(CompletionResponse {
337 id: "resp_test".to_string(),
338 model: "test-model".to_string(),
339 choices: vec![CompletionChoice {
340 index: 0,
341 message: Message::assistant("ok"),
342 finish_reason: FinishReason::Stop,
343 logprobs: None,
344 }],
345 usage: Usage::new(1, 1),
346 created: None,
347 provider: Some(self.name.to_string()),
348 healing_metadata: None,
349 })
350 }
351 }
352
353 #[tokio::test]
354 async fn complete_returns_response() {
355 let provider = Arc::new(MockProvider::new("p1"));
356 let client = SimpleAgentsClient::new(provider);
357
358 let request = CompletionRequest::builder()
359 .model("gpt-4")
360 .message(Message::user("Hi"))
361 .build()
362 .unwrap();
363
364 let outcome = client
365 .complete(&request, CompletionOptions::default())
366 .await
367 .unwrap();
368
369 match outcome {
370 CompletionOutcome::Response(resp) => {
371 assert_eq!(resp.provider.as_deref(), Some("p1"));
372 }
373 _ => panic!("expected Response outcome"),
374 }
375 }
376
377 struct StreamingProvider {
378 name: &'static str,
379 fail_after_first: bool,
380 }
381
382 impl StreamingProvider {
383 fn new(name: &'static str, fail_after_first: bool) -> Self {
384 Self {
385 name,
386 fail_after_first,
387 }
388 }
389
390 fn build_chunk(id: &str, content: &str) -> CompletionChunk {
391 CompletionChunk {
392 id: id.to_string(),
393 model: "test-model".to_string(),
394 choices: vec![ChoiceDelta {
395 index: 0,
396 delta: MessageDelta {
397 role: Some(Role::Assistant),
398 content: Some(content.to_string()),
399 reasoning_content: None,
400 tool_calls: None,
401 },
402 finish_reason: None,
403 }],
404 created: None,
405 usage: None,
406 }
407 }
408 }
409
410 #[async_trait]
411 impl Provider for StreamingProvider {
412 fn name(&self) -> &str {
413 self.name
414 }
415
416 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
417 Ok(ProviderRequest::new("http://example.com"))
418 }
419
420 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
421 Ok(ProviderResponse::new(
422 200,
423 serde_json::json!({"content": "ok"}),
424 ))
425 }
426
427 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
428 Ok(CompletionResponse {
429 id: "resp_stream".to_string(),
430 model: "test-model".to_string(),
431 choices: vec![CompletionChoice {
432 index: 0,
433 message: Message::assistant("ok"),
434 finish_reason: FinishReason::Stop,
435 logprobs: None,
436 }],
437 usage: Usage::new(1, 1),
438 created: None,
439 provider: Some(self.name.to_string()),
440 healing_metadata: None,
441 })
442 }
443
444 async fn execute_stream(
445 &self,
446 _req: ProviderRequest,
447 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
448 {
449 let stream = if self.fail_after_first {
450 let items: Vec<Result<CompletionChunk>> = vec![
451 Ok(Self::build_chunk("chunk-1", "hello")),
452 Err(SimpleAgentsError::Provider(ProviderError::ServerError(
453 "stream error".to_string(),
454 ))),
455 ];
456 futures_util::stream::iter(items)
457 } else {
458 let items: Vec<Result<CompletionChunk>> =
459 vec![Ok(Self::build_chunk("chunk-1", "hello"))];
460 futures_util::stream::iter(items)
461 };
462
463 Ok(Box::new(stream))
464 }
465 }
466
467 #[tokio::test]
468 async fn streaming_returns_chunks() {
469 let provider = Arc::new(StreamingProvider::new("p1", false));
470 let client = SimpleAgentsClient::new(provider);
471
472 let request = CompletionRequest::builder()
473 .model("gpt-4")
474 .message(Message::user("Hi"))
475 .stream(true)
476 .build()
477 .unwrap();
478
479 let outcome = client
480 .complete(&request, CompletionOptions::default())
481 .await
482 .unwrap();
483
484 let mut collected = Vec::new();
485 match outcome {
486 CompletionOutcome::Stream(mut stream) => {
487 while let Some(chunk) = stream.next().await {
488 collected.push(chunk.unwrap());
489 }
490 }
491 _ => panic!("expected stream outcome"),
492 }
493
494 assert_eq!(collected.len(), 1);
495 }
496
497 #[tokio::test]
498 async fn streaming_propagates_error() {
499 let provider = Arc::new(StreamingProvider::new("p1", true));
500 let client = SimpleAgentsClient::new(provider);
501
502 let request = CompletionRequest::builder()
503 .model("gpt-4")
504 .message(Message::user("Hi"))
505 .stream(true)
506 .build()
507 .unwrap();
508
509 let outcome = client
510 .complete(&request, CompletionOptions::default())
511 .await
512 .unwrap();
513
514 let mut chunks = Vec::new();
515 match outcome {
516 CompletionOutcome::Stream(mut stream) => {
517 while let Some(chunk) = stream.next().await {
518 chunks.push(chunk);
519 }
520 }
521 _ => panic!("expected stream outcome"),
522 }
523
524 assert_eq!(chunks.len(), 2);
525 assert!(chunks[0].is_ok());
526 assert!(chunks[1].is_err());
527 }
528}