1use std::future::Future;
34use std::pin::Pin;
35use std::sync::Arc;
36
37use async_openai::types::{
38 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
39 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
40 CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
41};
42use futures::{Stream, StreamExt};
43use serde_json::Value;
44use tokio::sync::{mpsc, Semaphore};
45use tokio_stream::wrappers::ReceiverStream;
46use tower::{BoxError, Layer, Service, ServiceExt};
47
48use crate::core::{
49 AgentPolicy, AgentRun, LoopState, StepAux, StepOutcome, ToolInvocation, ToolJoinPolicy,
50 ToolOutput,
51};
52
53#[derive(Debug, Clone)]
55pub enum StepChunk {
56 Token(String),
57 ToolCallStart {
58 id: String,
59 name: String,
60 arguments: Value,
61 },
62 ToolCallEnd {
63 id: String,
64 output: Value,
65 },
66 UsageDelta {
67 prompt_tokens: usize,
68 completion_tokens: usize,
69 },
70 StepComplete {
72 outcome: StepOutcome,
73 },
74 Error(String),
76}
77
78pub trait StepProvider: Send + Sync + 'static {
83 type Stream: Stream<Item = StepChunk> + Send + 'static;
84 fn stream_step(
85 &self,
86 req: CreateChatCompletionRequest,
87 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>>;
88}
89
90pub struct StepStreamService<P, T> {
96 provider: Arc<P>,
97 tools: T, instructions: Option<String>,
99 parallel_tools: bool,
100 tool_concurrency_limit: Option<usize>,
101 join_policy: ToolJoinPolicy,
102}
103
104impl<P, T> StepStreamService<P, T> {
105 pub fn new(provider: Arc<P>, tools: T) -> Self {
106 Self {
107 provider,
108 tools,
109 instructions: None,
110 parallel_tools: false,
111 tool_concurrency_limit: None,
112 join_policy: ToolJoinPolicy::FailFast,
113 }
114 }
115
116 pub fn parallel_tools(mut self, enabled: bool) -> Self {
117 self.parallel_tools = enabled;
118 self
119 }
120
121 pub fn tool_concurrency_limit(mut self, limit: usize) -> Self {
122 self.tool_concurrency_limit = Some(limit);
123 self
124 }
125
126 pub fn tool_join_policy(mut self, policy: ToolJoinPolicy) -> Self {
127 self.join_policy = policy;
128 self
129 }
130
131 pub fn instructions(mut self, text: impl Into<String>) -> Self {
132 self.instructions = Some(text.into());
133 self
134 }
135}
136
137impl<P, T> Service<CreateChatCompletionRequest> for StepStreamService<P, T>
138where
139 P: StepProvider,
140 T: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Clone + Send + 'static,
141 T::Future: Send + 'static,
142{
143 type Response = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
144 type Error = BoxError;
145 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
146
147 fn poll_ready(
148 &mut self,
149 _cx: &mut std::task::Context<'_>,
150 ) -> std::task::Poll<Result<(), Self::Error>> {
151 std::task::Poll::Ready(Ok(()))
152 }
153
154 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
155 let provider = self.provider.clone();
156 let tools = self.tools.clone();
157 let parallel = self.parallel_tools;
158 let _limit = self.tool_concurrency_limit;
159 let join_policy = self.join_policy;
160 let instructions = self.instructions.clone();
161 Box::pin(async move {
162 let mut token_buf = String::new();
163 let mut tool_calls: Vec<(String, String, Value)> = Vec::new();
164 let mut invoked_tool_names: Vec<String> = Vec::new();
165 let mut aux = StepAux::default();
166 let base_model = req.model.clone();
167
168 let mut builder = CreateChatCompletionRequestArgs::default();
170 builder.model(base_model.clone());
171 let mut injected_messages = req.messages.clone();
173 if let Some(instr) = instructions.clone() {
174 let sys_msg = ChatCompletionRequestSystemMessageArgs::default()
175 .content(instr)
176 .build()
177 .map(ChatCompletionRequestMessage::from)
178 .map_err(|e| format!("system msg build error: {}", e))?;
179 if let Some(pos) = injected_messages
180 .iter()
181 .position(|m| matches!(m, ChatCompletionRequestMessage::System(_)))
182 {
183 injected_messages.remove(pos);
184 }
185 injected_messages.insert(0, sys_msg);
186 }
187 builder.messages(injected_messages);
188 let normalized_req = builder.build().map_err(|e| format!("build req: {}", e))?;
189
190 let stream = provider.stream_step(normalized_req).await?;
191
192 let (tx, rx) = mpsc::channel::<StepChunk>(32);
193 tokio::spawn(async move {
194 futures::pin_mut!(stream);
195 while let Some(item) = stream.next().await {
196 match &item {
197 StepChunk::Token(t) => {
198 token_buf.push_str(t);
199 }
200 StepChunk::ToolCallStart {
201 id,
202 name,
203 arguments,
204 } => {
205 tool_calls.push((id.clone(), name.clone(), arguments.clone()));
206 }
207 StepChunk::UsageDelta {
208 prompt_tokens,
209 completion_tokens,
210 } => {
211 aux.prompt_tokens += *prompt_tokens;
212 aux.completion_tokens += *completion_tokens;
213 }
214 _ => {}
215 }
216 if tx.send(item).await.is_err() {
218 return;
219 }
220 }
221
222 let mut messages: Vec<ChatCompletionRequestMessage> = req.messages.clone();
224 let mut asst = ChatCompletionRequestAssistantMessageArgs::default();
226 asst.content(token_buf.clone());
227 if !tool_calls.is_empty() {
228 let tool_calls_for_msg: Vec<
230 async_openai::types::ChatCompletionMessageToolCall,
231 > = tool_calls
232 .iter()
233 .map(|(id, name, arguments)| {
234 async_openai::types::ChatCompletionMessageToolCall {
235 id: id.clone(),
236 r#type: async_openai::types::ChatCompletionToolType::Function,
237 function: async_openai::types::FunctionCall {
238 name: name.clone(),
239 arguments: arguments.to_string(),
240 },
241 }
242 })
243 .collect();
244 asst.tool_calls(tool_calls_for_msg);
245 }
246 match asst.build() {
247 Ok(msg) => messages.push(msg.into()),
248 Err(e) => {
249 let _ = tx
250 .send(StepChunk::Error(format!("assistant build: {}", e)))
251 .await;
252 return;
253 }
254 }
255
256 if tool_calls.len() > 1 && parallel {
258 let sem = _limit.map(|n| Arc::new(Semaphore::new(n)));
259 let mut futures = Vec::with_capacity(tool_calls.len());
260 for (idx, (id, name, args)) in tool_calls.iter().cloned().enumerate() {
261 invoked_tool_names.push(name.clone());
262 let inv = ToolInvocation {
263 id,
264 name,
265 arguments: args,
266 };
267 let mut svc = tools.clone();
268 let sem_cl = sem.clone();
269 futures.push(async move {
270 let _permit = match &sem_cl {
271 Some(s) => {
272 Some(s.clone().acquire_owned().await.expect("semaphore"))
273 }
274 None => None,
275 };
276 let ToolOutput { id: out_id, result } =
277 ServiceExt::ready(&mut svc).await?.call(inv).await?;
278 Ok::<(usize, String, Value), BoxError>((idx, out_id, result))
279 });
280 }
281 match join_policy {
282 ToolJoinPolicy::FailFast => {
283 let results = futures::future::try_join_all(futures).await;
284 match results {
285 Ok(mut items) => {
286 items.sort_by_key(|(idx, _, _)| *idx);
287 for (_idx, out_id, result) in items.into_iter() {
288 aux.tool_invocations += 1;
289 match ChatCompletionRequestToolMessageArgs::default()
290 .tool_call_id(out_id.clone())
291 .content(result.to_string())
292 .build()
293 {
294 Ok(tool_msg) => messages.push(tool_msg.into()),
295 Err(e) => {
296 let _ = tx
297 .send(StepChunk::Error(format!(
298 "tool msg build: {}",
299 e
300 )))
301 .await;
302 return;
303 }
304 }
305 let _ = tx
306 .send(StepChunk::ToolCallEnd {
307 id: out_id,
308 output: result,
309 })
310 .await;
311 }
312 }
313 Err(e) => {
314 let _ = tx
315 .send(StepChunk::Error(format!("tool error: {}", e)))
316 .await;
317 return;
318 }
319 }
320 }
321 ToolJoinPolicy::JoinAll => {
322 let results = futures::future::join_all(futures).await;
324 let mut items: Vec<(usize, String, Value)> = Vec::new();
325 let mut errors: Vec<String> = Vec::new();
326 for r in results.into_iter() {
327 match r {
328 Ok((idx, id, result)) => items.push((idx, id, result)),
329 Err(e) => errors.push(format!("{}", e)),
330 }
331 }
332 items.sort_by_key(|(idx, _, _)| *idx);
333 for (_idx, out_id, result) in items.into_iter() {
334 aux.tool_invocations += 1;
335 match ChatCompletionRequestToolMessageArgs::default()
336 .tool_call_id(out_id.clone())
337 .content(result.to_string())
338 .build()
339 {
340 Ok(tool_msg) => messages.push(tool_msg.into()),
341 Err(e) => {
342 let _ = tx
343 .send(StepChunk::Error(format!(
344 "tool msg build: {}",
345 e
346 )))
347 .await;
348 return;
349 }
350 }
351 let _ = tx
352 .send(StepChunk::ToolCallEnd {
353 id: out_id,
354 output: result,
355 })
356 .await;
357 }
358 if !errors.is_empty() {
359 let _ = tx
360 .send(StepChunk::Error(format!(
361 "one or more tools failed: {}",
362 errors.join("; ")
363 )))
364 .await;
365 return;
366 }
367 }
368 }
369 } else {
370 for (id, name, args) in tool_calls.into_iter() {
371 invoked_tool_names.push(name.clone());
372 let inv = ToolInvocation {
373 id: id.clone(),
374 name: name.clone(),
375 arguments: args,
376 };
377 let mut svc = tools.clone();
378 match ServiceExt::ready(&mut svc).await {
379 Ok(ready) => match ready.call(inv).await {
380 Ok(ToolOutput { id: out_id, result }) => {
381 aux.tool_invocations += 1;
382 match ChatCompletionRequestToolMessageArgs::default()
383 .tool_call_id(out_id.clone())
384 .content(result.to_string())
385 .build()
386 {
387 Ok(tool_msg) => messages.push(tool_msg.into()),
388 Err(e) => {
389 let _ = tx
390 .send(StepChunk::Error(format!(
391 "tool msg build: {}",
392 e
393 )))
394 .await;
395 return;
396 }
397 }
398 let _ = tx
399 .send(StepChunk::ToolCallEnd {
400 id: out_id,
401 output: result,
402 })
403 .await;
404 }
405 Err(e) => {
406 let _ = tx
407 .send(StepChunk::Error(format!("tool error: {}", e)))
408 .await;
409 return;
410 }
411 },
412 Err(e) => {
413 let _ = tx
414 .send(StepChunk::Error(format!("tool not ready: {}", e)))
415 .await;
416 return;
417 }
418 }
419 }
420 }
421
422 let outcome = if invoked_tool_names.is_empty() {
424 StepOutcome::Done { messages, aux }
425 } else {
426 StepOutcome::Next {
427 messages,
428 aux,
429 invoked_tools: invoked_tool_names,
430 }
431 };
432 let _ = tx.send(StepChunk::StepComplete { outcome }).await;
433 });
434
435 Ok(Box::pin(ReceiverStream::new(rx)) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>)
436 })
437 }
438}
439
440#[derive(Debug, Clone)]
442pub enum AgentEvent {
443 Step(usize),
444 Item(StepChunk),
445 RunComplete(AgentRun),
446}
447
448pub struct AgentLoopStreamLayer<P> {
450 policy: P,
451}
452
453impl<P> AgentLoopStreamLayer<P> {
454 pub fn new(policy: P) -> Self {
455 Self { policy }
456 }
457}
458
459pub struct AgentLoopStream<S, P> {
460 inner: Arc<tokio::sync::Mutex<S>>,
461 policy: P,
462}
463
464impl<S, P> Layer<S> for AgentLoopStreamLayer<P>
465where
466 P: Clone,
467{
468 type Service = AgentLoopStream<S, P>;
469 fn layer(&self, inner: S) -> Self::Service {
470 AgentLoopStream {
471 inner: Arc::new(tokio::sync::Mutex::new(inner)),
472 policy: self.policy.clone(),
473 }
474 }
475}
476
477impl<S, P> Service<CreateChatCompletionRequest> for AgentLoopStream<S, P>
478where
479 S: Service<
480 CreateChatCompletionRequest,
481 Response = Pin<Box<dyn Stream<Item = StepChunk> + Send>>,
482 Error = BoxError,
483 > + Send
484 + 'static,
485 S::Future: Send + 'static,
486 P: AgentPolicy + Send + Sync + Clone + 'static,
487{
488 type Response = Pin<Box<dyn Stream<Item = AgentEvent> + Send>>;
489 type Error = BoxError;
490 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
491
492 fn poll_ready(
493 &mut self,
494 _cx: &mut std::task::Context<'_>,
495 ) -> std::task::Poll<Result<(), Self::Error>> {
496 std::task::Poll::Ready(Ok(()))
497 }
498
499 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
500 let inner = self.inner.clone();
501 let policy = self.policy.clone();
502 Box::pin(async move {
503 let (tx, rx) = mpsc::channel::<AgentEvent>(64);
504 tokio::spawn(async move {
505 let base_model = req.model.clone();
506 let mut current_messages = req.messages.clone();
507 let mut state = LoopState::default();
508 let mut step_index: usize = 0;
509
510 loop {
511 let mut b = CreateChatCompletionRequestArgs::default();
513 b.model(&base_model);
514 b.messages(current_messages.clone());
515 let current_req = match b.build() {
516 Ok(r) => r,
517 Err(e) => {
518 let _ = tx
519 .send(AgentEvent::Item(StepChunk::Error(format!(
520 "build req: {}",
521 e
522 ))))
523 .await;
524 break;
525 }
526 };
527
528 let mut guard = inner.lock().await;
529 let stream = match guard.ready().await {
530 Ok(svc) => match svc.call(current_req).await {
531 Ok(st) => st,
532 Err(e) => {
533 let _ = tx
534 .send(AgentEvent::Item(StepChunk::Error(format!(
535 "step stream: {}",
536 e
537 ))))
538 .await;
539 break;
540 }
541 },
542 Err(e) => {
543 let _ = tx
544 .send(AgentEvent::Item(StepChunk::Error(format!(
545 "step not ready: {}",
546 e
547 ))))
548 .await;
549 break;
550 }
551 };
552 drop(guard);
553
554 step_index += 1;
555 if tx.send(AgentEvent::Step(step_index)).await.is_err() {
556 break;
557 }
558
559 futures::pin_mut!(stream);
561 let mut last_outcome: Option<StepOutcome> = None;
562 while let Some(item) = stream.next().await {
563 let is_complete = matches!(item, StepChunk::StepComplete { .. });
564 if let StepChunk::StepComplete { outcome } = item.clone() {
565 last_outcome = Some(outcome);
566 }
567 if tx.send(AgentEvent::Item(item)).await.is_err() {
568 return;
569 }
570 if is_complete {
571 break;
572 }
573 }
574
575 state.steps += 1;
576 match last_outcome {
577 Some(outcome) => {
578 if let Some(stop) = policy.decide(&state, &outcome) {
579 let messages = match outcome {
581 StepOutcome::Next { messages, .. } => messages,
582 StepOutcome::Done { messages, .. } => messages,
583 };
584 let run = AgentRun {
585 messages,
586 steps: state.steps,
587 stop,
588 };
589 let _ = tx.send(AgentEvent::RunComplete(run)).await;
590 break;
591 }
592 current_messages = match outcome {
594 StepOutcome::Next { messages, .. } => messages,
595 StepOutcome::Done { messages, .. } => messages,
596 };
597 }
598 None => {
599 let _ = tx
601 .send(AgentEvent::Item(StepChunk::Error(
602 "missing StepComplete".into(),
603 )))
604 .await;
605 break;
606 }
607 }
608 }
609 });
610 Ok(Box::pin(ReceiverStream::new(rx)) as Pin<Box<dyn Stream<Item = AgentEvent> + Send>>)
611 })
612 }
613}
614
615pub struct StreamTapLayer {
617 sink: Arc<dyn Fn(&AgentEvent) + Send + Sync + 'static>,
618}
619
620impl StreamTapLayer {
621 pub fn new<F>(f: F) -> Self
622 where
623 F: Fn(&AgentEvent) + Send + Sync + 'static,
624 {
625 Self { sink: Arc::new(f) }
626 }
627}
628
629pub struct StreamTap<S> {
630 inner: S,
631 sink: Arc<dyn Fn(&AgentEvent) + Send + Sync + 'static>,
632}
633
634impl<S> Layer<S> for StreamTapLayer {
635 type Service = StreamTap<S>;
636 fn layer(&self, inner: S) -> Self::Service {
637 StreamTap {
638 inner,
639 sink: self.sink.clone(),
640 }
641 }
642}
643
644impl<S> Service<CreateChatCompletionRequest> for StreamTap<S>
645where
646 S: Service<
647 CreateChatCompletionRequest,
648 Response = Pin<Box<dyn Stream<Item = AgentEvent> + Send>>,
649 Error = BoxError,
650 > + Send
651 + 'static,
652 S::Future: Send + 'static,
653{
654 type Response = Pin<Box<dyn Stream<Item = AgentEvent> + Send>>;
655 type Error = BoxError;
656 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
657
658 fn poll_ready(
659 &mut self,
660 _cx: &mut std::task::Context<'_>,
661 ) -> std::task::Poll<Result<(), Self::Error>> {
662 std::task::Poll::Ready(Ok(()))
663 }
664
665 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
666 let sink = self.sink.clone();
667 let fut = self.inner.call(req);
668 Box::pin(async move {
669 let stream = fut.await?;
670 let (tx, rx) = mpsc::channel::<AgentEvent>(32);
671 tokio::spawn(async move {
672 futures::pin_mut!(stream);
673 while let Some(item) = stream.next().await {
674 (sink)(&item);
675 if tx.send(item).await.is_err() {
676 return;
677 }
678 }
679 });
680 Ok(Box::pin(ReceiverStream::new(rx)) as Pin<Box<dyn Stream<Item = AgentEvent> + Send>>)
681 })
682 }
683}
684
685pub async fn collect_final<S>(stream: &mut S) -> Option<AgentRun>
687where
688 S: Stream<Item = AgentEvent> + Unpin,
689{
690 let mut final_run: Option<AgentRun> = None;
691 while let Some(ev) = stream.next().await {
692 if let AgentEvent::RunComplete(run) = ev {
693 final_run = Some(run);
694 }
695 }
696 final_run
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702 use crate::validation::{validate_conversation, ValidationPolicy};
703 use async_openai::types::ChatCompletionRequestUserMessageArgs;
704 use futures::stream;
705 use serde_json::json;
706 use tokio::time::{sleep, Duration};
707 use tower::service_fn;
708
709 struct FakeProvider {
710 items: Vec<StepChunk>,
711 }
712
713 impl StepProvider for FakeProvider {
714 type Stream = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
715 fn stream_step(
716 &self,
717 _req: CreateChatCompletionRequest,
718 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>> {
719 let s = stream::iter(self.items.clone());
720 Box::pin(
721 async move { Ok(Box::pin(s) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>) },
722 )
723 }
724 }
725
726 struct CapturingProvider {
727 captured: Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
728 }
729
730 impl StepProvider for CapturingProvider {
731 type Stream = Pin<Box<dyn Stream<Item = StepChunk> + Send>>;
732 fn stream_step(
733 &self,
734 req: CreateChatCompletionRequest,
735 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream, BoxError>> + Send>> {
736 let captured = self.captured.clone();
737 Box::pin(async move {
738 *captured.lock().await = Some(req);
739 let s = stream::iter(Vec::<StepChunk>::new());
740 Ok(Box::pin(s) as Pin<Box<dyn Stream<Item = StepChunk> + Send>>)
741 })
742 }
743 }
744
745 #[tokio::test]
746 async fn step_stream_invokes_tool_and_finishes() {
747 let provider = Arc::new(FakeProvider {
749 items: vec![
750 StepChunk::Token("Hello ".into()),
751 StepChunk::Token("world".into()),
752 StepChunk::ToolCallStart {
753 id: "call_1".into(),
754 name: "echo".into(),
755 arguments: json!({"x": 1}),
756 },
757 ],
758 });
759
760 let tool = service_fn(|inv: ToolInvocation| async move {
762 Ok::<_, BoxError>(ToolOutput {
763 id: inv.id,
764 result: json!({"ok": true}),
765 })
766 });
767
768 let mut svc = StepStreamService::new(provider, tool);
769 let req = CreateChatCompletionRequestArgs::default()
770 .model("gpt-4o")
771 .messages(vec![])
772 .build()
773 .unwrap();
774 let mut stream = svc.call(req).await.unwrap();
775 let mut got_tool_end = false;
776 let mut got_complete = false;
777 while let Some(item) = stream.next().await {
778 match item {
779 StepChunk::ToolCallEnd { id, output } => {
780 assert_eq!(id, "call_1");
781 assert_eq!(output, json!({"ok": true}));
782 got_tool_end = true;
783 }
784 StepChunk::StepComplete { outcome } => {
785 match outcome {
786 StepOutcome::Next {
787 messages,
788 invoked_tools,
789 ..
790 } => {
791 assert!(messages.len() >= 2); assert_eq!(invoked_tools, vec!["echo".to_string()]);
793 let policy = ValidationPolicy {
794 allow_repeated_roles: true,
795 require_user_first: false,
796 require_user_present: false,
797 ..Default::default()
798 };
799 assert!(validate_conversation(&messages, &policy).is_none());
800 }
801 _ => panic!("expected Next"),
802 }
803 got_complete = true;
804 }
805 _ => {}
806 }
807 }
808 assert!(got_tool_end && got_complete);
809 }
810
811 #[tokio::test]
812 async fn loop_stream_runs_until_policy() {
813 let provider = Arc::new(FakeProvider {
815 items: vec![StepChunk::Token("ok".into())],
816 });
817 let tool = service_fn(|_inv: ToolInvocation| async move {
819 Ok::<_, BoxError>(ToolOutput {
820 id: "x".into(),
821 result: json!({}),
822 })
823 });
824 let step = StepStreamService::new(provider, tool);
826 let loop_layer = AgentLoopStreamLayer::new(crate::core::policies::until_no_tool_calls());
828 let mut agent_stream = loop_layer.layer(step);
829
830 let req = CreateChatCompletionRequestArgs::default()
831 .model("gpt-4o")
832 .messages(vec![])
833 .build()
834 .unwrap();
835 let mut stream = agent_stream.call(req).await.unwrap();
836 let mut saw_run_complete = false;
837 while let Some(ev) = stream.next().await {
838 if let AgentEvent::RunComplete(run) = ev {
839 saw_run_complete = true;
840 assert_eq!(run.steps, 1);
841 assert!(matches!(
842 run.stop,
843 crate::core::AgentStopReason::DoneNoToolCalls
844 ));
845 let policy = ValidationPolicy {
846 allow_repeated_roles: true,
847 require_user_first: false,
848 require_user_present: false,
849 ..Default::default()
850 };
851 assert!(validate_conversation(&run.messages, &policy).is_none());
852 }
853 }
854 assert!(saw_run_complete);
855 }
856
857 #[tokio::test]
858 async fn tap_layer_receives_every_event() {
859 let provider = Arc::new(FakeProvider {
860 items: vec![StepChunk::Token("a".into()), StepChunk::Token("b".into())],
861 });
862 let tool = service_fn(|_inv: ToolInvocation| async move {
863 Ok::<_, BoxError>(ToolOutput {
864 id: "i".into(),
865 result: json!({}),
866 })
867 });
868 let step = StepStreamService::new(provider, tool);
869 let loop_layer = AgentLoopStreamLayer::new(crate::core::policies::max_steps(1));
870 let agent = loop_layer.layer(step);
871 let tap_log: Arc<tokio::sync::Mutex<Vec<String>>> =
872 Arc::new(tokio::sync::Mutex::new(vec![]));
873 let tap_log_clone = tap_log.clone();
874 let tap = StreamTapLayer::new(move |ev: &AgentEvent| {
875 let s = format!("{:?}", ev);
876 let tl = tap_log_clone.clone();
877 tokio::spawn(async move {
878 tl.lock().await.push(s);
879 });
880 });
881 let mut svc = tap.layer(agent);
882 let req = CreateChatCompletionRequestArgs::default()
883 .model("gpt-4o")
884 .messages(vec![])
885 .build()
886 .unwrap();
887 let mut stream = svc.call(req).await.unwrap();
888 while let Some(_ev) = stream.next().await {}
890 assert!(!tap_log.lock().await.is_empty());
891 }
892
893 #[tokio::test]
894 async fn instructions_are_injected_in_streaming_request() {
895 let captured: Arc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
896 Arc::new(tokio::sync::Mutex::new(None));
897 let provider = Arc::new(CapturingProvider {
898 captured: captured.clone(),
899 });
900 let tool = service_fn(|_inv: ToolInvocation| async move {
901 Ok::<_, BoxError>(ToolOutput {
902 id: "x".into(),
903 result: json!({}),
904 })
905 });
906 let mut svc = StepStreamService::new(provider, tool).instructions("INSTR");
907 let req = CreateChatCompletionRequestArgs::default()
908 .model("gpt-4o")
909 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
910 .content("hi")
911 .build()
912 .unwrap()
913 .into()])
914 .build()
915 .unwrap();
916 let _ = svc.call(req).await.unwrap();
918 let got = captured.lock().await.clone().expect("captured req");
919 assert!(!got.messages.is_empty());
920 match &got.messages[0] {
921 ChatCompletionRequestMessage::System(s) => match &s.content {
922 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
923 assert_eq!(t, "INSTR");
924 }
925 _ => panic!("expected text content"),
926 },
927 _ => panic!("expected first message to be system"),
928 }
929 }
930
931 #[tokio::test]
932 async fn step_stream_parallel_preserve_order() {
933 let provider = Arc::new(FakeProvider {
935 items: vec![
936 StepChunk::ToolCallStart {
937 id: "c1".into(),
938 name: "slow".into(),
939 arguments: json!({}),
940 },
941 StepChunk::ToolCallStart {
942 id: "c2".into(),
943 name: "fast".into(),
944 arguments: json!({}),
945 },
946 ],
947 });
948 let tool = service_fn(|inv: ToolInvocation| async move {
950 if inv.name == "slow" {
951 sleep(Duration::from_millis(40)).await;
952 } else {
953 sleep(Duration::from_millis(5)).await;
954 }
955 Ok::<_, BoxError>(ToolOutput {
956 id: inv.id,
957 result: json!({"label": inv.name}),
958 })
959 });
960 let mut svc = StepStreamService::new(provider, tool).parallel_tools(true);
961 let req = CreateChatCompletionRequestArgs::default()
962 .model("gpt-4o")
963 .messages(vec![])
964 .build()
965 .unwrap();
966 let mut stream = svc.call(req).await.unwrap();
967 let mut end_ids: Vec<String> = Vec::new();
968 let mut saw_complete = false;
969 let mut final_messages: Option<Vec<ChatCompletionRequestMessage>> = None;
970 while let Some(item) = stream.next().await {
971 match item {
972 StepChunk::ToolCallEnd { id, .. } => end_ids.push(id),
973 StepChunk::StepComplete { outcome } => {
974 saw_complete = true;
975 match outcome {
976 StepOutcome::Next { messages, .. } | StepOutcome::Done { messages, .. } => {
977 final_messages = Some(messages);
978 }
979 }
980 }
981 _ => {}
982 }
983 }
984 assert!(saw_complete);
985 assert_eq!(end_ids, vec!["c1".to_string(), "c2".to_string()]);
986 if let Some(msgs) = final_messages {
987 let policy = ValidationPolicy {
988 allow_repeated_roles: true,
989 require_user_first: false,
990 require_user_present: false,
991 ..Default::default()
992 };
993 assert!(validate_conversation(&msgs, &policy).is_none());
994 }
995 }
996
997 #[tokio::test]
998 async fn step_stream_parallel_error_propagation() {
999 let provider = Arc::new(FakeProvider {
1001 items: vec![
1002 StepChunk::ToolCallStart {
1003 id: "g1".into(),
1004 name: "good".into(),
1005 arguments: json!({}),
1006 },
1007 StepChunk::ToolCallStart {
1008 id: "b1".into(),
1009 name: "bad".into(),
1010 arguments: json!({}),
1011 },
1012 ],
1013 });
1014 let tool = service_fn(|inv: ToolInvocation| async move {
1015 if inv.name == "bad" {
1016 Err::<ToolOutput, BoxError>("boom".into())
1017 } else {
1018 Ok::<_, BoxError>(ToolOutput {
1019 id: inv.id,
1020 result: json!({}),
1021 })
1022 }
1023 });
1024 let mut svc = StepStreamService::new(provider, tool).parallel_tools(true);
1025 let req = CreateChatCompletionRequestArgs::default()
1026 .model("gpt-4o")
1027 .messages(vec![])
1028 .build()
1029 .unwrap();
1030 let mut stream = svc.call(req).await.unwrap();
1031 let mut saw_error = false;
1032 let mut saw_complete = false;
1033 while let Some(item) = stream.next().await {
1034 match item {
1035 StepChunk::Error(e) => {
1036 saw_error = true;
1037 assert!(e.contains("tool error"));
1038 }
1039 StepChunk::StepComplete { .. } => saw_complete = true,
1040 _ => {}
1041 }
1042 }
1043 assert!(saw_error);
1044 assert!(!saw_complete);
1045 }
1046
1047 #[tokio::test]
1048 async fn step_stream_parallel_concurrency_limit() {
1049 use std::sync::atomic::{AtomicUsize, Ordering};
1050 static CURRENT: AtomicUsize = AtomicUsize::new(0);
1051 static MAX_OBSERVED: AtomicUsize = AtomicUsize::new(0);
1052
1053 let mut items = Vec::new();
1054 for i in 0..8 {
1055 items.push(StepChunk::ToolCallStart {
1056 id: format!("c{}", i),
1057 name: "gate".into(),
1058 arguments: json!({}),
1059 });
1060 }
1061 let provider = Arc::new(FakeProvider { items });
1062 let tool = service_fn(|inv: ToolInvocation| async move {
1063 let now = CURRENT.fetch_add(1, Ordering::SeqCst) + 1;
1064 let max = MAX_OBSERVED.load(Ordering::SeqCst);
1065 if now > max {
1066 let _ = MAX_OBSERVED.compare_exchange(max, now, Ordering::SeqCst, Ordering::SeqCst);
1067 }
1068 sleep(Duration::from_millis(10)).await;
1069 CURRENT.fetch_sub(1, Ordering::SeqCst);
1070 Ok::<_, BoxError>(ToolOutput {
1071 id: inv.id,
1072 result: json!({}),
1073 })
1074 });
1075
1076 let mut svc = StepStreamService::new(provider, tool)
1077 .parallel_tools(true)
1078 .tool_concurrency_limit(3);
1079 let req = CreateChatCompletionRequestArgs::default()
1080 .model("gpt-4o")
1081 .messages(vec![])
1082 .build()
1083 .unwrap();
1084 let mut stream = svc.call(req).await.unwrap();
1085 while let Some(_item) = stream.next().await {}
1086 assert!(MAX_OBSERVED.load(Ordering::SeqCst) <= 3);
1087 }
1088
1089 #[tokio::test]
1090 async fn step_stream_parallel_failfast_early_termination() {
1091 use serde_json::json;
1092 let provider = Arc::new(FakeProvider {
1094 items: vec![
1095 StepChunk::ToolCallStart {
1096 id: "b1".into(),
1097 name: "bad".into(),
1098 arguments: json!({}),
1099 },
1100 StepChunk::ToolCallStart {
1101 id: "s1".into(),
1102 name: "slow".into(),
1103 arguments: json!({}),
1104 },
1105 ],
1106 });
1107 let tool = service_fn(|inv: ToolInvocation| async move {
1108 if inv.name == "bad" {
1109 Err::<ToolOutput, BoxError>("boom".into())
1110 } else {
1111 sleep(Duration::from_millis(40)).await;
1112 Ok::<_, BoxError>(ToolOutput {
1113 id: inv.id,
1114 result: json!({"ok":true}),
1115 })
1116 }
1117 });
1118 let mut svc = StepStreamService::new(provider, tool)
1119 .parallel_tools(true)
1120 .tool_join_policy(crate::core::ToolJoinPolicy::FailFast);
1121 let req = CreateChatCompletionRequestArgs::default()
1122 .model("gpt-4o")
1123 .messages(vec![])
1124 .build()
1125 .unwrap();
1126 let mut stream = svc.call(req).await.unwrap();
1127 let mut saw_slow_end = false;
1128 let mut saw_error = false;
1129 while let Some(item) = stream.next().await {
1130 match item {
1131 StepChunk::ToolCallEnd { id, .. } => {
1132 if id == "s1" {
1133 saw_slow_end = true;
1134 }
1135 }
1136 StepChunk::Error(_) => {
1137 saw_error = true;
1138 }
1139 _ => {}
1140 }
1141 }
1142 assert!(saw_error);
1143 assert!(!saw_slow_end);
1144 }
1145
1146 #[tokio::test]
1147 async fn step_stream_parallel_joinall_emits_successes_then_error() {
1148 use serde_json::json;
1149 let provider = Arc::new(FakeProvider {
1151 items: vec![
1152 StepChunk::ToolCallStart {
1153 id: "b1".into(),
1154 name: "bad".into(),
1155 arguments: json!({}),
1156 },
1157 StepChunk::ToolCallStart {
1158 id: "s1".into(),
1159 name: "slow".into(),
1160 arguments: json!({}),
1161 },
1162 ],
1163 });
1164 let tool = service_fn(|inv: ToolInvocation| async move {
1165 if inv.name == "bad" {
1166 Err::<ToolOutput, BoxError>("boom".into())
1167 } else {
1168 sleep(Duration::from_millis(20)).await;
1169 Ok::<_, BoxError>(ToolOutput {
1170 id: inv.id,
1171 result: json!({"ok":true}),
1172 })
1173 }
1174 });
1175 let mut svc = StepStreamService::new(provider, tool)
1176 .parallel_tools(true)
1177 .tool_join_policy(crate::core::ToolJoinPolicy::JoinAll)
1178 .tool_concurrency_limit(1);
1179 let req = CreateChatCompletionRequestArgs::default()
1180 .model("gpt-4o")
1181 .messages(vec![])
1182 .build()
1183 .unwrap();
1184 let mut stream = svc.call(req).await.unwrap();
1185 let mut saw_slow_end = false;
1186 let mut saw_error = false;
1187 let mut saw_complete = false;
1188 while let Some(item) = stream.next().await {
1189 match item {
1190 StepChunk::ToolCallEnd { id, .. } => {
1191 if id == "s1" {
1192 saw_slow_end = true;
1193 }
1194 }
1195 StepChunk::Error(_) => {
1196 saw_error = true;
1197 }
1198 StepChunk::StepComplete { .. } => {
1199 saw_complete = true;
1200 }
1201 _ => {}
1202 }
1203 }
1204 assert!(saw_slow_end);
1205 assert!(saw_error);
1206 assert!(!saw_complete);
1207 }
1208}