1use crate::core::agent::events::{EventSink, SharedLifecycleEmitter};
2use crate::core::agent::session::AgentSessionState;
3use crate::core::agent::steering::SteeringMessage;
4use crate::exec::events::{ThreadEvent, ToolCallStatus};
5use crate::llm::provider::{
6 AssistantPhase, FinishReason, LLMProvider, LLMRequest, LLMResponse, NormalizedStreamEvent,
7 ToolCall, Usage as ProviderUsage,
8};
9use crate::llm::providers::gemini::wire::{Content, Part};
10use anyhow::Result;
11use async_trait::async_trait;
12use futures::StreamExt;
13use std::collections::VecDeque;
14use tokio::sync::mpsc::UnboundedReceiver;
15use tokio::sync::mpsc::error::TryRecvError;
16
17fn merge_stream_and_completed_text(accumulated: &mut String, completed: Option<&str>) {
18 let Some(completed_text) = completed else {
19 return;
20 };
21 if completed_text.is_empty() {
22 return;
23 }
24 if accumulated.is_empty() {
25 accumulated.push_str(completed_text);
26 return;
27 }
28 if completed_text == accumulated.as_str() {
29 return;
30 }
31 if let Some(suffix) = completed_text.strip_prefix(accumulated.as_str()) {
32 accumulated.push_str(suffix);
33 return;
34 }
35 accumulated.clear();
36 accumulated.push_str(completed_text);
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum RuntimeControl {
41 Continue,
42 Resumed,
43 StopRequested,
44}
45
46#[doc(hidden)]
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum RuntimeModelProgress {
49 OutputDelta(String),
50 ReasoningDelta(String),
51 ReasoningStage(String),
52 ToolCallStarted {
53 call_id: String,
54 name: Option<String>,
55 },
56 ToolCallDelta {
57 call_id: String,
58 delta: String,
59 },
60}
61
62#[derive(Debug, Clone)]
63struct RuntimeModelOutput {
64 response: LLMResponse,
65}
66
67#[async_trait]
68trait RuntimeModelAdapter {
69 async fn execute(
70 &mut self,
71 request: LLMRequest,
72 timeout: Option<std::time::Duration>,
73 on_progress: &mut (dyn FnMut(RuntimeModelProgress) + Send),
74 ) -> Result<RuntimeModelOutput>;
75}
76
77struct ProviderRuntimeModelAdapter<'a> {
78 provider: &'a mut Box<dyn LLMProvider>,
79 steering: &'a mut RuntimeSteering,
80}
81
82impl<'a> ProviderRuntimeModelAdapter<'a> {
83 fn new(provider: &'a mut Box<dyn LLMProvider>, steering: &'a mut RuntimeSteering) -> Self {
84 Self { provider, steering }
85 }
86}
87
88#[async_trait]
89impl RuntimeModelAdapter for ProviderRuntimeModelAdapter<'_> {
90 async fn execute(
91 &mut self,
92 request: LLMRequest,
93 timeout: Option<std::time::Duration>,
94 on_progress: &mut (dyn FnMut(RuntimeModelProgress) + Send),
95 ) -> Result<RuntimeModelOutput> {
96 let request_model = request.model.clone();
97 let mut stream = if let Some(duration) = timeout {
98 match tokio::time::timeout(duration, self.provider.stream_normalized(request)).await {
99 Ok(result) => result?,
100 Err(_) => {
101 return Err(anyhow::anyhow!(
102 "Stream request timed out after {:?}",
103 duration
104 ));
105 }
106 }
107 } else {
108 self.provider.stream_normalized(request).await?
109 };
110
111 let mut final_usage = ProviderUsage::default();
112 let mut completed_response: Option<LLMResponse> = None;
113 while let Some(event_result) = stream.next().await {
114 if matches!(
115 self.steering.poll_turn_control().await,
116 RuntimeControl::StopRequested
117 ) {
118 let mut response = LLMResponse {
119 model: request_model.clone(),
120 finish_reason: FinishReason::Error("Cancelled".to_string()),
121 usage: Some(final_usage.clone()),
122 ..Default::default()
123 };
124 if response.usage.as_ref().is_some_and(|usage| {
125 usage.prompt_tokens == 0
126 && usage.completion_tokens == 0
127 && usage.total_tokens == 0
128 }) {
129 response.usage = None;
130 }
131 return Ok(RuntimeModelOutput { response });
132 }
133
134 match event_result? {
135 NormalizedStreamEvent::TextDelta { delta } => {
136 on_progress(RuntimeModelProgress::OutputDelta(delta));
137 }
138 NormalizedStreamEvent::ReasoningDelta { delta } => {
139 on_progress(RuntimeModelProgress::ReasoningDelta(delta));
140 }
141 NormalizedStreamEvent::ToolCallStart { call_id, name } => {
142 on_progress(RuntimeModelProgress::ToolCallStarted { call_id, name });
143 }
144 NormalizedStreamEvent::ToolCallDelta { call_id, delta } => {
145 on_progress(RuntimeModelProgress::ToolCallDelta { call_id, delta });
146 }
147 NormalizedStreamEvent::Usage { usage } => {
148 final_usage = usage;
149 }
150 NormalizedStreamEvent::Done { response } => {
151 let mut response = *response;
152 if response.usage.is_none()
153 && (final_usage.prompt_tokens > 0
154 || final_usage.completion_tokens > 0
155 || final_usage.total_tokens > 0)
156 {
157 response.usage = Some(final_usage.clone());
158 }
159 completed_response = Some(response);
160 break;
161 }
162 }
163 }
164
165 let mut response = completed_response.unwrap_or_default();
166 if response.model.is_empty() {
167 response.model = request_model;
168 }
169 if response.usage.is_none()
170 && (final_usage.prompt_tokens > 0
171 || final_usage.completion_tokens > 0
172 || final_usage.total_tokens > 0)
173 {
174 response.usage = Some(final_usage);
175 }
176
177 Ok(RuntimeModelOutput { response })
178 }
179}
180
181pub struct RuntimeSteering {
182 steering_receiver: Option<UnboundedReceiver<SteeringMessage>>,
183 queued_follow_up_inputs: VecDeque<String>,
184}
185
186impl Default for RuntimeSteering {
187 fn default() -> Self {
188 Self::new(None)
189 }
190}
191
192impl RuntimeSteering {
193 fn new(steering_receiver: Option<UnboundedReceiver<SteeringMessage>>) -> Self {
194 Self {
195 steering_receiver,
196 queued_follow_up_inputs: VecDeque::new(),
197 }
198 }
199
200 pub fn set_receiver(&mut self, receiver: Option<UnboundedReceiver<SteeringMessage>>) {
201 self.steering_receiver = receiver;
202 }
203
204 pub fn take_receiver(&mut self) -> Option<UnboundedReceiver<SteeringMessage>> {
205 self.steering_receiver.take()
206 }
207
208 #[must_use]
209 pub fn has_pending_follow_up_inputs(&self) -> bool {
210 !self.queued_follow_up_inputs.is_empty()
211 }
212
213 pub fn pop_follow_up_input(&mut self) -> Option<String> {
214 self.queued_follow_up_inputs.pop_front()
215 }
216
217 pub fn queue_follow_up_input(&mut self, input: String) {
218 self.queued_follow_up_inputs.push_back(input);
219 }
220
221 pub async fn poll_turn_control(&mut self) -> RuntimeControl {
222 self.poll_control().await
223 }
224
225 pub async fn poll_tool_control(&mut self) -> RuntimeControl {
226 self.poll_control().await
227 }
228
229 async fn poll_control(&mut self) -> RuntimeControl {
230 let mut paused = false;
231
232 loop {
233 let Some(receiver) = self.steering_receiver.as_mut() else {
234 return if paused {
235 RuntimeControl::Resumed
236 } else {
237 RuntimeControl::Continue
238 };
239 };
240
241 match receiver.try_recv() {
242 Ok(SteeringMessage::SteerStop) => return RuntimeControl::StopRequested,
243 Ok(SteeringMessage::Pause) => {
244 paused = true;
245 if matches!(self.wait_for_resume().await, RuntimeControl::StopRequested) {
246 return RuntimeControl::StopRequested;
247 }
248 }
249 Ok(SteeringMessage::Resume) => {
250 paused = true;
251 }
252 Ok(SteeringMessage::FollowUpInput(input)) => {
253 self.queued_follow_up_inputs.push_back(input);
254 }
255 Err(TryRecvError::Empty | TryRecvError::Disconnected) => {
256 return if paused {
257 RuntimeControl::Resumed
258 } else {
259 RuntimeControl::Continue
260 };
261 }
262 }
263 }
264 }
265
266 async fn wait_for_resume(&mut self) -> RuntimeControl {
267 loop {
268 let Some(receiver) = self.steering_receiver.as_mut() else {
269 return RuntimeControl::Continue;
270 };
271
272 match receiver.recv().await {
273 Some(SteeringMessage::Resume) => return RuntimeControl::Continue,
274 Some(SteeringMessage::SteerStop) => return RuntimeControl::StopRequested,
275 Some(SteeringMessage::FollowUpInput(input)) => {
276 self.queued_follow_up_inputs.push_back(input);
277 }
278 Some(SteeringMessage::Pause) => {}
279 None => return RuntimeControl::Continue,
280 }
281 }
282 }
283}
284
285pub struct TurnExecution {
286 pub response: LLMResponse,
287 pub content: String,
288 pub reasoning: Option<String>,
289}
290
291const MIN_REASONING_UPDATE_BYTES: usize = 256;
292const MAX_REASONING_UPDATE_EVENTS: usize = 2;
293
294#[doc(hidden)]
295pub struct StreamingLifecycleBridge {
296 event_sink: Option<EventSink>,
297 assistant_item_id: String,
298 reasoning_item_id: String,
299 lifecycle: SharedLifecycleEmitter,
300 tool_call_item_ids: hashbrown::HashMap<String, String>,
301 reasoning_stage: Option<String>,
302 reasoning_update_events: usize,
303 last_reasoning_emit_len: usize,
304}
305
306impl StreamingLifecycleBridge {
307 #[must_use]
308 pub fn new(event_sink: Option<EventSink>, turn_id: &str, step: usize, attempt: usize) -> Self {
309 Self {
310 event_sink,
311 assistant_item_id: format!("{turn_id}-step-{step}-assistant-stream-{attempt}"),
312 reasoning_item_id: format!("{turn_id}-step-{step}-reasoning-stream-{attempt}"),
313 lifecycle: SharedLifecycleEmitter::default(),
314 tool_call_item_ids: hashbrown::HashMap::new(),
315 reasoning_stage: None,
316 reasoning_update_events: 0,
317 last_reasoning_emit_len: 0,
318 }
319 }
320
321 pub fn on_progress(&mut self, event: RuntimeModelProgress) {
322 match event {
323 RuntimeModelProgress::OutputDelta(delta) => self.push_assistant_delta(&delta),
324 RuntimeModelProgress::ReasoningDelta(delta) => self.push_reasoning_delta(&delta),
325 RuntimeModelProgress::ReasoningStage(stage) => self.update_reasoning_stage(stage),
326 RuntimeModelProgress::ToolCallStarted { call_id, name } => {
327 self.start_tool_call(call_id, name);
328 }
329 RuntimeModelProgress::ToolCallDelta { call_id, delta } => {
330 self.push_tool_call_delta(call_id, &delta);
331 }
332 }
333 }
334
335 pub fn abort(&mut self) {
336 self.lifecycle.complete_open_text_items();
337 self.lifecycle
338 .complete_open_tool_calls_with_status(ToolCallStatus::Failed);
339 self.emit_pending_events();
340 }
341
342 pub fn complete_open_items(&mut self) {
343 self.lifecycle.complete_open_text_items();
344 self.emit_pending_events();
345 }
346
347 #[must_use]
348 pub fn take_streamed_tool_call_items(&mut self) -> hashbrown::HashMap<String, String> {
349 std::mem::take(&mut self.tool_call_item_ids)
350 }
351
352 fn push_assistant_delta(&mut self, delta: &str) {
353 if !self.lifecycle.append_assistant_delta(delta) {
354 return;
355 }
356
357 let _ = self
358 .lifecycle
359 .emit_assistant_snapshot(Some(self.assistant_item_id.clone()));
360 self.emit_pending_events();
361 }
362
363 fn push_reasoning_delta(&mut self, delta: &str) {
364 if !self.lifecycle.append_reasoning_delta(delta) {
365 return;
366 }
367
368 if !self.lifecycle.reasoning_started() {
369 if self
370 .lifecycle
371 .emit_reasoning_snapshot(Some(self.reasoning_item_id.clone()))
372 {
373 self.last_reasoning_emit_len = self.lifecycle.reasoning_len();
374 self.emit_pending_events();
375 }
376 return;
377 }
378
379 if !self.should_emit_reasoning_update(false) {
380 return;
381 }
382
383 if self
384 .lifecycle
385 .emit_reasoning_snapshot(Some(self.reasoning_item_id.clone()))
386 {
387 self.record_reasoning_update();
388 self.emit_pending_events();
389 }
390 }
391
392 fn update_reasoning_stage(&mut self, stage: String) {
393 let stage_changed = self.reasoning_stage.as_deref() != Some(stage.as_str());
394 self.reasoning_stage = Some(stage);
395 if !stage_changed
396 || !self
397 .lifecycle
398 .set_reasoning_stage(self.reasoning_stage.clone())
399 {
400 return;
401 }
402
403 if !self.lifecycle.reasoning_started() || !self.should_emit_reasoning_update(true) {
404 return;
405 }
406
407 if self.lifecycle.emit_reasoning_stage_update() {
408 self.record_reasoning_update();
409 self.emit_pending_events();
410 }
411 }
412
413 fn should_emit_reasoning_update(&self, stage_changed: bool) -> bool {
414 if self.reasoning_update_events >= MAX_REASONING_UPDATE_EVENTS {
415 return false;
416 }
417
418 stage_changed
419 || self
420 .lifecycle
421 .reasoning_len()
422 .saturating_sub(self.last_reasoning_emit_len)
423 >= MIN_REASONING_UPDATE_BYTES
424 }
425
426 fn record_reasoning_update(&mut self) {
427 self.reasoning_update_events += 1;
428 self.last_reasoning_emit_len = self.lifecycle.reasoning_len();
429 }
430
431 fn start_tool_call(&mut self, call_id: String, name: Option<String>) {
432 let item_id = format!("{}-tool-call-{call_id}", self.assistant_item_id);
433 self.tool_call_item_ids
434 .insert(call_id.clone(), item_id.clone());
435 let _ = self
436 .lifecycle
437 .start_tool_call(&call_id, name, Some(item_id));
438 self.emit_pending_events();
439 }
440
441 fn push_tool_call_delta(&mut self, call_id: String, delta: &str) {
442 if !self.lifecycle.append_tool_call_delta(
443 &call_id,
444 delta,
445 None,
446 Some(format!("{}-tool-call-{call_id}", self.assistant_item_id)),
447 ) {
448 return;
449 }
450 self.emit_pending_events();
451 }
452
453 fn emit_pending_events(&mut self) {
454 let Some(sink) = &self.event_sink else {
455 let _ = self.lifecycle.drain_events();
456 return;
457 };
458
459 for event in self.lifecycle.drain_events() {
460 let mut callback = sink.lock();
461 callback(&event);
462 }
463 }
464}
465
466pub struct AgentRuntime {
467 pub state: AgentSessionState,
468 steering: RuntimeSteering,
469 event_sink: Option<EventSink>,
470 lifecycle: SharedLifecycleEmitter,
471 emitted_events: Vec<ThreadEvent>,
472}
473
474impl AgentRuntime {
475 pub fn new(
476 state: AgentSessionState,
477 event_sink: Option<EventSink>,
478 steering_receiver: Option<UnboundedReceiver<SteeringMessage>>,
479 ) -> Self {
480 Self {
481 state,
482 steering: RuntimeSteering::new(steering_receiver),
483 event_sink,
484 lifecycle: SharedLifecycleEmitter::default(),
485 emitted_events: Vec::new(),
486 }
487 }
488
489 pub fn set_event_handler(&mut self, sink: Option<EventSink>) {
490 self.event_sink = sink;
491 }
492
493 pub fn set_steering_receiver(&mut self, receiver: Option<UnboundedReceiver<SteeringMessage>>) {
494 self.steering.set_receiver(receiver);
495 }
496
497 pub fn take_steering_receiver(&mut self) -> Option<UnboundedReceiver<SteeringMessage>> {
498 self.steering.take_receiver()
499 }
500
501 pub fn split_mut(&mut self) -> (&mut AgentSessionState, &mut RuntimeSteering) {
502 (&mut self.state, &mut self.steering)
503 }
504
505 #[must_use]
506 pub fn has_pending_follow_up_inputs(&self) -> bool {
507 self.steering.has_pending_follow_up_inputs()
508 }
509
510 pub fn run_until_idle(&mut self) -> Option<String> {
511 let input = self.steering.pop_follow_up_input()?;
512 self.state.add_user_message(input.clone());
513 Some(input)
514 }
515
516 pub async fn poll_turn_control(&mut self) -> RuntimeControl {
517 self.steering.poll_turn_control().await
518 }
519
520 pub async fn poll_tool_control(&mut self) -> RuntimeControl {
521 self.steering.poll_tool_control().await
522 }
523
524 pub fn take_emitted_events(&mut self) -> Vec<ThreadEvent> {
525 std::mem::take(&mut self.emitted_events)
526 }
527
528 #[must_use]
529 pub fn tool_call_item_id(&self, call_id: &str) -> Option<String> {
530 self.lifecycle
531 .tool_call_item_id(call_id)
532 .map(str::to_string)
533 }
534
535 pub fn complete_tool_call(&mut self, call_id: &str, status: ToolCallStatus) {
536 let _ = self.lifecycle.complete_tool_call(call_id, status);
537 self.emit_pending_lifecycle_events();
538 }
539
540 pub fn complete_open_tool_calls(&mut self, status: ToolCallStatus) {
541 self.lifecycle.complete_open_tool_calls_with_status(status);
542 self.emit_pending_lifecycle_events();
543 }
544
545 fn emit_event(&mut self, event: ThreadEvent) {
546 self.emitted_events.push(event.clone());
547 if let Some(sink) = &self.event_sink {
548 let mut callback = sink.lock();
549 callback(&event);
550 }
551 }
552
553 fn emit_pending_lifecycle_events(&mut self) {
554 for event in self.lifecycle.drain_events() {
555 self.emit_event(event);
556 }
557 }
558
559 fn finalize_assistant_lifecycle(&mut self, text: &str) {
560 if text.trim().is_empty() {
561 return;
562 }
563
564 let should_emit_snapshot =
565 !self.lifecycle.assistant_started() || self.lifecycle.replace_assistant_text(text);
566 if should_emit_snapshot {
567 let _ = self.lifecycle.emit_assistant_snapshot(None);
568 }
569 let _ = self.lifecycle.complete_assistant_stream();
570 }
571
572 fn finalize_reasoning_lifecycle(&mut self, text: &str) {
573 if text.trim().is_empty() {
574 return;
575 }
576
577 let should_emit_snapshot =
578 !self.lifecycle.reasoning_started() || self.lifecycle.replace_reasoning_text(text);
579 if should_emit_snapshot {
580 let _ = self.lifecycle.emit_reasoning_snapshot(None);
581 }
582 let _ = self.lifecycle.complete_reasoning_stream();
583 }
584
585 fn finalize_tool_call_lifecycle(
586 &mut self,
587 tool_calls: Option<&[ToolCall]>,
588 _finish_reason: &str,
589 ) {
590 if let Some(tool_calls) = tool_calls {
591 for call in tool_calls {
592 let tool_name = call.function.as_ref().map(|function| function.name.clone());
593 let _ = self
594 .lifecycle
595 .start_tool_call(&call.id, tool_name.clone(), None);
596 if let Some(function) = call.function.as_ref() {
597 let _ = self.lifecycle.sync_tool_call_arguments(
598 &call.id,
599 &function.arguments,
600 tool_name,
601 None,
602 );
603 }
604 }
605 return;
606 }
607
608 self.lifecycle
609 .complete_open_tool_calls_with_status(ToolCallStatus::Failed);
610 }
611
612 fn record_model_progress(
613 &mut self,
614 event: RuntimeModelProgress,
615 full_text: &mut String,
616 full_reasoning: &mut String,
617 ) {
618 match event {
619 RuntimeModelProgress::OutputDelta(delta) => {
620 full_text.push_str(&delta);
621 if self.lifecycle.append_assistant_delta(&delta) {
622 let _ = self.lifecycle.emit_assistant_snapshot(None);
623 self.emit_pending_lifecycle_events();
624 }
625 }
626 RuntimeModelProgress::ReasoningDelta(delta) => {
627 full_reasoning.push_str(&delta);
628 if self.lifecycle.append_reasoning_delta(&delta) {
629 let _ = self.lifecycle.emit_reasoning_snapshot(None);
630 self.emit_pending_lifecycle_events();
631 }
632 }
633 RuntimeModelProgress::ReasoningStage(stage) => {
634 if self.lifecycle.set_reasoning_stage(Some(stage)) {
635 let _ = self.lifecycle.emit_reasoning_stage_update();
636 self.emit_pending_lifecycle_events();
637 }
638 }
639 RuntimeModelProgress::ToolCallStarted { call_id, name } => {
640 let _ = self.lifecycle.start_tool_call(&call_id, name, None);
641 self.emit_pending_lifecycle_events();
642 }
643 RuntimeModelProgress::ToolCallDelta { call_id, delta } => {
644 if self
645 .lifecycle
646 .append_tool_call_delta(&call_id, &delta, None, None)
647 {
648 self.emit_pending_lifecycle_events();
649 }
650 }
651 }
652 }
653
654 async fn run_turn_once_with_adapter<A: RuntimeModelAdapter + ?Sized>(
655 &mut self,
656 adapter: &mut A,
657 request: LLMRequest,
658 timeout: Option<std::time::Duration>,
659 ) -> Result<TurnExecution> {
660 let request_model = request.model.clone();
661 let start_time = std::time::Instant::now();
662 let mut full_text = String::new();
663 let mut full_reasoning = String::new();
664 let mut on_progress =
665 |event| self.record_model_progress(event, &mut full_text, &mut full_reasoning);
666 let RuntimeModelOutput { mut response } =
667 adapter.execute(request, timeout, &mut on_progress).await?;
668
669 merge_stream_and_completed_text(&mut full_text, response.content.as_deref());
670 merge_stream_and_completed_text(&mut full_reasoning, response.reasoning.as_deref());
671
672 let finish_reason = match response.finish_reason.clone() {
673 FinishReason::Stop => "stop".to_string(),
674 FinishReason::ToolCalls => "tool_calls".to_string(),
675 FinishReason::Length => "length".to_string(),
676 FinishReason::Error(message) => message,
677 _ => "unknown".to_string(),
678 };
679 let final_usage = response.usage.clone().unwrap_or_default();
680 let mut aggregated_tool_calls = response.tool_calls.clone();
681
682 self.finalize_assistant_lifecycle(&full_text);
683 self.finalize_reasoning_lifecycle(&full_reasoning);
684 self.finalize_tool_call_lifecycle(aggregated_tool_calls.as_deref(), &finish_reason);
685 self.emit_pending_lifecycle_events();
686
687 let mut turn_recorded = false;
688 self.state.record_turn(&start_time, &mut turn_recorded);
689
690 if final_usage.prompt_tokens > 0 || final_usage.completion_tokens > 0 {
691 self.state.stats.merge_usage(final_usage.clone());
692 }
693
694 aggregated_tool_calls = aggregated_tool_calls.filter(|calls| !calls.is_empty());
695
696 let mut assistant_message = crate::llm::provider::Message::assistant(full_text.clone());
697 if !full_reasoning.is_empty() {
698 assistant_message = assistant_message.with_reasoning(Some(full_reasoning.clone()));
699 }
700
701 match aggregated_tool_calls.as_ref() {
702 Some(calls) => {
703 assistant_message = assistant_message
704 .with_tool_calls(calls.clone())
705 .with_phase(Some(AssistantPhase::Commentary));
706 }
707 None => {
708 assistant_message = assistant_message.with_phase(Some(AssistantPhase::FinalAnswer));
709 }
710 }
711
712 self.state.messages.push(assistant_message);
713
714 self.state.conversation.push(Content {
715 role: "model".to_string(),
716 parts: vec![Part::Text {
717 text: full_text.clone(),
718 thought_signature: None,
719 }],
720 });
721 self.state.last_processed_message_idx = self.state.conversation.len();
722
723 if response.model.is_empty() {
724 response.model = request_model;
725 }
726 response.content = Some(full_text.clone());
727 response.reasoning = if full_reasoning.is_empty() {
728 None
729 } else {
730 Some(full_reasoning.clone())
731 };
732 response.tool_calls = aggregated_tool_calls.clone();
733 response.usage = Some(final_usage.clone());
734 response.finish_reason = if finish_reason == "tool_calls" {
735 FinishReason::ToolCalls
736 } else if finish_reason == "Cancelled" || finish_reason == "cancelled" {
737 FinishReason::Error("Cancelled".to_string())
738 } else {
739 response.finish_reason
740 };
741
742 Ok(TurnExecution {
743 response,
744 content: full_text,
745 reasoning: if full_reasoning.is_empty() {
746 None
747 } else {
748 Some(full_reasoning)
749 },
750 })
751 }
752
753 pub async fn run_turn_once(
754 &mut self,
755 provider: &mut Box<dyn LLMProvider>,
756 request: LLMRequest,
757 timeout: Option<std::time::Duration>,
758 ) -> Result<TurnExecution> {
759 let mut steering = std::mem::take(&mut self.steering);
760 let mut adapter = ProviderRuntimeModelAdapter::new(provider, &mut steering);
761 let result = self
762 .run_turn_once_with_adapter(&mut adapter, request, timeout)
763 .await;
764 self.steering = steering;
765 result
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772 use async_trait::async_trait;
773 use futures::stream;
774
775 use crate::llm::provider::{
776 LLMError, LLMNormalizedStream, LLMStream, LLMStreamEvent, NormalizedStreamEvent,
777 };
778
779 #[derive(Clone)]
780 struct CompletedOnlyStreamProvider {
781 response: LLMResponse,
782 }
783
784 #[derive(Clone)]
785 struct DeltaStreamProvider {
786 response: LLMResponse,
787 text_delta: String,
788 reasoning_delta: String,
789 }
790
791 #[async_trait]
792 impl LLMProvider for CompletedOnlyStreamProvider {
793 fn name(&self) -> &str {
794 "test-provider"
795 }
796
797 fn supports_streaming(&self) -> bool {
798 true
799 }
800
801 async fn generate(&self, _request: LLMRequest) -> Result<LLMResponse, LLMError> {
802 Ok(self.response.clone())
803 }
804
805 async fn stream(&self, _request: LLMRequest) -> Result<LLMStream, LLMError> {
806 Ok(Box::pin(stream::iter(vec![Ok(
807 LLMStreamEvent::Completed {
808 response: Box::new(self.response.clone()),
809 },
810 )])))
811 }
812
813 async fn stream_normalized(
814 &self,
815 _request: LLMRequest,
816 ) -> Result<LLMNormalizedStream, LLMError> {
817 Ok(Box::pin(stream::iter(vec![Ok(
818 NormalizedStreamEvent::Done {
819 response: Box::new(self.response.clone()),
820 },
821 )])))
822 }
823
824 fn supported_models(&self) -> Vec<String> {
825 vec!["test-model".to_string()]
826 }
827
828 fn validate_request(&self, _request: &LLMRequest) -> Result<(), LLMError> {
829 Ok(())
830 }
831 }
832
833 #[async_trait]
834 impl LLMProvider for DeltaStreamProvider {
835 fn name(&self) -> &str {
836 "delta-provider"
837 }
838
839 fn supports_streaming(&self) -> bool {
840 true
841 }
842
843 async fn generate(&self, _request: LLMRequest) -> Result<LLMResponse, LLMError> {
844 Ok(self.response.clone())
845 }
846
847 async fn stream(&self, _request: LLMRequest) -> Result<LLMStream, LLMError> {
848 Ok(Box::pin(stream::iter(vec![Ok(
849 LLMStreamEvent::Completed {
850 response: Box::new(self.response.clone()),
851 },
852 )])))
853 }
854
855 async fn stream_normalized(
856 &self,
857 _request: LLMRequest,
858 ) -> Result<LLMNormalizedStream, LLMError> {
859 Ok(Box::pin(stream::iter(vec![
860 Ok(NormalizedStreamEvent::ReasoningDelta {
861 delta: self.reasoning_delta.clone(),
862 }),
863 Ok(NormalizedStreamEvent::TextDelta {
864 delta: self.text_delta.clone(),
865 }),
866 Ok(NormalizedStreamEvent::Done {
867 response: Box::new(self.response.clone()),
868 }),
869 ])))
870 }
871
872 fn supported_models(&self) -> Vec<String> {
873 vec!["test-model".to_string()]
874 }
875
876 fn validate_request(&self, _request: &LLMRequest) -> Result<(), LLMError> {
877 Ok(())
878 }
879 }
880
881 #[tokio::test]
882 async fn queued_follow_up_inputs_are_applied_one_at_a_time() {
883 let state = AgentSessionState::new("session".to_string(), 16, 4, 128_000);
884 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
885 let mut runtime = AgentRuntime::new(state, None, Some(receiver));
886
887 sender
888 .send(SteeringMessage::FollowUpInput("first".to_string()))
889 .expect("first follow-up should queue");
890 sender
891 .send(SteeringMessage::FollowUpInput("second".to_string()))
892 .expect("second follow-up should queue");
893
894 assert_eq!(runtime.poll_turn_control().await, RuntimeControl::Continue);
895 assert!(runtime.has_pending_follow_up_inputs());
896 assert!(runtime.state.messages.is_empty());
897
898 assert_eq!(runtime.run_until_idle().as_deref(), Some("first"));
899 assert_eq!(
900 runtime
901 .state
902 .messages
903 .last()
904 .map(|message| message.get_text_content().into_owned()),
905 Some("first".to_string())
906 );
907 assert!(runtime.has_pending_follow_up_inputs());
908
909 assert_eq!(runtime.run_until_idle().as_deref(), Some("second"));
910 assert_eq!(
911 runtime
912 .state
913 .messages
914 .last()
915 .map(|message| message.get_text_content().into_owned()),
916 Some("second".to_string())
917 );
918 assert!(!runtime.has_pending_follow_up_inputs());
919 }
920
921 #[tokio::test]
922 async fn paused_runtime_resumes_and_preserves_follow_up_inputs() {
923 let state = AgentSessionState::new("session".to_string(), 16, 4, 128_000);
924 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
925 let mut runtime = AgentRuntime::new(state, None, Some(receiver));
926
927 sender
928 .send(SteeringMessage::Pause)
929 .expect("pause should send");
930 sender
931 .send(SteeringMessage::FollowUpInput(
932 "queued while paused".to_string(),
933 ))
934 .expect("follow-up should send");
935 sender
936 .send(SteeringMessage::Resume)
937 .expect("resume should send");
938
939 assert_eq!(runtime.poll_turn_control().await, RuntimeControl::Resumed);
940 assert!(runtime.has_pending_follow_up_inputs());
941 assert_eq!(
942 runtime.run_until_idle().as_deref(),
943 Some("queued while paused")
944 );
945 }
946
947 #[tokio::test]
948 async fn paused_runtime_stop_request_wins() {
949 let state = AgentSessionState::new("session".to_string(), 16, 4, 128_000);
950 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
951 let mut runtime = AgentRuntime::new(state, None, Some(receiver));
952
953 sender
954 .send(SteeringMessage::Pause)
955 .expect("pause should send");
956 sender
957 .send(SteeringMessage::SteerStop)
958 .expect("stop should send");
959
960 assert_eq!(
961 runtime.poll_turn_control().await,
962 RuntimeControl::StopRequested
963 );
964 assert!(!runtime.has_pending_follow_up_inputs());
965 }
966
967 #[tokio::test]
968 async fn run_turn_once_uses_completed_payload_when_no_deltas_exist() {
969 let response = LLMResponse {
970 content: Some("### Header\n- item".to_string()),
971 model: "test-model".to_string(),
972 finish_reason: FinishReason::Stop,
973 reasoning: Some("**why** this works".to_string()),
974 ..Default::default()
975 };
976 let provider = CompletedOnlyStreamProvider {
977 response: response.clone(),
978 };
979 let state = AgentSessionState::new("session".to_string(), 16, 4, 128_000);
980 let mut runtime = AgentRuntime::new(state, None, None);
981 let mut provider_box: Box<dyn LLMProvider> = Box::new(provider);
982 let request = LLMRequest {
983 model: "test-model".to_string(),
984 ..Default::default()
985 };
986
987 let turn = runtime
988 .run_turn_once(&mut provider_box, request, None)
989 .await
990 .expect("run_turn_once should succeed");
991
992 assert_eq!(turn.content, "### Header\n- item");
993 assert_eq!(turn.reasoning.as_deref(), Some("**why** this works"));
994 assert_eq!(turn.response.content.as_deref(), Some("### Header\n- item"));
995 assert_eq!(
996 turn.response.reasoning.as_deref(),
997 Some("**why** this works")
998 );
999 }
1000
1001 #[tokio::test]
1002 async fn provider_runtime_model_adapter_emits_delta_progress() {
1003 let response = LLMResponse {
1004 content: Some("hello world".to_string()),
1005 model: "test-model".to_string(),
1006 finish_reason: FinishReason::Stop,
1007 reasoning: Some("trace".to_string()),
1008 ..Default::default()
1009 };
1010 let provider = DeltaStreamProvider {
1011 response,
1012 text_delta: "hello world".to_string(),
1013 reasoning_delta: "trace".to_string(),
1014 };
1015 let mut steering = RuntimeSteering::default();
1016 let mut provider_box: Box<dyn LLMProvider> = Box::new(provider);
1017 let request = LLMRequest {
1018 model: "test-model".to_string(),
1019 ..Default::default()
1020 };
1021
1022 let mut adapter = ProviderRuntimeModelAdapter::new(&mut provider_box, &mut steering);
1023 let mut seen_progress = Vec::new();
1024 let mut callback = |event| seen_progress.push(event);
1025 let output = adapter
1026 .execute(request, None, &mut callback)
1027 .await
1028 .expect("adapter execution should succeed");
1029
1030 assert_eq!(output.response.content.as_deref(), Some("hello world"));
1031 assert_eq!(output.response.reasoning.as_deref(), Some("trace"));
1032 assert_eq!(
1033 seen_progress,
1034 vec![
1035 RuntimeModelProgress::ReasoningDelta("trace".to_string()),
1036 RuntimeModelProgress::OutputDelta("hello world".to_string()),
1037 ]
1038 );
1039 }
1040
1041 #[test]
1042 fn streaming_lifecycle_bridge_tracks_tool_call_item_ids() {
1043 let mut bridge = StreamingLifecycleBridge::new(None, "turn_tool_map", 5, 2);
1044 bridge.on_progress(RuntimeModelProgress::ToolCallStarted {
1045 call_id: "call_42".to_string(),
1046 name: Some("shell".to_string()),
1047 });
1048
1049 let item_ids = bridge.take_streamed_tool_call_items();
1050 assert_eq!(
1051 item_ids.get("call_42").map(String::as_str),
1052 Some("turn_tool_map-step-5-assistant-stream-2-tool-call-call_42")
1053 );
1054 }
1055}