Skip to main content

simple_agents_core/
client.rs

1//! SimpleAgents client implementation.
2
3use crate::healing::{HealedJsonResponse, HealedSchemaResponse, HealingSettings};
4use crate::middleware::Middleware;
5use crate::routing::{RouterEngine, RoutingMode};
6use async_trait::async_trait;
7use futures_util::future::BoxFuture;
8use futures_util::stream::{self, Stream};
9use futures_util::StreamExt;
10use simple_agent_type::cache::Cache;
11use simple_agent_type::cache::CacheKey;
12use simple_agent_type::prelude::{
13    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
14};
15use simple_agents_healing::coercion::CoercionEngine;
16use simple_agents_healing::parser::JsonishParser;
17use simple_agents_healing::schema::Schema;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::RwLock;
22use tracing::debug;
23
24/// Mode for completion post-processing.
25#[derive(Clone)]
26pub enum CompletionMode {
27    /// Return the raw completion response.
28    Standard,
29    /// Parse the response content as JSON using healing.
30    HealedJson,
31    /// Parse and coerce the response into the provided schema.
32    CoercedSchema(Schema),
33}
34
35/// Options that control completion behavior.
36#[derive(Clone)]
37pub struct CompletionOptions {
38    /// Completion post-processing mode.
39    pub mode: CompletionMode,
40}
41
42impl Default for CompletionOptions {
43    fn default() -> Self {
44        Self {
45            mode: CompletionMode::Standard,
46        }
47    }
48}
49
50/// Result of a unified completion call.
51pub enum CompletionOutcome {
52    /// A standard, non-streaming completion response.
53    Response(CompletionResponse),
54    /// A streaming response yielding completion chunks.
55    Stream(Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>),
56    /// A healed JSON response.
57    HealedJson(HealedJsonResponse),
58    /// A schema-coerced response.
59    CoercedSchema(HealedSchemaResponse),
60}
61
62struct ClientState {
63    providers: Vec<Arc<dyn Provider>>,
64    provider_map: HashMap<String, Arc<dyn Provider>>,
65    router: Arc<RouterEngine>,
66}
67
68/// Unified SimpleAgents client.
69pub struct SimpleAgentsClient {
70    state: RwLock<ClientState>,
71    routing_mode: RoutingMode,
72    cache: Option<Arc<dyn Cache>>,
73    cache_ttl: Duration,
74    healing: HealingSettings,
75    middleware: Vec<Arc<dyn Middleware>>,
76}
77
78impl SimpleAgentsClient {
79    /// Start a new client builder.
80    pub fn builder() -> SimpleAgentsClientBuilder {
81        SimpleAgentsClientBuilder::new()
82    }
83
84    /// List registered provider names.
85    pub async fn provider_names(&self) -> Result<Vec<String>> {
86        let state = self.state.read().await;
87        Ok(state.provider_map.keys().cloned().collect())
88    }
89
90    /// Retrieve a provider by name.
91    pub async fn provider(&self, name: &str) -> Result<Option<Arc<dyn Provider>>> {
92        let state = self.state.read().await;
93        Ok(state.provider_map.get(name).cloned())
94    }
95
96    /// Register an additional provider and rebuild the router.
97    pub async fn register_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
98        let mut state = self.state.write().await;
99        let name = provider.name().to_string();
100
101        if state.provider_map.contains_key(&name) {
102            return Err(SimpleAgentsError::Config(format!(
103                "provider already registered: {}",
104                name
105            )));
106        }
107
108        state.provider_map.insert(name, provider.clone());
109        state.providers.push(provider);
110        state.router = Arc::new(self.routing_mode.build_router(state.providers.clone())?);
111        Ok(())
112    }
113
114    /// Execute a completion request with routing, caching, and middleware.
115    pub async fn complete(
116        &self,
117        request: &CompletionRequest,
118        options: CompletionOptions,
119    ) -> Result<CompletionOutcome> {
120        if request.stream.unwrap_or(false) {
121            let stream = self.stream(request).await?;
122            return Ok(CompletionOutcome::Stream(stream));
123        }
124
125        match options.mode {
126            CompletionMode::Standard => {
127                let response = self.complete_response(request).await?;
128                Ok(CompletionOutcome::Response(response))
129            }
130            CompletionMode::HealedJson => {
131                let healed = self.complete_json_internal(request).await?;
132                Ok(CompletionOutcome::HealedJson(healed))
133            }
134            CompletionMode::CoercedSchema(schema) => {
135                let healed = self.complete_with_schema_internal(request, &schema).await?;
136                Ok(CompletionOutcome::CoercedSchema(healed))
137            }
138        }
139    }
140
141    async fn complete_response(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
142        request.validate()?;
143        self.before_request(request).await?;
144
145        let cache_key = if let Some(cache) = &self.cache {
146            if cache.is_enabled() {
147                Some(self.cache_key(request)?)
148            } else {
149                None
150            }
151        } else {
152            None
153        };
154
155        if let (Some(cache), Some(key)) = (&self.cache, cache_key.as_deref()) {
156            if let Some(cached) = cache.get(key).await? {
157                let response: CompletionResponse = serde_json::from_slice(&cached)?;
158                self.on_cache_hit(request, &response).await?;
159                return Ok(response);
160            }
161        }
162
163        let start = Instant::now();
164        let router = {
165            let state = self.state.read().await;
166            state.router.clone()
167        };
168        let response = router.complete(request).await;
169
170        match response {
171            Ok(response) => {
172                self.after_response(request, &response, start.elapsed())
173                    .await?;
174                if let (Some(cache), Some(key)) = (&self.cache, cache_key) {
175                    let payload = serde_json::to_vec(&response)?;
176                    cache.set(&key, payload, self.cache_ttl).await?;
177                }
178                Ok(response)
179            }
180            Err(error) => {
181                self.on_error(request, &error, start.elapsed()).await?;
182                Err(error)
183            }
184        }
185    }
186
187    /// Execute a completion request and parse the response content as JSON.
188    async fn complete_json_internal(
189        &self,
190        request: &CompletionRequest,
191    ) -> Result<HealedJsonResponse> {
192        self.ensure_healing_enabled()?;
193        let response = self.complete_response(request).await?;
194        let content = response.content().ok_or_else(|| {
195            SimpleAgentsError::Healing(simple_agent_type::error::HealingError::ParseFailed {
196                error_message: "response contained no content".to_string(),
197                input: String::new(),
198            })
199        })?;
200
201        let parser = JsonishParser::with_config(self.healing.parser_config.clone());
202        let parsed = parser.parse(content)?;
203
204        Ok(HealedJsonResponse { response, parsed })
205    }
206
207    /// Execute a completion request and coerce the response into a schema.
208    async fn complete_with_schema_internal(
209        &self,
210        request: &CompletionRequest,
211        schema: &Schema,
212    ) -> Result<HealedSchemaResponse> {
213        self.ensure_healing_enabled()?;
214        let healed = self.complete_json_internal(request).await?;
215        let engine = CoercionEngine::with_config(self.healing.coercion_config.clone());
216        let coerced = engine
217            .coerce(&healed.parsed.value, schema)
218            .map_err(SimpleAgentsError::Healing)?;
219
220        Ok(HealedSchemaResponse {
221            response: healed.response,
222            parsed: healed.parsed,
223            coerced,
224        })
225    }
226
227    /// Execute a streaming completion request.
228    async fn stream(
229        &self,
230        request: &CompletionRequest,
231    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
232        request.validate()?;
233        self.before_request(request).await?;
234        debug!(
235            model = %request.model,
236            stream = ?request.stream,
237            "SimpleAgentsClient.stream start"
238        );
239
240        let router = {
241            let state = self.state.read().await;
242            state.router.clone()
243        };
244
245        let start = Instant::now();
246        let middleware = self.middleware.clone();
247        let instrumented_request = request.clone();
248        let inner = router.stream(request).await?;
249
250        let wrapped = Self::instrument_stream(inner, instrumented_request, middleware, start);
251        Ok(Box::new(wrapped))
252    }
253
254    fn ensure_healing_enabled(&self) -> Result<()> {
255        if self.healing.enabled {
256            Ok(())
257        } else {
258            Err(SimpleAgentsError::Config(
259                "healing is disabled for this client".to_string(),
260            ))
261        }
262    }
263
264    fn cache_key(&self, request: &CompletionRequest) -> Result<String> {
265        let serialized = serde_json::to_string(request)?;
266        Ok(CacheKey::from_parts("core", &request.model, &serialized))
267    }
268
269    async fn before_request(&self, request: &CompletionRequest) -> Result<()> {
270        for middleware in &self.middleware {
271            middleware.before_request(request).await?;
272        }
273        Ok(())
274    }
275
276    async fn after_response(
277        &self,
278        request: &CompletionRequest,
279        response: &CompletionResponse,
280        latency: Duration,
281    ) -> Result<()> {
282        for middleware in &self.middleware {
283            middleware
284                .after_response(request, response, latency)
285                .await?;
286        }
287        Ok(())
288    }
289
290    async fn on_cache_hit(
291        &self,
292        request: &CompletionRequest,
293        response: &CompletionResponse,
294    ) -> Result<()> {
295        for middleware in &self.middleware {
296            middleware.on_cache_hit(request, response).await?;
297        }
298        Ok(())
299    }
300
301    async fn on_error(
302        &self,
303        request: &CompletionRequest,
304        error: &SimpleAgentsError,
305        latency: Duration,
306    ) -> Result<()> {
307        for middleware in &self.middleware {
308            middleware.on_error(request, error, latency).await?;
309        }
310        Ok(())
311    }
312}
313
314impl SimpleAgentsClient {
315    fn instrument_stream(
316        inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
317        request: CompletionRequest,
318        middleware: Vec<Arc<dyn Middleware>>,
319        start: Instant,
320    ) -> impl Stream<Item = Result<CompletionChunk>> + Send + Unpin {
321        struct StreamState {
322            inner: Box<dyn Stream<Item = Result<CompletionChunk>> + Send + Unpin>,
323            middleware: Vec<Arc<dyn Middleware>>,
324            request: CompletionRequest,
325            start: Instant,
326            done: bool,
327        }
328
329        stream::unfold(
330            StreamState {
331                inner,
332                middleware,
333                request,
334                start,
335                done: false,
336            },
337            |mut state| -> BoxFuture<Option<(Result<CompletionChunk>, StreamState)>> {
338                Box::pin(async move {
339                    if state.done {
340                        return None;
341                    }
342
343                    match state.inner.next().await {
344                        Some(Ok(chunk)) => Some((Ok(chunk), state)),
345                        Some(Err(err)) => {
346                            let latency = state.start.elapsed();
347                            for middleware in &state.middleware {
348                                if let Err(mw_err) =
349                                    middleware.on_error(&state.request, &err, latency).await
350                                {
351                                    state.done = true;
352                                    return Some((Err(mw_err), state));
353                                }
354                            }
355                            state.done = true;
356                            Some((Err(err), state))
357                        }
358                        None => {
359                            let latency = state.start.elapsed();
360                            for middleware in &state.middleware {
361                                if let Err(mw_err) =
362                                    middleware.after_stream(&state.request, latency).await
363                                {
364                                    state.done = true;
365                                    return Some((Err(mw_err), state));
366                                }
367                            }
368                            None
369                        }
370                    }
371                })
372            },
373        )
374    }
375}
376
377/// Builder for `SimpleAgentsClient`.
378pub struct SimpleAgentsClientBuilder {
379    providers: Vec<Arc<dyn Provider>>,
380    routing_mode: RoutingMode,
381    cache: Option<Arc<dyn Cache>>,
382    cache_ttl: Duration,
383    healing: HealingSettings,
384    middleware: Vec<Arc<dyn Middleware>>,
385}
386
387impl SimpleAgentsClientBuilder {
388    /// Create a new builder with defaults.
389    pub fn new() -> Self {
390        Self {
391            providers: Vec::new(),
392            routing_mode: RoutingMode::default(),
393            cache: None,
394            cache_ttl: Duration::from_secs(60),
395            healing: HealingSettings::default(),
396            middleware: Vec::new(),
397        }
398    }
399
400    /// Register a provider.
401    pub fn with_provider(mut self, provider: Arc<dyn Provider>) -> Self {
402        self.providers.push(provider);
403        self
404    }
405
406    /// Register multiple providers at once.
407    pub fn with_providers(mut self, providers: Vec<Arc<dyn Provider>>) -> Self {
408        self.providers.extend(providers);
409        self
410    }
411
412    /// Configure routing mode.
413    pub fn with_routing_mode(mut self, mode: RoutingMode) -> Self {
414        self.routing_mode = mode;
415        self
416    }
417
418    /// Configure response cache.
419    pub fn with_cache(mut self, cache: Arc<dyn Cache>) -> Self {
420        self.cache = Some(cache);
421        self
422    }
423
424    /// Configure cache TTL.
425    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
426        self.cache_ttl = ttl;
427        self
428    }
429
430    /// Configure healing settings.
431    pub fn with_healing_settings(mut self, settings: HealingSettings) -> Self {
432        self.healing = settings;
433        self
434    }
435
436    /// Register a middleware hook.
437    pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
438        self.middleware.push(middleware);
439        self
440    }
441
442    /// Build the client.
443    pub fn build(self) -> Result<SimpleAgentsClient> {
444        if self.providers.is_empty() {
445            return Err(SimpleAgentsError::Config(
446                "at least one provider is required".to_string(),
447            ));
448        }
449
450        let mut seen = HashSet::new();
451        for provider in &self.providers {
452            let name = provider.name();
453            if !seen.insert(name.to_string()) {
454                return Err(SimpleAgentsError::Config(format!(
455                    "duplicate provider configured in builder: {}",
456                    name
457                )));
458            }
459        }
460
461        let provider_map = self
462            .providers
463            .iter()
464            .map(|provider| (provider.name().to_string(), provider.clone()))
465            .collect::<HashMap<_, _>>();
466
467        let router = Arc::new(self.routing_mode.build_router(self.providers.clone())?);
468        let state = ClientState {
469            providers: self.providers,
470            provider_map,
471            router,
472        };
473
474        Ok(SimpleAgentsClient {
475            state: RwLock::new(state),
476            routing_mode: self.routing_mode,
477            cache: self.cache,
478            cache_ttl: self.cache_ttl,
479            healing: self.healing,
480            middleware: self.middleware,
481        })
482    }
483}
484
485impl Default for SimpleAgentsClientBuilder {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491#[async_trait]
492impl Middleware for () {
493    async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
494        Ok(())
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use futures_util::{stream, StreamExt};
502    use simple_agent_type::error::ProviderError;
503    use simple_agent_type::prelude::*;
504    use std::sync::atomic::{AtomicUsize, Ordering};
505    use std::time::Duration;
506
507    struct MockProvider {
508        name: &'static str,
509        calls: AtomicUsize,
510    }
511
512    impl MockProvider {
513        fn new(name: &'static str) -> Self {
514            Self {
515                name,
516                calls: AtomicUsize::new(0),
517            }
518        }
519    }
520
521    #[async_trait]
522    impl Provider for MockProvider {
523        fn name(&self) -> &str {
524            self.name
525        }
526
527        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
528            Ok(ProviderRequest::new("http://example.com"))
529        }
530
531        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
532            self.calls.fetch_add(1, Ordering::Relaxed);
533            Ok(ProviderResponse::new(
534                200,
535                serde_json::json!({"content": "ok"}),
536            ))
537        }
538
539        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
540            Ok(CompletionResponse {
541                id: "resp_test".to_string(),
542                model: "test-model".to_string(),
543                choices: vec![CompletionChoice {
544                    index: 0,
545                    message: Message::assistant("ok"),
546                    finish_reason: FinishReason::Stop,
547                    logprobs: None,
548                }],
549                usage: Usage::new(1, 1),
550                created: None,
551                provider: Some(self.name.to_string()),
552                healing_metadata: None,
553            })
554        }
555    }
556
557    #[tokio::test]
558    async fn client_build_requires_provider() {
559        let result = SimpleAgentsClientBuilder::new().build();
560        assert!(result.is_err());
561    }
562
563    #[tokio::test]
564    async fn register_provider_rebuilds_router() {
565        let provider = Arc::new(MockProvider::new("p1"));
566        let client = SimpleAgentsClientBuilder::new()
567            .with_provider(provider)
568            .build()
569            .unwrap();
570
571        let second = Arc::new(MockProvider::new("p2"));
572        client.register_provider(second).await.unwrap();
573
574        let names = client.provider_names().await.unwrap();
575        assert!(names.contains(&"p1".to_string()));
576        assert!(names.contains(&"p2".to_string()));
577    }
578
579    #[tokio::test]
580    async fn duplicate_provider_registration_fails() {
581        let provider = Arc::new(MockProvider::new("p1"));
582        let client = SimpleAgentsClientBuilder::new()
583            .with_provider(provider.clone())
584            .build()
585            .unwrap();
586
587        let result = client.register_provider(provider).await;
588        assert!(matches!(
589            result,
590            Err(SimpleAgentsError::Config(msg)) if msg.contains("provider already registered")
591        ));
592    }
593
594    #[tokio::test]
595    async fn duplicate_provider_in_builder_with_provider_fails() {
596        let p1 = Arc::new(MockProvider::new("p1"));
597        let p1_dup = Arc::new(MockProvider::new("p1"));
598
599        let result = SimpleAgentsClientBuilder::new()
600            .with_provider(p1)
601            .with_provider(p1_dup)
602            .build();
603
604        assert!(matches!(
605            result,
606            Err(SimpleAgentsError::Config(msg)) if msg.contains("duplicate provider configured in builder")
607        ));
608    }
609
610    #[tokio::test]
611    async fn duplicate_provider_in_builder_with_providers_fails() {
612        let result = SimpleAgentsClientBuilder::new()
613            .with_providers(vec![
614                Arc::new(MockProvider::new("p1")),
615                Arc::new(MockProvider::new("p1")),
616            ])
617            .build();
618
619        assert!(matches!(
620            result,
621            Err(SimpleAgentsError::Config(msg)) if msg.contains("duplicate provider configured in builder")
622        ));
623    }
624
625    #[derive(Default)]
626    struct RecordingMiddleware {
627        before: AtomicUsize,
628        after_stream: AtomicUsize,
629        errors: AtomicUsize,
630    }
631
632    #[async_trait]
633    impl Middleware for RecordingMiddleware {
634        async fn before_request(&self, _request: &CompletionRequest) -> Result<()> {
635            self.before.fetch_add(1, Ordering::Relaxed);
636            Ok(())
637        }
638
639        async fn after_stream(
640            &self,
641            _request: &CompletionRequest,
642            _latency: Duration,
643        ) -> Result<()> {
644            self.after_stream.fetch_add(1, Ordering::Relaxed);
645            Ok(())
646        }
647
648        async fn on_error(
649            &self,
650            _request: &CompletionRequest,
651            _error: &SimpleAgentsError,
652            _latency: Duration,
653        ) -> Result<()> {
654            self.errors.fetch_add(1, Ordering::Relaxed);
655            Ok(())
656        }
657
658        fn name(&self) -> &str {
659            "recording"
660        }
661    }
662
663    struct StreamingProvider {
664        name: &'static str,
665        fail_after_first: bool,
666    }
667
668    impl StreamingProvider {
669        fn new(name: &'static str, fail_after_first: bool) -> Self {
670            Self {
671                name,
672                fail_after_first,
673            }
674        }
675
676        fn build_chunk(id: &str, content: &str) -> CompletionChunk {
677            CompletionChunk {
678                id: id.to_string(),
679                model: "test-model".to_string(),
680                choices: vec![ChoiceDelta {
681                    index: 0,
682                    delta: MessageDelta {
683                        role: Some(Role::Assistant),
684                        content: Some(content.to_string()),
685                        reasoning_content: None,
686                        tool_calls: None,
687                    },
688                    finish_reason: None,
689                }],
690                created: None,
691                usage: None,
692            }
693        }
694    }
695
696    #[async_trait]
697    impl Provider for StreamingProvider {
698        fn name(&self) -> &str {
699            self.name
700        }
701
702        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
703            Ok(ProviderRequest::new("http://example.com"))
704        }
705
706        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
707            Ok(ProviderResponse::new(
708                200,
709                serde_json::json!({"content": "ok"}),
710            ))
711        }
712
713        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
714            Ok(CompletionResponse {
715                id: "resp_stream".to_string(),
716                model: "test-model".to_string(),
717                choices: vec![CompletionChoice {
718                    index: 0,
719                    message: Message::assistant("ok"),
720                    finish_reason: FinishReason::Stop,
721                    logprobs: None,
722                }],
723                usage: Usage::new(1, 1),
724                created: None,
725                provider: Some(self.name.to_string()),
726                healing_metadata: None,
727            })
728        }
729
730        async fn execute_stream(
731            &self,
732            _req: ProviderRequest,
733        ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>>
734        {
735            let stream = if self.fail_after_first {
736                let items: Vec<Result<CompletionChunk>> = vec![
737                    Ok(Self::build_chunk("chunk-1", "hello")),
738                    Err(SimpleAgentsError::Provider(ProviderError::ServerError(
739                        "stream error".to_string(),
740                    ))),
741                ];
742                stream::iter(items)
743            } else {
744                let items: Vec<Result<CompletionChunk>> =
745                    vec![Ok(Self::build_chunk("chunk-1", "hello"))];
746                stream::iter(items)
747            };
748
749            Ok(Box::new(stream))
750        }
751    }
752
753    #[tokio::test]
754    async fn streaming_invokes_after_stream_on_success() {
755        let provider = Arc::new(StreamingProvider::new("p1", false));
756        let middleware = Arc::new(RecordingMiddleware::default());
757
758        let client = SimpleAgentsClientBuilder::new()
759            .with_provider(provider)
760            .with_middleware(middleware.clone())
761            .build()
762            .unwrap();
763
764        let request = CompletionRequest::builder()
765            .model("gpt-4")
766            .message(Message::user("Hi"))
767            .stream(true)
768            .build()
769            .unwrap();
770
771        let outcome = client
772            .complete(&request, CompletionOptions::default())
773            .await
774            .unwrap();
775
776        let mut collected = Vec::new();
777        match outcome {
778            CompletionOutcome::Stream(mut stream) => {
779                while let Some(chunk) = stream.next().await {
780                    collected.push(chunk.unwrap());
781                }
782            }
783            _ => panic!("expected stream outcome"),
784        }
785
786        assert_eq!(collected.len(), 1);
787        assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
788        assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 1);
789        assert_eq!(middleware.errors.load(Ordering::Relaxed), 0);
790    }
791
792    #[tokio::test]
793    async fn streaming_invokes_on_error_on_failure() {
794        let provider = Arc::new(StreamingProvider::new("p1", true));
795        let middleware = Arc::new(RecordingMiddleware::default());
796
797        let client = SimpleAgentsClientBuilder::new()
798            .with_provider(provider)
799            .with_middleware(middleware.clone())
800            .build()
801            .unwrap();
802
803        let request = CompletionRequest::builder()
804            .model("gpt-4")
805            .message(Message::user("Hi"))
806            .stream(true)
807            .build()
808            .unwrap();
809
810        let outcome = client
811            .complete(&request, CompletionOptions::default())
812            .await
813            .unwrap();
814
815        let mut chunks = Vec::new();
816        match outcome {
817            CompletionOutcome::Stream(mut stream) => {
818                while let Some(chunk) = stream.next().await {
819                    chunks.push(chunk);
820                }
821            }
822            _ => panic!("expected stream outcome"),
823        }
824
825        assert_eq!(middleware.before.load(Ordering::Relaxed), 1);
826        assert_eq!(middleware.after_stream.load(Ordering::Relaxed), 0);
827        assert_eq!(middleware.errors.load(Ordering::Relaxed), 1);
828        assert_eq!(chunks.len(), 2);
829        assert!(chunks[0].as_ref().is_ok());
830        assert!(chunks[1].is_err());
831    }
832}