swink_agent/agent/
state_updates.rs1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::Ordering;
4
5use futures::{Stream, StreamExt};
6
7use crate::error::AgentError;
8use crate::loop_::AgentEvent;
9use crate::types::{AgentMessage, AgentResult, LlmMessage, StopReason, Usage};
10
11use super::Agent;
12
13impl Agent {
14 pub(super) async fn collect_stream(
16 &mut self,
17 mut stream: Pin<Box<dyn Stream<Item = AgentEvent> + Send>>,
18 ) -> Result<AgentResult, AgentError> {
19 let mut all_messages: Vec<AgentMessage> = Vec::new();
20 let mut state_messages = self.in_flight_llm_messages.take().unwrap_or_default();
21 let mut checkpoint_messages = self
22 .in_flight_messages
23 .take()
24 .unwrap_or_else(|| clone_messages(&state_messages));
25 let mut received_full_context = false;
26 let mut stop_reason = StopReason::Stop;
27 let mut usage = Usage::default();
28 let mut cost = crate::types::Cost::default();
29 let mut error: Option<String> = None;
30 let mut transfer_signal: Option<crate::transfer::TransferSignal> = None;
31
32 while let Some(event) = stream.next().await {
33 self.dispatch_event(&event);
34 self.update_state_from_event(&event);
35
36 match event {
37 AgentEvent::TransferInitiated { signal } => {
38 transfer_signal = Some(signal);
39 stop_reason = StopReason::Transfer;
40 }
41 AgentEvent::TurnEnd {
42 assistant_message,
43 tool_results,
44 ..
45 } => {
46 if transfer_signal.is_none() {
48 stop_reason = assistant_message.stop_reason;
49 }
50 usage += assistant_message.usage.clone();
51 cost += assistant_message.cost.clone();
52 if let Some(ref err) = assistant_message.error_message {
53 error = Some(err.clone());
54 }
55 let assistant_llm = LlmMessage::Assistant(assistant_message);
56 state_messages.push(AgentMessage::Llm(assistant_llm.clone()));
57 checkpoint_messages.push(AgentMessage::Llm(assistant_llm.clone()));
58 all_messages.push(AgentMessage::Llm(assistant_llm));
59 for tr in tool_results {
60 state_messages.push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
61 checkpoint_messages
62 .push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
63 all_messages.push(AgentMessage::Llm(LlmMessage::ToolResult(tr)));
64 }
65 }
66 AgentEvent::AgentEnd { messages } => match Arc::try_unwrap(messages) {
67 Ok(returned) => {
68 self.state.messages = returned;
69 received_full_context = true;
70 }
71 Err(messages) => {
72 self.state.messages = clone_messages(messages.as_ref());
73 received_full_context = true;
74 }
75 },
76 _ => {}
77 }
78 }
79
80 if !received_full_context {
81 self.state.messages = checkpoint_messages;
82 }
83 self.state.is_running = false;
84 self.loop_active.store(false, Ordering::Release);
85 self.pending_message_snapshot.clear();
86 self.loop_context_snapshot.clear();
87 self.state.error.clone_from(&error);
88 self.idle_notify.notify_waiters();
89
90 Ok(AgentResult {
91 messages: all_messages,
92 stop_reason,
93 usage,
94 cost,
95 error,
96 transfer_signal,
97 })
98 }
99
100 pub fn handle_stream_event(&mut self, event: &AgentEvent) {
102 self.dispatch_event(event);
103 self.update_state_from_event(event);
104
105 match event {
106 AgentEvent::TurnEnd {
107 assistant_message,
108 tool_results,
109 ..
110 } => {
111 let msgs = self.in_flight_llm_messages.get_or_insert_with(Vec::new);
112 msgs.push(AgentMessage::Llm(LlmMessage::Assistant(
113 assistant_message.clone(),
114 )));
115 let checkpoint_msgs = self.in_flight_messages.get_or_insert_with(Vec::new);
116 checkpoint_msgs.push(AgentMessage::Llm(LlmMessage::Assistant(
117 assistant_message.clone(),
118 )));
119 for tr in tool_results {
120 msgs.push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
121 checkpoint_msgs.push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
122 }
123 if let Some(ref err) = assistant_message.error_message {
125 self.state.error = Some(err.clone());
126 }
127 }
128 AgentEvent::AgentEnd { messages } => {
129 self.state.messages = clone_messages(messages.as_ref());
130 self.in_flight_llm_messages = None;
131 self.in_flight_messages = None;
132 self.pending_message_snapshot.clear();
133 self.loop_context_snapshot.clear();
134 self.idle_notify.notify_waiters();
136 }
137 _ => {}
138 }
139 }
140
141 fn update_state_from_event(&mut self, event: &AgentEvent) {
142 match event {
143 AgentEvent::MessageStart => {
144 self.state.stream_message = None;
145 }
146 AgentEvent::MessageEnd { message } => {
147 self.state.stream_message =
148 Some(AgentMessage::Llm(LlmMessage::Assistant(message.clone())));
149 }
150 AgentEvent::ToolExecutionStart { id, .. } => {
151 self.state.pending_tool_calls.insert(id.clone());
152 }
153 AgentEvent::TurnEnd { .. } => {
154 self.state.pending_tool_calls.clear();
155 self.state.stream_message = None;
156 }
157 AgentEvent::AgentEnd { .. } => {
158 self.state.is_running = false;
159 self.loop_active.store(false, Ordering::Release);
160 self.state.pending_tool_calls.clear();
161 self.state.stream_message = None;
162 }
163 _ => {}
164 }
165 }
166}
167
168fn clone_messages(messages: &[AgentMessage]) -> Vec<AgentMessage> {
169 messages
170 .iter()
171 .filter_map(|message| match message {
172 AgentMessage::Llm(llm) => Some(AgentMessage::Llm(llm.clone())),
173 AgentMessage::Custom(cm) => cm.clone_box().map_or_else(
174 || {
175 tracing::warn!(
176 "CustomMessage {:?} does not support clone_box — dropped during state rebuild",
177 cm
178 );
179 None
180 },
181 |cloned| Some(AgentMessage::Custom(cloned)),
182 ),
183 })
184 .collect()
185}