Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use serde::{Deserialize, Serialize};
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use simple_agent_type::telemetry::{ApiFormat, TelemetryConfig, TraceContext};
9use simple_agents_healing::coercion::CoercionEngine;
10use simple_agents_healing::parser::JsonishParser;
11use simple_agents_healing::schema::Schema;
12use std::sync::Arc;
13use tracing::debug;
14
15// ---------------------------------------------------------------------------
16// Configuration types
17// ---------------------------------------------------------------------------
18
19/// Retry behaviour for failed LLM requests.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RetryConfig {
22    /// Maximum number of attempts (including the initial one).
23    pub max_attempts: u8,
24    /// Base back-off between retries in milliseconds.
25    pub backoff_ms: u64,
26}
27
28impl Default for RetryConfig {
29    fn default() -> Self {
30        Self {
31            max_attempts: 3,
32            backoff_ms: 1000,
33        }
34    }
35}
36
37/// Top-level configuration for a [`SimpleAgentsClient`].
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ClientConfig {
40    /// Provider identifier (e.g. `"openai"`, `"anthropic"`).
41    pub provider: String,
42    /// API key used for authentication with the provider.
43    pub api_key: String,
44    /// Optional custom base URL for the provider API.
45    pub base_url: Option<String>,
46    /// Wire format for the provider API.
47    pub api_format: ApiFormat,
48    /// Extra HTTP headers sent with every request.
49    pub extra_headers: Option<Vec<(String, String)>>,
50    /// Optional telemetry / OTEL export configuration.
51    pub telemetry: Option<TelemetryConfig>,
52    /// Default retry policy applied to all requests.
53    pub default_retry: RetryConfig,
54}
55
56impl Default for ClientConfig {
57    fn default() -> Self {
58        Self {
59            provider: "openai".into(),
60            api_key: String::new(),
61            base_url: None,
62            api_format: ApiFormat::default(),
63            extra_headers: None,
64            telemetry: None,
65            default_retry: RetryConfig::default(),
66        }
67    }
68}
69
70/// Flags that control execution behaviour for a single run.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ExecutionFlags {
73    /// Whether the workflow-level orchestrator streams between nodes.
74    pub workflow_streaming: bool,
75    /// Whether individual LLM calls within a node use streaming.
76    pub node_llm_streaming: bool,
77}
78
79impl Default for ExecutionFlags {
80    fn default() -> Self {
81        Self {
82            workflow_streaming: false,
83            node_llm_streaming: true,
84        }
85    }
86}
87
88/// Per-run options passed to the executor.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RunOptions {
91    /// Print timing / token-usage statistics after the run.
92    pub nerdstats: bool,
93    /// Whether to emit telemetry spans for this run.
94    pub telemetry_enabled: bool,
95    /// Distributed trace context to propagate.
96    pub trace_context: Option<TraceContext>,
97    /// Execution behaviour flags.
98    pub execution_flags: ExecutionFlags,
99}
100
101impl Default for RunOptions {
102    fn default() -> Self {
103        Self {
104            nerdstats: true,
105            telemetry_enabled: true,
106            trace_context: None,
107            execution_flags: ExecutionFlags::default(),
108        }
109    }
110}
111
112/// Mode for completion post-processing.
113#[derive(Clone)]
114pub enum CompletionMode {
115    /// Return the raw completion response.
116    Standard,
117    /// Parse the response content as JSON using healing.
118    HealedJson,
119    /// Parse and coerce the response into the provided schema.
120    CoercedSchema(Schema),
121}
122
123/// Options that control completion behavior.
124#[derive(Clone)]
125pub struct CompletionOptions {
126    /// Completion post-processing mode.
127    pub mode: CompletionMode,
128}
129
130impl Default for CompletionOptions {
131    fn default() -> Self {
132        Self {
133            mode: CompletionMode::Standard,
134        }
135    }
136}
137
138/// Result of a unified completion call.
139pub enum CompletionOutcome {
140    /// A standard, non-streaming completion response.
141    Response(CompletionResponse),
142    /// A streaming response yielding completion chunks.
143    Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
144    /// A healed JSON response.
145    HealedJson(HealedJsonResponse),
146    /// A schema-coerced response.
147    CoercedSchema(HealedSchemaResponse),
148}
149
150/// Unified SimpleAgents client.
151pub struct SimpleAgentsClient {
152    provider: Arc<dyn Provider>,
153    config: ClientConfig,
154    healing: HealingSettings,
155}
156
157impl SimpleAgentsClient {
158    /// Create a new client wrapping a single provider with default config and healing.
159    pub fn new(provider: Arc<dyn Provider>) -> Self {
160        Self {
161            provider,
162            config: ClientConfig::default(),
163            healing: HealingSettings::default(),
164        }
165    }
166
167    /// Create a new client from a [`ClientConfig`], using the supplied provider.
168    pub fn from_config(provider: Arc<dyn Provider>, config: ClientConfig) -> Self {
169        Self {
170            provider,
171            config,
172            healing: HealingSettings::default(),
173        }
174    }
175
176    /// Create a new client with custom healing settings.
177    pub fn with_healing(provider: Arc<dyn Provider>, healing: HealingSettings) -> Self {
178        Self {
179            provider,
180            config: ClientConfig::default(),
181            healing,
182        }
183    }
184
185    /// Return a reference to the client's configuration.
186    pub fn config(&self) -> &ClientConfig {
187        &self.config
188    }
189
190    /// Return the name of the underlying provider.
191    pub fn provider_name(&self) -> &str {
192        self.provider.name()
193    }
194
195    /// Execute a completion request.
196    pub async fn complete(
197        &self,
198        request: &CompletionRequest,
199        options: CompletionOptions,
200    ) -> Result<CompletionOutcome> {
201        if request.stream.unwrap_or(false) {
202            let stream = self.stream(request).await?;
203            return Ok(CompletionOutcome::Stream(stream));
204        }
205
206        match options.mode {
207            CompletionMode::Standard => {
208                let response = self.complete_response(request).await?;
209                Ok(CompletionOutcome::Response(response))
210            }
211            CompletionMode::HealedJson => {
212                let healed = self.complete_json_internal(request).await?;
213                Ok(CompletionOutcome::HealedJson(healed))
214            }
215            CompletionMode::CoercedSchema(schema) => {
216                let healed = self.complete_with_schema_internal(request, &schema).await?;
217                Ok(CompletionOutcome::CoercedSchema(healed))
218            }
219        }
220    }
221
222    async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
223        request.validate()?;
224
225        let provider_request = self.provider.transform_request(request)?;
226        let provider_response = self.provider.execute(provider_request).await?;
227        self.provider.transform_response(provider_response)
228    }
229
230    async fn complete_json_internal(
231        &self,
232        request: &CompletionRequest,
233    ) -> Result<HealedJsonResponse> {
234        self.ensure_healing_enabled()?;
235        let response = self.complete_response(request).await?;
236        let content = response.content().ok_or_else(|| {
237            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
238                error_message: "response contained no content".to_string(),
239                input: String::new(),
240            })
241        })?;
242
243        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
244        let parsed = parser.parse(content)?;
245
246        Ok(HealedJsonResponse { response, parsed })
247    }
248
249    async fn complete_with_schema_internal(
250        &self,
251        request: &CompletionRequest,
252        schema: &Schema,
253    ) -> Result<HealedSchemaResponse> {
254        self.ensure_healing_enabled()?;
255        let healed = self.complete_json_internal(request).await?;
256        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
257        let coerced = engine
258            .coerce(&healed.parsed.value, schema)
259            .map_err(SimpleAgentsError::Healing)?;
260
261        Ok(HealedSchemaResponse {
262            response: healed.response,
263            parsed: healed.parsed,
264            coerced,
265        })
266    }
267
268    async fn stream(
269        &self,
270        request: &CompletionRequest,
271    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
272        request.validate()?;
273        debug!(
274            model = %request.model,
275            stream = ?request.stream,
276            "SimpleAgentsClient.stream start"
277        );
278
279        let provider_request = self.provider.transform_request(request)?;
280        self.provider.execute_stream(provider_request).await
281    }
282
283    fn ensure_healing_enabled(&self) -> Result<()> {
284        if self.healing.enabled {
285            Ok(())
286        } else {
287            Err(SimpleAgentsError::Config(
288                "healing is disabled for this client".to_string(),
289            ))
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use async_trait::async_trait;
298    use futures_util::StreamExt;
299    use simple_agent_type::error::ProviderError;
300    use simple_agent_type::prelude::*;
301    use std::sync::atomic::{AtomicUsize, Ordering};
302
303    struct MockProvider {
304        name: &'static str,
305        calls: AtomicUsize,
306    }
307
308    impl MockProvider {
309        fn new(name: &'static str) -> Self {
310            Self {
311                name,
312                calls: AtomicUsize::new(0),
313            }
314        }
315    }
316
317    #[async_trait]
318    impl Provider for MockProvider {
319        fn name(&self) -> &str {
320            self.name
321        }
322
323        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
324            Ok(ProviderRequest::new("http://example.com"))
325        }
326
327        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
328            self.calls.fetch_add(1, Ordering::Relaxed);
329            Ok(ProviderResponse::new(
330                200,
331                serde_json::json!({"content": "ok"}),
332            ))
333        }
334
335        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
336            Ok(CompletionResponse {
337                id: "resp_test".to_string(),
338                model: "test-model".to_string(),
339                choices: vec![CompletionChoice {
340                    index: 0,
341                    message: Message::assistant("ok"),
342                    finish_reason: FinishReason::Stop,
343                    logprobs: None,
344                }],
345                usage: Usage::new(1, 1),
346                created: None,
347                provider: Some(self.name.to_string()),
348                healing_metadata: None,
349            })
350        }
351    }
352
353    #[tokio::test]
354    async fn complete_returns_response() {
355        let provider = Arc::new(MockProvider::new("p1"));
356        let client = SimpleAgentsClient::new(provider);
357
358        let request = CompletionRequest::builder()
359            .model("gpt-4")
360            .message(Message::user("Hi"))
361            .build()
362            .unwrap();
363
364        let outcome = client
365            .complete(&request, CompletionOptions::default())
366            .await
367            .unwrap();
368
369        match outcome {
370            CompletionOutcome::Response(resp) => {
371                assert_eq!(resp.provider.as_deref(), Some("p1"));
372            }
373            _ => panic!("expected Response outcome"),
374        }
375    }
376
377    struct StreamingProvider {
378        name: &'static str,
379        fail_after_first: bool,
380    }
381
382    impl StreamingProvider {
383        fn new(name: &'static str, fail_after_first: bool) -> Self {
384            Self {
385                name,
386                fail_after_first,
387            }
388        }
389
390        fn build_chunk(id: &str, content: &str) -> CompletionChunk {
391            CompletionChunk {
392                id: id.to_string(),
393                model: "test-model".to_string(),
394                choices: vec![ChoiceDelta {
395                    index: 0,
396                    delta: MessageDelta {
397                        role: Some(Role::Assistant),
398                        content: Some(content.to_string()),
399                        reasoning_content: None,
400                        tool_calls: None,
401                    },
402                    finish_reason: None,
403                }],
404                created: None,
405                usage: None,
406            }
407        }
408    }
409
410    #[async_trait]
411    impl Provider for StreamingProvider {
412        fn name(&self) -> &str {
413            self.name
414        }
415
416        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
417            Ok(ProviderRequest::new("http://example.com"))
418        }
419
420        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
421            Ok(ProviderResponse::new(
422                200,
423                serde_json::json!({"content": "ok"}),
424            ))
425        }
426
427        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
428            Ok(CompletionResponse {
429                id: "resp_stream".to_string(),
430                model: "test-model".to_string(),
431                choices: vec![CompletionChoice {
432                    index: 0,
433                    message: Message::assistant("ok"),
434                    finish_reason: FinishReason::Stop,
435                    logprobs: None,
436                }],
437                usage: Usage::new(1, 1),
438                created: None,
439                provider: Some(self.name.to_string()),
440                healing_metadata: None,
441            })
442        }
443
444        async fn execute_stream(
445            &self,
446            _req: ProviderRequest,
447        ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
448        {
449            let stream = if self.fail_after_first {
450                let items: Vec<Result<CompletionChunk>> = vec![
451                    Ok(Self::build_chunk("chunk-1", "hello")),
452                    Err(SimpleAgentsError::Provider(ProviderError::ServerError(
453                        "stream error".to_string(),
454                    ))),
455                ];
456                futures_util::stream::iter(items)
457            } else {
458                let items: Vec<Result<CompletionChunk>> =
459                    vec![Ok(Self::build_chunk("chunk-1", "hello"))];
460                futures_util::stream::iter(items)
461            };
462
463            Ok(Box::new(stream))
464        }
465    }
466
467    #[tokio::test]
468    async fn streaming_returns_chunks() {
469        let provider = Arc::new(StreamingProvider::new("p1", false));
470        let client = SimpleAgentsClient::new(provider);
471
472        let request = CompletionRequest::builder()
473            .model("gpt-4")
474            .message(Message::user("Hi"))
475            .stream(true)
476            .build()
477            .unwrap();
478
479        let outcome = client
480            .complete(&request, CompletionOptions::default())
481            .await
482            .unwrap();
483
484        let mut collected = Vec::new();
485        match outcome {
486            CompletionOutcome::Stream(mut stream) => {
487                while let Some(chunk) = stream.next().await {
488                    collected.push(chunk.unwrap());
489                }
490            }
491            _ => panic!("expected stream outcome"),
492        }
493
494        assert_eq!(collected.len(), 1);
495    }
496
497    #[tokio::test]
498    async fn streaming_propagates_error() {
499        let provider = Arc::new(StreamingProvider::new("p1", true));
500        let client = SimpleAgentsClient::new(provider);
501
502        let request = CompletionRequest::builder()
503            .model("gpt-4")
504            .message(Message::user("Hi"))
505            .stream(true)
506            .build()
507            .unwrap();
508
509        let outcome = client
510            .complete(&request, CompletionOptions::default())
511            .await
512            .unwrap();
513
514        let mut chunks = Vec::new();
515        match outcome {
516            CompletionOutcome::Stream(mut stream) => {
517                while let Some(chunk) = stream.next().await {
518                    chunks.push(chunk);
519                }
520            }
521            _ => panic!("expected stream outcome"),
522        }
523
524        assert_eq!(chunks.len(), 2);
525        assert!(chunks[0].is_ok());
526        assert!(chunks[1].is_err());
527    }
528}