tower_llm/streaming/
mod.rs

1//! Streaming step/agent variants
2//!
3//! What this module provides (spec)
4//! - A streaming version of the step and loop that emits tokens/tool events incrementally
5//! - Tap APIs for UIs without breaking Tower composition
6//!
7//! Exports
8//! - Models
9//!   - `StepChunk::{Token(String), ToolCallStart{ id, name, args }, ToolCallEnd{ id, output }, UsageDelta{...} }`
10//!   - `AgentEvent` mirroring the above at agent layer boundaries
11//! - Services
12//!   - `StepStream: Service<RawChatRequest, Response=impl Stream<Item=StepChunk>>`
13//! - Layers
14//!   - `AgentLoopStreamLayer<S>` where `S: Service<RawChatRequest, Response=Stream<StepChunk>>`
15//!   - `StreamTapLayer<S>` to tee events to an injected sink (observer)
16//! - Utils
17//!   - `collect_final(stream) -> AgentRun` to remain API-compatible with non-streaming callers
18//!
19//! Implementation strategy
20//! - Provider adapter translates SSE/streaming API into `StepChunk` stream
21//! - Loop layer buffers minimal state (e.g., current messages, pending tool_calls), evaluates policy on-the-fly
22//! - Ensure back-pressure: do not buffer entire streams; forward as items arrive
23//! - Error handling: surface provider/tool errors as terminal `AgentEvent::Error`
24//!
25//! Composition
26//! - `ServiceBuilder::new().layer(StreamTapLayer::new(sink)).layer(AgentLoopStreamLayer::new(policy)).service(step_stream)`
27//!
28//! Testing strategy
29//! - Fake provider that yields a scripted sequence of chunks (tokens → tool_call → outputs → final)
30//! - Assert policy-controlled termination (e.g., until tool_called("x"))
31//! - Verify tap receives the exact event sequence; no extra buffering or reordering
32
33use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36
37use async_openai::types::{
38    ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
39    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
40    CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
41};
42use futures::{Stream, StreamExt};
43use serde_json::Value;
44use tokio::sync::{mpsc, Semaphore};
45use tokio_stream::wrappers::ReceiverStream;
46use tower::{BoxError, Layer, Service, ServiceExt};
47
48use crate::core::{
49    AgentPolicy, AgentRun, LoopState, StepAux, StepOutcome, ToolInvocation, ToolJoinPolicy,
50    ToolOutput,
51};
52
53/// Streaming step-level items.
54#[derive(Debug, Clone)]
55pub enum StepChunk {
56    Token(String),
57    ToolCallStart {
58        id: String,
59        name: String,
60        arguments: Value,
61    },
62    ToolCallEnd {
63        id: String,
64        output: Value,
65    },
66    UsageDelta {
67        prompt_tokens: usize,
68        completion_tokens: usize,
69    },
70    /// Terminal item that signals the end of a step and carries the outcome
71    StepComplete {
72        outcome: StepOutcome,
73    },
74    /// Non-fatal error surfaced as a terminal event
75    Error(String),
76}
77
78/// A provider that yields an assistant response as a stream of `StepChunk`s.
79///
80/// This remains abstract so tests can inject a fake provider; a real provider
81/// can live in `next::provider` and adapt OpenAI SSE to this interface.
82pub trait StepProvider: Send + Sync + 'static {
83    type Stream: Stream<Item = StepChunk> + Send + 'static;
84    fn stream_step(
85        &self,
86        req: CreateChatCompletionRequest,
87    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>>;
88}
89
90/// Service that executes a single step and returns a stream of `StepChunk`s.
91///
92/// It delegates token/tool-call streaming to a `StepProvider` and is responsible
93/// for invoking tools when requested and yielding `ToolCallEnd` events and the
94/// final `StepComplete` outcome.
95pub struct StepStreamService<P, T> {
96    provider: Arc<P>,
97    tools: T, // routed tool service (clonable for parallel)
98    instructions: Option<String>,
99    parallel_tools: bool,
100    tool_concurrency_limit: Option<usize>,
101    join_policy: ToolJoinPolicy,
102}
103
104impl<P, T> StepStreamService<P, T> {
105    pub fn new(provider: Arc<P>, tools: T) -> Self {
106        Self {
107            provider,
108            tools,
109            instructions: None,
110            parallel_tools: false,
111            tool_concurrency_limit: None,
112            join_policy: ToolJoinPolicy::FailFast,
113        }
114    }
115
116    pub fn parallel_tools(mut self, enabled: bool) -> Self {
117        self.parallel_tools = enabled;
118        self
119    }
120
121    pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
122        self.tool_concurrency_limit = Some(limit);
123        self
124    }
125
126    pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
127        self.join_policy = policy;
128        self
129    }
130
131    pub fn instructions(mut self, text: impl Into<String>) -> Self {
132        self.instructions = Some(text.into());
133        self
134    }
135}
136
137impl<P, T> Service<CreateChatCompletionRequest> for StepStreamService<P, T>
138where
139    P: StepProvider,
140    T: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
141    T::Future: Send + 'static,
142{
143    type Response = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
144    type Error = BoxError;
145    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
146
147    fn poll_ready(
148        &mut self,
149        _cx: &mut std::task::Context<'_>,
150    ) -> std::task::Poll<Result<(), Self::Error>> {
151        std::task::Poll::Ready(Ok(()))
152    }
153
154    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
155        let provider = self.provider.clone();
156        let tools = self.tools.clone();
157        let parallel = self.parallel_tools;
158        let _limit = self.tool_concurrency_limit;
159        let join_policy = self.join_policy;
160        let instructions = self.instructions.clone();
161        Box::pin(async move {
162            let mut token_buf = String::new();
163            let mut tool_calls: Vec<(String, String, Value)> = Vec::new();
164            let mut invoked_tool_names: Vec<String> = Vec::new();
165            let mut aux = StepAux::default();
166            let base_model = req.model.clone();
167
168            // Ensure model/messages present using builder semantics to normalize
169            let mut builder = CreateChatCompletionRequestArgs::default();
170            builder.model(base_model.clone());
171            // Inject instructions if provided
172            let mut injected_messages = req.messages.clone();
173            if let Some(instr) = instructions.clone() {
174                let sys_msg = ChatCompletionRequestSystemMessageArgs::default()
175                    .content(instr)
176                    .build()
177                    .map(ChatCompletionRequestMessage::from)
178                    .map_err(|e| format!("system msg build error: {}", e))?;
179                if let Some(pos) = injected_messages
180                    .iter()
181                    .position(|m| matches!(m, ChatCompletionRequestMessage::System(_)))
182                {
183                    injected_messages.remove(pos);
184                }
185                injected_messages.insert(0, sys_msg);
186            }
187            builder.messages(injected_messages);
188            let normalized_req = builder.build().map_err(|e| format!("build req: {}", e))?;
189
190            let stream = provider.stream_step(normalized_req).await?;
191
192            let (tx, rx) = mpsc::channel::<StepChunk>(32);
193            tokio::spawn(async move {
194                futures::pin_mut!(stream);
195                while let Some(item) = stream.next().await {
196                    match &item {
197                        StepChunk::Token(t) => {
198                            token_buf.push_str(t);
199                        }
200                        StepChunk::ToolCallStart {
201                            id,
202                            name,
203                            arguments,
204                        } => {
205                            tool_calls.push((id.clone(), name.clone(), arguments.clone()));
206                        }
207                        StepChunk::UsageDelta {
208                            prompt_tokens,
209                            completion_tokens,
210                        } => {
211                            aux.prompt_tokens += *prompt_tokens;
212                            aux.completion_tokens += *completion_tokens;
213                        }
214                        _ => {}
215                    }
216                    // forward every item as-is
217                    if tx.send(item).await.is_err() {
218                        return;
219                    }
220                }
221
222                // Provider stream ended; construct assistant + maybe run tools
223                let mut messages: Vec<ChatCompletionRequestMessage> = req.messages.clone();
224                // Build assistant message
225                let mut asst = ChatCompletionRequestAssistantMessageArgs::default();
226                asst.content(token_buf.clone());
227                if !tool_calls.is_empty() {
228                    // Build tool_calls list for assistant message
229                    let tool_calls_for_msg: Vec<
230                        async_openai::types::ChatCompletionMessageToolCall,
231                    > = tool_calls
232                        .iter()
233                        .map(|(id, name, arguments)| {
234                            async_openai::types::ChatCompletionMessageToolCall {
235                                id: id.clone(),
236                                r#type: async_openai::types::ChatCompletionToolType::Function,
237                                function: async_openai::types::FunctionCall {
238                                    name: name.clone(),
239                                    arguments: arguments.to_string(),
240                                },
241                            }
242                        })
243                        .collect();
244                    asst.tool_calls(tool_calls_for_msg);
245                }
246                match asst.build() {
247                    Ok(msg) => messages.push(msg.into()),
248                    Err(e) => {
249                        let _ = tx
250                            .send(StepChunk::Error(format!("assistant build: {}", e)))
251                            .await;
252                        return;
253                    }
254                }
255
256                // Execute tools (parallel if enabled) and emit ends preserving order
257                if tool_calls.len() > 1 && parallel {
258                    let sem = _limit.map(|n| Arc::new(Semaphore::new(n)));
259                    let mut futures = Vec::with_capacity(tool_calls.len());
260                    for (idx, (id, name, args)) in tool_calls.iter().cloned().enumerate() {
261                        invoked_tool_names.push(name.clone());
262                        let inv = ToolInvocation {
263                            id,
264                            name,
265                            arguments: args,
266                        };
267                        let mut svc = tools.clone();
268                        let sem_cl = sem.clone();
269                        futures.push(async move {
270                            let _permit = match &sem_cl {
271                                Some(s) => {
272                                    Some(s.clone().acquire_owned().await.expect("semaphore"))
273                                }
274                                None => None,
275                            };
276                            let ToolOutput { id: out_id, result } =
277                                ServiceExt::ready(&mut svc).await?.call(inv).await?;
278                            Ok::<(usize, String, Value), BoxError>((idx, out_id, result))
279                        });
280                    }
281                    match join_policy {
282                        ToolJoinPolicy::FailFast => {
283                            let results = futures::future::try_join_all(futures).await;
284                            match results {
285                                Ok(mut items) => {
286                                    items.sort_by_key(|(idx, _, _)| *idx);
287                                    for (_idx, out_id, result) in items.into_iter() {
288                                        aux.tool_invocations += 1;
289                                        match ChatCompletionRequestToolMessageArgs::default()
290                                            .tool_call_id(out_id.clone())
291                                            .content(result.to_string())
292                                            .build()
293                                        {
294                                            Ok(tool_msg) => messages.push(tool_msg.into()),
295                                            Err(e) => {
296                                                let _ = tx
297                                                    .send(StepChunk::Error(format!(
298                                                        "tool msg build: {}",
299                                                        e
300                                                    )))
301                                                    .await;
302                                                return;
303                                            }
304                                        }
305                                        let _ = tx
306                                            .send(StepChunk::ToolCallEnd {
307                                                id: out_id,
308                                                output: result,
309                                            })
310                                            .await;
311                                    }
312                                }
313                                Err(e) => {
314                                    let _ = tx
315                                        .send(StepChunk::Error(format!("tool error: {}", e)))
316                                        .await;
317                                    return;
318                                }
319                            }
320                        }
321                        ToolJoinPolicy::JoinAll => {
322                            // Wait for all, emit successes, then emit a single aggregated error if any failed
323                            let results = futures::future::join_all(futures).await;
324                            let mut items: Vec<(usize, String, Value)> = Vec::new();
325                            let mut errors: Vec<String> = Vec::new();
326                            for r in results.into_iter() {
327                                match r {
328                                    Ok((idx, id, result)) => items.push((idx, id, result)),
329                                    Err(e) => errors.push(format!("{}", e)),
330                                }
331                            }
332                            items.sort_by_key(|(idx, _, _)| *idx);
333                            for (_idx, out_id, result) in items.into_iter() {
334                                aux.tool_invocations += 1;
335                                match ChatCompletionRequestToolMessageArgs::default()
336                                    .tool_call_id(out_id.clone())
337                                    .content(result.to_string())
338                                    .build()
339                                {
340                                    Ok(tool_msg) => messages.push(tool_msg.into()),
341                                    Err(e) => {
342                                        let _ = tx
343                                            .send(StepChunk::Error(format!(
344                                                "tool msg build: {}",
345                                                e
346                                            )))
347                                            .await;
348                                        return;
349                                    }
350                                }
351                                let _ = tx
352                                    .send(StepChunk::ToolCallEnd {
353                                        id: out_id,
354                                        output: result,
355                                    })
356                                    .await;
357                            }
358                            if !errors.is_empty() {
359                                let _ = tx
360                                    .send(StepChunk::Error(format!(
361                                        "one or more tools failed: {}",
362                                        errors.join("; ")
363                                    )))
364                                    .await;
365                                return;
366                            }
367                        }
368                    }
369                } else {
370                    for (id, name, args) in tool_calls.into_iter() {
371                        invoked_tool_names.push(name.clone());
372                        let inv = ToolInvocation {
373                            id: id.clone(),
374                            name: name.clone(),
375                            arguments: args,
376                        };
377                        let mut svc = tools.clone();
378                        match ServiceExt::ready(&mut svc).await {
379                            Ok(ready) => match ready.call(inv).await {
380                                Ok(ToolOutput { id: out_id, result }) => {
381                                    aux.tool_invocations += 1;
382                                    match ChatCompletionRequestToolMessageArgs::default()
383                                        .tool_call_id(out_id.clone())
384                                        .content(result.to_string())
385                                        .build()
386                                    {
387                                        Ok(tool_msg) => messages.push(tool_msg.into()),
388                                        Err(e) => {
389                                            let _ = tx
390                                                .send(StepChunk::Error(format!(
391                                                    "tool msg build: {}",
392                                                    e
393                                                )))
394                                                .await;
395                                            return;
396                                        }
397                                    }
398                                    let _ = tx
399                                        .send(StepChunk::ToolCallEnd {
400                                            id: out_id,
401                                            output: result,
402                                        })
403                                        .await;
404                                }
405                                Err(e) => {
406                                    let _ = tx
407                                        .send(StepChunk::Error(format!("tool error: {}", e)))
408                                        .await;
409                                    return;
410                                }
411                            },
412                            Err(e) => {
413                                let _ = tx
414                                    .send(StepChunk::Error(format!("tool not ready: {}", e)))
415                                    .await;
416                                return;
417                            }
418                        }
419                    }
420                }
421
422                // Final outcome
423                let outcome = if invoked_tool_names.is_empty() {
424                    StepOutcome::Done { messages, aux }
425                } else {
426                    StepOutcome::Next {
427                        messages,
428                        aux,
429                        invoked_tools: invoked_tool_names,
430                    }
431                };
432                let _ = tx.send(StepChunk::StepComplete { outcome }).await;
433            });
434
435            Ok(Box::pin(ReceiverStream::new(rx)) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>)
436        })
437    }
438}
439
440/// Events emitted by the agent-level streaming loop.
441#[derive(Debug, Clone)]
442pub enum AgentEvent {
443    Step(usize),
444    Item(StepChunk),
445    RunComplete(AgentRun),
446}
447
448/// Layer that turns a streaming step into a multi-step agentic stream with policies.
449pub struct AgentLoopStreamLayer<P> {
450    policy: P,
451}
452
453impl<P> AgentLoopStreamLayer<P> {
454    pub fn new(policy: P) -> Self {
455        Self { policy }
456    }
457}
458
459pub struct AgentLoopStream<S, P> {
460    inner: Arc<tokio::sync::Mutex<S>>,
461    policy: P,
462}
463
464impl<S, P> Layer<S> for AgentLoopStreamLayer<P>
465where
466    P: Clone,
467{
468    type Service = AgentLoopStream<S, P>;
469    fn layer(&self, inner: S) -> Self::Service {
470        AgentLoopStream {
471            inner: Arc::new(tokio::sync::Mutex::new(inner)),
472            policy: self.policy.clone(),
473        }
474    }
475}
476
477impl<S, P> Service<CreateChatCompletionRequest> for AgentLoopStream<S, P>
478where
479    S: Service<
480            CreateChatCompletionRequest,
481            Response = Pin<Box<dyn Stream<Item = StepChunk> + Send>>,
482            Error = BoxError,
483        > + Send
484        + 'static,
485    S::Future: Send + 'static,
486    P: AgentPolicy + Send + Sync + Clone + 'static,
487{
488    type Response = Pin<Box<dyn Stream<Item = AgentEvent> + Send>>;
489    type Error = BoxError;
490    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
491
492    fn poll_ready(
493        &mut self,
494        _cx: &mut std::task::Context<'_>,
495    ) -> std::task::Poll<Result<(), Self::Error>> {
496        std::task::Poll::Ready(Ok(()))
497    }
498
499    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
500        let inner = self.inner.clone();
501        let policy = self.policy.clone();
502        Box::pin(async move {
503            let (tx, rx) = mpsc::channel::<AgentEvent>(64);
504            tokio::spawn(async move {
505                let base_model = req.model.clone();
506                let mut current_messages = req.messages.clone();
507                let mut state = LoopState::default();
508                let mut step_index: usize = 0;
509
510                loop {
511                    // Build current request
512                    let mut b = CreateChatCompletionRequestArgs::default();
513                    b.model(&base_model);
514                    b.messages(current_messages.clone());
515                    let current_req = match b.build() {
516                        Ok(r) => r,
517                        Err(e) => {
518                            let _ = tx
519                                .send(AgentEvent::Item(StepChunk::Error(format!(
520                                    "build req: {}",
521                                    e
522                                ))))
523                                .await;
524                            break;
525                        }
526                    };
527
528                    let mut guard = inner.lock().await;
529                    let stream = match guard.ready().await {
530                        Ok(svc) => match svc.call(current_req).await {
531                            Ok(st) => st,
532                            Err(e) => {
533                                let _ = tx
534                                    .send(AgentEvent::Item(StepChunk::Error(format!(
535                                        "step stream: {}",
536                                        e
537                                    ))))
538                                    .await;
539                                break;
540                            }
541                        },
542                        Err(e) => {
543                            let _ = tx
544                                .send(AgentEvent::Item(StepChunk::Error(format!(
545                                    "step not ready: {}",
546                                    e
547                                ))))
548                                .await;
549                            break;
550                        }
551                    };
552                    drop(guard);
553
554                    step_index += 1;
555                    if tx.send(AgentEvent::Step(step_index)).await.is_err() {
556                        break;
557                    }
558
559                    // Forward inner items until StepComplete
560                    futures::pin_mut!(stream);
561                    let mut last_outcome: Option<StepOutcome> = None;
562                    while let Some(item) = stream.next().await {
563                        let is_complete = matches!(item, StepChunk::StepComplete { .. });
564                        if let StepChunk::StepComplete { outcome } = item.clone() {
565                            last_outcome = Some(outcome);
566                        }
567                        if tx.send(AgentEvent::Item(item)).await.is_err() {
568                            return;
569                        }
570                        if is_complete {
571                            break;
572                        }
573                    }
574
575                    state.steps += 1;
576                    match last_outcome {
577                        Some(outcome) => {
578                            if let Some(stop) = policy.decide(&state, &outcome) {
579                                // Extract messages for final run
580                                let messages = match outcome {
581                                    StepOutcome::Next { messages, .. } => messages,
582                                    StepOutcome::Done { messages, .. } => messages,
583                                };
584                                let run = AgentRun {
585                                    messages,
586                                    steps: state.steps,
587                                    stop,
588                                };
589                                let _ = tx.send(AgentEvent::RunComplete(run)).await;
590                                break;
591                            }
592                            // Continue with updated messages
593                            current_messages = match outcome {
594                                StepOutcome::Next { messages, .. } => messages,
595                                StepOutcome::Done { messages, .. } => messages,
596                            };
597                        }
598                        None => {
599                            // No completion seen; treat as error
600                            let _ = tx
601                                .send(AgentEvent::Item(StepChunk::Error(
602                                    "missing StepComplete".into(),
603                                )))
604                                .await;
605                            break;
606                        }
607                    }
608                }
609            });
610            Ok(Box::pin(ReceiverStream::new(rx)) as Pin<Box<dyn Stream<Item = AgentEvent> + Send>>)
611        })
612    }
613}
614
615/// Layer that tees a stream of `AgentEvent`s to a sink function.
616pub struct StreamTapLayer {
617    sink: Arc<dyn Fn(&AgentEvent) + Send + Sync + 'static>,
618}
619
620impl StreamTapLayer {
621    pub fn new<F>(f: F) -> Self
622    where
623        F: Fn(&AgentEvent) + Send + Sync + 'static,
624    {
625        Self { sink: Arc::new(f) }
626    }
627}
628
629pub struct StreamTap<S> {
630    inner: S,
631    sink: Arc<dyn Fn(&AgentEvent) + Send + Sync + 'static>,
632}
633
634impl<S> Layer<S> for StreamTapLayer {
635    type Service = StreamTap<S>;
636    fn layer(&self, inner: S) -> Self::Service {
637        StreamTap {
638            inner,
639            sink: self.sink.clone(),
640        }
641    }
642}
643
644impl<S> Service<CreateChatCompletionRequest> for StreamTap<S>
645where
646    S: Service<
647            CreateChatCompletionRequest,
648            Response = Pin<Box<dyn Stream<Item = AgentEvent> + Send>>,
649            Error = BoxError,
650        > + Send
651        + 'static,
652    S::Future: Send + 'static,
653{
654    type Response = Pin<Box<dyn Stream<Item = AgentEvent> + Send>>;
655    type Error = BoxError;
656    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
657
658    fn poll_ready(
659        &mut self,
660        _cx: &mut std::task::Context<'_>,
661    ) -> std::task::Poll<Result<(), Self::Error>> {
662        std::task::Poll::Ready(Ok(()))
663    }
664
665    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
666        let sink = self.sink.clone();
667        let fut = self.inner.call(req);
668        Box::pin(async move {
669            let stream = fut.await?;
670            let (tx, rx) = mpsc::channel::<AgentEvent>(32);
671            tokio::spawn(async move {
672                futures::pin_mut!(stream);
673                while let Some(item) = stream.next().await {
674                    (sink)(&item);
675                    if tx.send(item).await.is_err() {
676                        return;
677                    }
678                }
679            });
680            Ok(Box::pin(ReceiverStream::new(rx)) as Pin<Box<dyn Stream<Item = AgentEvent> + Send>>)
681        })
682    }
683}
684
685/// Utility: collect a streaming agent run and return the final `AgentRun`.
686pub async fn collect_final<S>(stream: &mut S) -> Option<AgentRun>
687where
688    S: Stream<Item = AgentEvent> + Unpin,
689{
690    let mut final_run: Option<AgentRun> = None;
691    while let Some(ev) = stream.next().await {
692        if let AgentEvent::RunComplete(run) = ev {
693            final_run = Some(run);
694        }
695    }
696    final_run
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702    use crate::validation::{validate_conversation, ValidationPolicy};
703    use async_openai::types::ChatCompletionRequestUserMessageArgs;
704    use futures::stream;
705    use serde_json::json;
706    use tokio::time::{sleep, Duration};
707    use tower::service_fn;
708
709    struct FakeProvider {
710        items: Vec<StepChunk>,
711    }
712
713    impl StepProvider for FakeProvider {
714        type Stream = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
715        fn stream_step(
716            &self,
717            _req: CreateChatCompletionRequest,
718        ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>> {
719            let s = stream::iter(self.items.clone());
720            Box::pin(
721                async move { Ok(Box::pin(s) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>) },
722            )
723        }
724    }
725
726    struct CapturingProvider {
727        captured: Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
728    }
729
730    impl StepProvider for CapturingProvider {
731        type Stream = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
732        fn stream_step(
733            &self,
734            req: CreateChatCompletionRequest,
735        ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>> {
736            let captured = self.captured.clone();
737            Box::pin(async move {
738                *captured.lock().await = Some(req);
739                let s = stream::iter(Vec::<StepChunk>::new());
740                Ok(Box::pin(s) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>)
741            })
742        }
743    }
744
745    #[tokio::test]
746    async fn step_stream_invokes_tool_and_finishes() {
747        // Provider emits tokens, then a tool call
748        let provider = Arc::new(FakeProvider {
749            items: vec![
750                StepChunk::Token("Hello ".into()),
751                StepChunk::Token("world".into()),
752                StepChunk::ToolCallStart {
753                    id: "call_1".into(),
754                    name: "echo".into(),
755                    arguments: json!({"x": 1}),
756                },
757            ],
758        });
759
760        // Tool echoes args
761        let tool = service_fn(|inv: ToolInvocation| async move {
762            Ok::<_, BoxError>(ToolOutput {
763                id: inv.id,
764                result: json!({"ok": true}),
765            })
766        });
767
768        let mut svc = StepStreamService::new(provider, tool);
769        let req = CreateChatCompletionRequestArgs::default()
770            .model("gpt-4o")
771            .messages(vec![])
772            .build()
773            .unwrap();
774        let mut stream = svc.call(req).await.unwrap();
775        let mut got_tool_end = false;
776        let mut got_complete = false;
777        while let Some(item) = stream.next().await {
778            match item {
779                StepChunk::ToolCallEnd { id, output } => {
780                    assert_eq!(id, "call_1");
781                    assert_eq!(output, json!({"ok": true}));
782                    got_tool_end = true;
783                }
784                StepChunk::StepComplete { outcome } => {
785                    match outcome {
786                        StepOutcome::Next {
787                            messages,
788                            invoked_tools,
789                            ..
790                        } => {
791                            assert!(messages.len() >= 2); // assistant + tool
792                            assert_eq!(invoked_tools, vec!["echo".to_string()]);
793                            let policy = ValidationPolicy {
794                                allow_repeated_roles: true,
795                                require_user_first: false,
796                                require_user_present: false,
797                                ..Default::default()
798                            };
799                            assert!(validate_conversation(&messages, &policy).is_none());
800                        }
801                        _ => panic!("expected Next"),
802                    }
803                    got_complete = true;
804                }
805                _ => {}
806            }
807        }
808        assert!(got_tool_end && got_complete);
809    }
810
811    #[tokio::test]
812    async fn loop_stream_runs_until_policy() {
813        // Provider that yields just tokens and finishes (no tool calls)
814        let provider = Arc::new(FakeProvider {
815            items: vec![StepChunk::Token("ok".into())],
816        });
817        // No-op tool service
818        let tool = service_fn(|_inv: ToolInvocation| async move {
819            Ok::<_, BoxError>(ToolOutput {
820                id: "x".into(),
821                result: json!({}),
822            })
823        });
824        // Wrap into StepStreamService
825        let step = StepStreamService::new(provider, tool);
826        // Build a layer that turns it into an agent loop stream with a policy that stops on Done
827        let loop_layer = AgentLoopStreamLayer::new(crate::core::policies::until_no_tool_calls());
828        let mut agent_stream = loop_layer.layer(step);
829
830        let req = CreateChatCompletionRequestArgs::default()
831            .model("gpt-4o")
832            .messages(vec![])
833            .build()
834            .unwrap();
835        let mut stream = agent_stream.call(req).await.unwrap();
836        let mut saw_run_complete = false;
837        while let Some(ev) = stream.next().await {
838            if let AgentEvent::RunComplete(run) = ev {
839                saw_run_complete = true;
840                assert_eq!(run.steps, 1);
841                assert!(matches!(
842                    run.stop,
843                    crate::core::AgentStopReason::DoneNoToolCalls
844                ));
845                let policy = ValidationPolicy {
846                    allow_repeated_roles: true,
847                    require_user_first: false,
848                    require_user_present: false,
849                    ..Default::default()
850                };
851                assert!(validate_conversation(&run.messages, &policy).is_none());
852            }
853        }
854        assert!(saw_run_complete);
855    }
856
857    #[tokio::test]
858    async fn tap_layer_receives_every_event() {
859        let provider = Arc::new(FakeProvider {
860            items: vec![StepChunk::Token("a".into()), StepChunk::Token("b".into())],
861        });
862        let tool = service_fn(|_inv: ToolInvocation| async move {
863            Ok::<_, BoxError>(ToolOutput {
864                id: "i".into(),
865                result: json!({}),
866            })
867        });
868        let step = StepStreamService::new(provider, tool);
869        let loop_layer = AgentLoopStreamLayer::new(crate::core::policies::max_steps(1));
870        let agent = loop_layer.layer(step);
871        let tap_log: Arc<tokio::sync::Mutex<Vec<String>>> =
872            Arc::new(tokio::sync::Mutex::new(vec![]));
873        let tap_log_clone = tap_log.clone();
874        let tap = StreamTapLayer::new(move |ev: &AgentEvent| {
875            let s = format!("{:?}", ev);
876            let tl = tap_log_clone.clone();
877            tokio::spawn(async move {
878                tl.lock().await.push(s);
879            });
880        });
881        let mut svc = tap.layer(agent);
882        let req = CreateChatCompletionRequestArgs::default()
883            .model("gpt-4o")
884            .messages(vec![])
885            .build()
886            .unwrap();
887        let mut stream = svc.call(req).await.unwrap();
888        // Drain
889        while let Some(_ev) = stream.next().await {}
890        assert!(!tap_log.lock().await.is_empty());
891    }
892
893    #[tokio::test]
894    async fn instructions_are_injected_in_streaming_request() {
895        let captured: Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
896            Arc::new(tokio::sync::Mutex::new(None));
897        let provider = Arc::new(CapturingProvider {
898            captured: captured.clone(),
899        });
900        let tool = service_fn(|_inv: ToolInvocation| async move {
901            Ok::<_, BoxError>(ToolOutput {
902                id: "x".into(),
903                result: json!({}),
904            })
905        });
906        let mut svc = StepStreamService::new(provider, tool).instructions("INSTR");
907        let req = CreateChatCompletionRequestArgs::default()
908            .model("gpt-4o")
909            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
910                .content("hi")
911                .build()
912                .unwrap()
913                .into()])
914            .build()
915            .unwrap();
916        // Call the service and drop the stream immediately; capture happens on call
917        let _ = svc.call(req).await.unwrap();
918        let got = captured.lock().await.clone().expect("captured req");
919        assert!(!got.messages.is_empty());
920        match &got.messages[0] {
921            ChatCompletionRequestMessage::System(s) => match &s.content {
922                async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
923                    assert_eq!(t, "INSTR");
924                }
925                _ => panic!("expected text content"),
926            },
927            _ => panic!("expected first message to be system"),
928        }
929    }
930
931    #[tokio::test]
932    async fn step_stream_parallel_preserve_order() {
933        // Provider emits two tool calls
934        let provider = Arc::new(FakeProvider {
935            items: vec![
936                StepChunk::ToolCallStart {
937                    id: "c1".into(),
938                    name: "slow".into(),
939                    arguments: json!({}),
940                },
941                StepChunk::ToolCallStart {
942                    id: "c2".into(),
943                    name: "fast".into(),
944                    arguments: json!({}),
945                },
946            ],
947        });
948        // Tools with different latency
949        let tool = service_fn(|inv: ToolInvocation| async move {
950            if inv.name == "slow" {
951                sleep(Duration::from_millis(40)).await;
952            } else {
953                sleep(Duration::from_millis(5)).await;
954            }
955            Ok::<_, BoxError>(ToolOutput {
956                id: inv.id,
957                result: json!({"label": inv.name}),
958            })
959        });
960        let mut svc = StepStreamService::new(provider, tool).parallel_tools(true);
961        let req = CreateChatCompletionRequestArgs::default()
962            .model("gpt-4o")
963            .messages(vec![])
964            .build()
965            .unwrap();
966        let mut stream = svc.call(req).await.unwrap();
967        let mut end_ids: Vec<String> = Vec::new();
968        let mut saw_complete = false;
969        let mut final_messages: Option<Vec<ChatCompletionRequestMessage>> = None;
970        while let Some(item) = stream.next().await {
971            match item {
972                StepChunk::ToolCallEnd { id, .. } => end_ids.push(id),
973                StepChunk::StepComplete { outcome } => {
974                    saw_complete = true;
975                    match outcome {
976                        StepOutcome::Next { messages, .. } | StepOutcome::Done { messages, .. } => {
977                            final_messages = Some(messages);
978                        }
979                    }
980                }
981                _ => {}
982            }
983        }
984        assert!(saw_complete);
985        assert_eq!(end_ids, vec!["c1".to_string(), "c2".to_string()]);
986        if let Some(msgs) = final_messages {
987            let policy = ValidationPolicy {
988                allow_repeated_roles: true,
989                require_user_first: false,
990                require_user_present: false,
991                ..Default::default()
992            };
993            assert!(validate_conversation(&msgs, &policy).is_none());
994        }
995    }
996
997    #[tokio::test]
998    async fn step_stream_parallel_error_propagation() {
999        // Provider emits two tool calls; one will fail
1000        let provider = Arc::new(FakeProvider {
1001            items: vec![
1002                StepChunk::ToolCallStart {
1003                    id: "g1".into(),
1004                    name: "good".into(),
1005                    arguments: json!({}),
1006                },
1007                StepChunk::ToolCallStart {
1008                    id: "b1".into(),
1009                    name: "bad".into(),
1010                    arguments: json!({}),
1011                },
1012            ],
1013        });
1014        let tool = service_fn(|inv: ToolInvocation| async move {
1015            if inv.name == "bad" {
1016                Err::<ToolOutput, BoxError>("boom".into())
1017            } else {
1018                Ok::<_, BoxError>(ToolOutput {
1019                    id: inv.id,
1020                    result: json!({}),
1021                })
1022            }
1023        });
1024        let mut svc = StepStreamService::new(provider, tool).parallel_tools(true);
1025        let req = CreateChatCompletionRequestArgs::default()
1026            .model("gpt-4o")
1027            .messages(vec![])
1028            .build()
1029            .unwrap();
1030        let mut stream = svc.call(req).await.unwrap();
1031        let mut saw_error = false;
1032        let mut saw_complete = false;
1033        while let Some(item) = stream.next().await {
1034            match item {
1035                StepChunk::Error(e) => {
1036                    saw_error = true;
1037                    assert!(e.contains("tool error"));
1038                }
1039                StepChunk::StepComplete { .. } => saw_complete = true,
1040                _ => {}
1041            }
1042        }
1043        assert!(saw_error);
1044        assert!(!saw_complete);
1045    }
1046
1047    #[tokio::test]
1048    async fn step_stream_parallel_concurrency_limit() {
1049        use std::sync::atomic::{AtomicUsize, Ordering};
1050        static CURRENT: AtomicUsize = AtomicUsize::new(0);
1051        static MAX_OBSERVED: AtomicUsize = AtomicUsize::new(0);
1052
1053        let mut items = Vec::new();
1054        for i in 0..8 {
1055            items.push(StepChunk::ToolCallStart {
1056                id: format!("c{}", i),
1057                name: "gate".into(),
1058                arguments: json!({}),
1059            });
1060        }
1061        let provider = Arc::new(FakeProvider { items });
1062        let tool = service_fn(|inv: ToolInvocation| async move {
1063            let now = CURRENT.fetch_add(1, Ordering::SeqCst) + 1;
1064            let max = MAX_OBSERVED.load(Ordering::SeqCst);
1065            if now > max {
1066                let _ = MAX_OBSERVED.compare_exchange(max, now, Ordering::SeqCst, Ordering::SeqCst);
1067            }
1068            sleep(Duration::from_millis(10)).await;
1069            CURRENT.fetch_sub(1, Ordering::SeqCst);
1070            Ok::<_, BoxError>(ToolOutput {
1071                id: inv.id,
1072                result: json!({}),
1073            })
1074        });
1075
1076        let mut svc = StepStreamService::new(provider, tool)
1077            .parallel_tools(true)
1078            .tool_concurrency_limit(3);
1079        let req = CreateChatCompletionRequestArgs::default()
1080            .model("gpt-4o")
1081            .messages(vec![])
1082            .build()
1083            .unwrap();
1084        let mut stream = svc.call(req).await.unwrap();
1085        while let Some(_item) = stream.next().await {}
1086        assert!(MAX_OBSERVED.load(Ordering::SeqCst) <= 3);
1087    }
1088
1089    #[tokio::test]
1090    async fn step_stream_parallel_failfast_early_termination() {
1091        use serde_json::json;
1092        // Provider emits two tool calls
1093        let provider = Arc::new(FakeProvider {
1094            items: vec![
1095                StepChunk::ToolCallStart {
1096                    id: "b1".into(),
1097                    name: "bad".into(),
1098                    arguments: json!({}),
1099                },
1100                StepChunk::ToolCallStart {
1101                    id: "s1".into(),
1102                    name: "slow".into(),
1103                    arguments: json!({}),
1104                },
1105            ],
1106        });
1107        let tool = service_fn(|inv: ToolInvocation| async move {
1108            if inv.name == "bad" {
1109                Err::<ToolOutput, BoxError>("boom".into())
1110            } else {
1111                sleep(Duration::from_millis(40)).await;
1112                Ok::<_, BoxError>(ToolOutput {
1113                    id: inv.id,
1114                    result: json!({"ok":true}),
1115                })
1116            }
1117        });
1118        let mut svc = StepStreamService::new(provider, tool)
1119            .parallel_tools(true)
1120            .tool_join_policy(crate::core::ToolJoinPolicy::FailFast);
1121        let req = CreateChatCompletionRequestArgs::default()
1122            .model("gpt-4o")
1123            .messages(vec![])
1124            .build()
1125            .unwrap();
1126        let mut stream = svc.call(req).await.unwrap();
1127        let mut saw_slow_end = false;
1128        let mut saw_error = false;
1129        while let Some(item) = stream.next().await {
1130            match item {
1131                StepChunk::ToolCallEnd { id, .. } => {
1132                    if id == "s1" {
1133                        saw_slow_end = true;
1134                    }
1135                }
1136                StepChunk::Error(_) => {
1137                    saw_error = true;
1138                }
1139                _ => {}
1140            }
1141        }
1142        assert!(saw_error);
1143        assert!(!saw_slow_end);
1144    }
1145
1146    #[tokio::test]
1147    async fn step_stream_parallel_joinall_emits_successes_then_error() {
1148        use serde_json::json;
1149        // Provider emits two tool calls
1150        let provider = Arc::new(FakeProvider {
1151            items: vec![
1152                StepChunk::ToolCallStart {
1153                    id: "b1".into(),
1154                    name: "bad".into(),
1155                    arguments: json!({}),
1156                },
1157                StepChunk::ToolCallStart {
1158                    id: "s1".into(),
1159                    name: "slow".into(),
1160                    arguments: json!({}),
1161                },
1162            ],
1163        });
1164        let tool = service_fn(|inv: ToolInvocation| async move {
1165            if inv.name == "bad" {
1166                Err::<ToolOutput, BoxError>("boom".into())
1167            } else {
1168                sleep(Duration::from_millis(20)).await;
1169                Ok::<_, BoxError>(ToolOutput {
1170                    id: inv.id,
1171                    result: json!({"ok":true}),
1172                })
1173            }
1174        });
1175        let mut svc = StepStreamService::new(provider, tool)
1176            .parallel_tools(true)
1177            .tool_join_policy(crate::core::ToolJoinPolicy::JoinAll)
1178            .tool_concurrency_limit(1);
1179        let req = CreateChatCompletionRequestArgs::default()
1180            .model("gpt-4o")
1181            .messages(vec![])
1182            .build()
1183            .unwrap();
1184        let mut stream = svc.call(req).await.unwrap();
1185        let mut saw_slow_end = false;
1186        let mut saw_error = false;
1187        let mut saw_complete = false;
1188        while let Some(item) = stream.next().await {
1189            match item {
1190                StepChunk::ToolCallEnd { id, .. } => {
1191                    if id == "s1" {
1192                        saw_slow_end = true;
1193                    }
1194                }
1195                StepChunk::Error(_) => {
1196                    saw_error = true;
1197                }
1198                StepChunk::StepComplete { .. } => {
1199                    saw_complete = true;
1200                }
1201                _ => {}
1202            }
1203        }
1204        assert!(saw_slow_end);
1205        assert!(saw_error);
1206        assert!(!saw_complete);
1207    }
1208}