Skip to main content

swink_agent/agent/
control.rs

1use std::sync::Arc;
2use std::sync::atomic::Ordering;
3
4use tokio::sync::Notify;
5use tracing::info;
6
7use super::Agent;
8
9async fn wait_for_idle_future<F>(
10    notify: Arc<Notify>,
11    active: Arc<std::sync::atomic::AtomicBool>,
12    after_register: F,
13) where
14    F: Fn() + Send + Sync + 'static,
15{
16    loop {
17        let notified = notify.notified();
18        after_register();
19        if !active.load(Ordering::Acquire) {
20            return;
21        }
22        notified.await;
23    }
24}
25
26impl Agent {
27    pub(super) fn clear_transient_runtime_state(&mut self) {
28        self.state.is_running = false;
29        self.state.stream_message = None;
30        self.state.pending_tool_calls.clear();
31        self.state.error = None;
32        self.abort_controller = None;
33        self.in_flight_llm_messages = None;
34        self.in_flight_messages = None;
35        self.pending_message_snapshot.clear();
36        self.loop_context_snapshot.clear();
37    }
38
39    /// Cancel the currently running loop, if any.
40    pub fn abort(&mut self) {
41        if let Some(ref token) = self.abort_controller {
42            info!("aborting agent loop");
43            token.cancel();
44        }
45    }
46
47    /// Reset the agent to its initial state, clearing messages, queues, and error.
48    ///
49    /// If a loop is currently active, the abort token is cancelled and the
50    /// generation counter is bumped so the stale `LoopGuardStream` cannot
51    /// clear `loop_active` for any future run.
52    pub fn reset(&mut self) {
53        // Cancel the running loop *before* dropping the token, so the spawned
54        // stream observes cancellation rather than continuing to emit events.
55        if let Some(ref token) = self.abort_controller {
56            token.cancel();
57        }
58
59        // Bump the generation counter so the old LoopGuardStream's Drop impl
60        // sees a mismatched generation and skips clearing loop_active.
61        self.loop_generation.fetch_add(1, Ordering::AcqRel);
62
63        self.state.messages.clear();
64        self.loop_active.store(false, Ordering::Release);
65        self.clear_transient_runtime_state();
66        self.clear_queues();
67        self.idle_notify.notify_waiters();
68    }
69
70    /// Returns a future that resolves when the agent is no longer running.
71    ///
72    /// Uses the shared `loop_active` flag so the future correctly resolves even
73    /// when the event stream is dropped without being drained to `AgentEnd`.
74    pub fn wait_for_idle(&self) -> impl Future<Output = ()> + Send + '_ {
75        wait_for_idle_future(
76            Arc::clone(&self.idle_notify),
77            Arc::clone(&self.loop_active),
78            || {},
79        )
80    }
81}
82
83#[cfg(all(test, feature = "testkit"))]
84mod tests {
85    use std::sync::Arc;
86    use std::sync::atomic::{AtomicBool, Ordering};
87    use std::task::Poll;
88
89    use futures::pin_mut;
90    use tokio::sync::Notify;
91
92    use crate::agent_options::AgentOptions;
93    use crate::stream::StreamFn;
94    use crate::testing::{
95        MockStreamFn, default_convert, default_model, text_only_events, user_msg,
96    };
97
98    use super::{Agent, wait_for_idle_future};
99
100    #[tokio::test]
101    async fn wait_for_idle_returns_when_idle_transition_happens_after_registration() {
102        let notify = Arc::new(Notify::new());
103        let active = Arc::new(AtomicBool::new(true));
104        let active_for_hook = Arc::clone(&active);
105        let notify_for_hook = Arc::clone(&notify);
106
107        let wait_for_idle = wait_for_idle_future(notify, active, move || {
108            active_for_hook.store(false, Ordering::Release);
109            notify_for_hook.notify_waiters();
110        });
111        pin_mut!(wait_for_idle);
112
113        assert!(matches!(
114            futures::poll!(wait_for_idle.as_mut()),
115            Poll::Ready(())
116        ));
117    }
118
119    #[tokio::test]
120    async fn wait_for_idle_stays_pending_until_idle_notification() {
121        let notify = Arc::new(Notify::new());
122        let active = Arc::new(AtomicBool::new(true));
123        let active_for_assert = Arc::clone(&active);
124
125        let wait_for_idle = wait_for_idle_future(Arc::clone(&notify), Arc::clone(&active), || {});
126        pin_mut!(wait_for_idle);
127
128        assert!(matches!(
129            futures::poll!(wait_for_idle.as_mut()),
130            Poll::Pending
131        ));
132        assert!(active_for_assert.load(Ordering::Acquire));
133
134        active.store(false, Ordering::Release);
135        notify.notify_waiters();
136
137        assert!(matches!(
138            futures::poll!(wait_for_idle.as_mut()),
139            Poll::Ready(())
140        ));
141    }
142
143    #[tokio::test]
144    async fn reset_notifies_pending_wait_for_idle_waiters() {
145        let stream_fn = Arc::new(MockStreamFn::new(vec![text_only_events("done")]));
146        let mut agent = Agent::new(AgentOptions::new(
147            "sys",
148            default_model(),
149            stream_fn as Arc<dyn StreamFn>,
150            default_convert,
151        ));
152
153        let _stream = agent
154            .prompt_stream(vec![user_msg("hi")])
155            .expect("prompt_stream should start a loop");
156
157        let wait_for_idle = wait_for_idle_future(
158            Arc::clone(&agent.idle_notify),
159            Arc::clone(&agent.loop_active),
160            || {},
161        );
162        pin_mut!(wait_for_idle);
163
164        assert!(matches!(
165            futures::poll!(wait_for_idle.as_mut()),
166            Poll::Pending
167        ));
168
169        agent.reset();
170
171        assert!(matches!(
172            futures::poll!(wait_for_idle.as_mut()),
173            Poll::Ready(())
174        ));
175    }
176}