Skip to main content

swink_agent/agent/
state_updates.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::Ordering;
4
5use futures::{Stream, StreamExt};
6
7use crate::error::AgentError;
8use crate::loop_::AgentEvent;
9use crate::types::{AgentMessage, AgentResult, LlmMessage, StopReason, Usage};
10
11use super::Agent;
12
13impl Agent {
14    /// Collect a stream to completion, updating agent state along the way.
15    pub(super) async fn collect_stream(
16        &mut self,
17        mut stream: Pin<Box<dyn Stream<Item = AgentEvent> + Send>>,
18    ) -> Result<AgentResult, AgentError> {
19        let mut all_messages: Vec<AgentMessage> = Vec::new();
20        let mut state_messages = self.in_flight_llm_messages.take().unwrap_or_default();
21        let mut checkpoint_messages = self
22            .in_flight_messages
23            .take()
24            .unwrap_or_else(|| clone_messages(&state_messages));
25        let mut received_full_context = false;
26        let mut stop_reason = StopReason::Stop;
27        let mut usage = Usage::default();
28        let mut cost = crate::types::Cost::default();
29        let mut error: Option<String> = None;
30        let mut transfer_signal: Option<crate::transfer::TransferSignal> = None;
31
32        while let Some(event) = stream.next().await {
33            self.dispatch_event(&event);
34            self.update_state_from_event(&event);
35
36            match event {
37                AgentEvent::TransferInitiated { signal } => {
38                    transfer_signal = Some(signal);
39                    stop_reason = StopReason::Transfer;
40                }
41                AgentEvent::TurnEnd {
42                    assistant_message,
43                    tool_results,
44                    ..
45                } => {
46                    // Preserve Transfer stop reason set by TransferInitiated event
47                    if transfer_signal.is_none() {
48                        stop_reason = assistant_message.stop_reason;
49                    }
50                    usage += assistant_message.usage.clone();
51                    cost += assistant_message.cost.clone();
52                    if let Some(ref err) = assistant_message.error_message {
53                        error = Some(err.clone());
54                    }
55                    let assistant_llm = LlmMessage::Assistant(assistant_message);
56                    state_messages.push(AgentMessage::Llm(assistant_llm.clone()));
57                    checkpoint_messages.push(AgentMessage::Llm(assistant_llm.clone()));
58                    all_messages.push(AgentMessage::Llm(assistant_llm));
59                    for tr in tool_results {
60                        state_messages.push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
61                        checkpoint_messages
62                            .push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
63                        all_messages.push(AgentMessage::Llm(LlmMessage::ToolResult(tr)));
64                    }
65                }
66                AgentEvent::AgentEnd { messages } => match Arc::try_unwrap(messages) {
67                    Ok(returned) => {
68                        self.state.messages = returned;
69                        received_full_context = true;
70                    }
71                    Err(messages) => {
72                        self.state.messages = clone_messages(messages.as_ref());
73                        received_full_context = true;
74                    }
75                },
76                _ => {}
77            }
78        }
79
80        if !received_full_context {
81            self.state.messages = checkpoint_messages;
82        }
83        self.state.is_running = false;
84        self.loop_active.store(false, Ordering::Release);
85        self.pending_message_snapshot.clear();
86        self.loop_context_snapshot.clear();
87        self.state.error.clone_from(&error);
88        self.idle_notify.notify_waiters();
89
90        Ok(AgentResult {
91            messages: all_messages,
92            stop_reason,
93            usage,
94            cost,
95            error,
96            transfer_signal,
97        })
98    }
99
100    /// Processes a streaming event, updating [`Agent::state`] and notifying subscribers.
101    pub fn handle_stream_event(&mut self, event: &AgentEvent) {
102        self.dispatch_event(event);
103        self.update_state_from_event(event);
104
105        match event {
106            AgentEvent::TurnEnd {
107                assistant_message,
108                tool_results,
109                ..
110            } => {
111                let msgs = self.in_flight_llm_messages.get_or_insert_with(Vec::new);
112                msgs.push(AgentMessage::Llm(LlmMessage::Assistant(
113                    assistant_message.clone(),
114                )));
115                let checkpoint_msgs = self.in_flight_messages.get_or_insert_with(Vec::new);
116                checkpoint_msgs.push(AgentMessage::Llm(LlmMessage::Assistant(
117                    assistant_message.clone(),
118                )));
119                for tr in tool_results {
120                    msgs.push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
121                    checkpoint_msgs.push(AgentMessage::Llm(LlmMessage::ToolResult(tr.clone())));
122                }
123                // Capture terminal error so it survives through AgentEnd.
124                if let Some(ref err) = assistant_message.error_message {
125                    self.state.error = Some(err.clone());
126                }
127            }
128            AgentEvent::AgentEnd { messages } => {
129                self.state.messages = clone_messages(messages.as_ref());
130                self.in_flight_llm_messages = None;
131                self.in_flight_messages = None;
132                self.pending_message_snapshot.clear();
133                self.loop_context_snapshot.clear();
134                // Preserve terminal error — do not clear self.state.error.
135                self.idle_notify.notify_waiters();
136            }
137            _ => {}
138        }
139    }
140
141    fn update_state_from_event(&mut self, event: &AgentEvent) {
142        match event {
143            AgentEvent::MessageStart => {
144                self.state.stream_message = None;
145            }
146            AgentEvent::MessageEnd { message } => {
147                self.state.stream_message =
148                    Some(AgentMessage::Llm(LlmMessage::Assistant(message.clone())));
149            }
150            AgentEvent::ToolExecutionStart { id, .. } => {
151                self.state.pending_tool_calls.insert(id.clone());
152            }
153            AgentEvent::TurnEnd { .. } => {
154                self.state.pending_tool_calls.clear();
155                self.state.stream_message = None;
156            }
157            AgentEvent::AgentEnd { .. } => {
158                self.state.is_running = false;
159                self.loop_active.store(false, Ordering::Release);
160                self.state.pending_tool_calls.clear();
161                self.state.stream_message = None;
162            }
163            _ => {}
164        }
165    }
166}
167
168fn clone_messages(messages: &[AgentMessage]) -> Vec<AgentMessage> {
169    messages
170        .iter()
171        .filter_map(|message| match message {
172            AgentMessage::Llm(llm) => Some(AgentMessage::Llm(llm.clone())),
173            AgentMessage::Custom(cm) => cm.clone_box().map_or_else(
174                || {
175                    tracing::warn!(
176                        "CustomMessage {:?} does not support clone_box — dropped during state rebuild",
177                        cm
178                    );
179                    None
180                },
181                |cloned| Some(AgentMessage::Custom(cloned)),
182            ),
183        })
184        .collect()
185}