Skip to main content

swink_agent/
message_provider.rs

1//! Trait for polling steering and follow-up messages.
2//!
3//! [`MessageProvider`] replaces inline closures in [`AgentLoopConfig`](crate::loop_::AgentLoopConfig),
4//! giving callers a named, testable abstraction for injecting messages into the
5//! agent loop between turns.
6//!
7//! For push-based messaging, see [`ChannelMessageProvider`] and [`MessageSender`].
8
9use std::sync::Mutex;
10
11use crate::types::AgentMessage;
12
13/// Provides steering and follow-up messages to the agent loop.
14///
15/// Implementors are polled at well-defined points during loop execution:
16/// - [`poll_steering`](Self::poll_steering) is called after each tool execution batch.
17/// - [`poll_follow_up`](Self::poll_follow_up) is called when the agent would otherwise stop.
18pub trait MessageProvider: Send + Sync {
19    /// Return pending steering messages, if any.
20    ///
21    /// Called after tool execution completes. Returning a non-empty vec causes
22    /// a steering interrupt — pending tool calls may be cancelled and the new
23    /// messages are injected into the conversation.
24    fn poll_steering(&self) -> Vec<AgentMessage>;
25
26    /// Return pending follow-up messages, if any.
27    ///
28    /// Called when the model has finished a turn and no tool calls remain.
29    /// Returning a non-empty vec triggers another outer-loop iteration.
30    fn poll_follow_up(&self) -> Vec<AgentMessage>;
31
32    /// Non-draining check for pending steering messages.
33    ///
34    /// Used by tool-dispatch workers to detect steering interrupts early
35    /// without consuming queued messages — the authoritative drain happens
36    /// via [`poll_steering`](Self::poll_steering) in the interrupt collector.
37    ///
38    /// The default implementation returns `false`, so external providers
39    /// that only implement `poll_steering`/`poll_follow_up` will never
40    /// trigger a worker-initiated early interrupt. Built-in channel/queue
41    /// providers override this with a non-draining peek.
42    fn has_steering(&self) -> bool {
43        false
44    }
45}
46
47/// A [`MessageProvider`] built from two closures.
48///
49/// Created via [`from_fns`].
50pub struct FnMessageProvider<S, F>
51where
52    S: Fn() -> Vec<AgentMessage> + Send + Sync,
53    F: Fn() -> Vec<AgentMessage> + Send + Sync,
54{
55    steering: S,
56    follow_up: F,
57}
58
59impl<S, F> MessageProvider for FnMessageProvider<S, F>
60where
61    S: Fn() -> Vec<AgentMessage> + Send + Sync,
62    F: Fn() -> Vec<AgentMessage> + Send + Sync,
63{
64    fn poll_steering(&self) -> Vec<AgentMessage> {
65        (self.steering)()
66    }
67
68    fn poll_follow_up(&self) -> Vec<AgentMessage> {
69        (self.follow_up)()
70    }
71}
72
73/// Create a [`MessageProvider`] from two closures.
74///
75/// # Example
76///
77/// ```
78/// use swink_agent::from_fns;
79///
80/// let provider = from_fns(
81///     || vec![],  // no steering messages
82///     || vec![],  // no follow-up messages
83/// );
84/// ```
85pub const fn from_fns<S, F>(steering: S, follow_up: F) -> FnMessageProvider<S, F>
86where
87    S: Fn() -> Vec<AgentMessage> + Send + Sync,
88    F: Fn() -> Vec<AgentMessage> + Send + Sync,
89{
90    FnMessageProvider {
91        steering,
92        follow_up,
93    }
94}
95
96// ─── Channel-based MessageProvider ──────────────────────────────────────────
97
98/// A clonable handle for pushing messages into a [`ChannelMessageProvider`].
99///
100/// Obtained from [`message_channel`]. Messages sent through this handle are
101/// delivered as **follow-up** messages by default. Use [`send_steering`](Self::send_steering)
102/// to inject steering messages instead.
103#[derive(Clone)]
104pub struct MessageSender {
105    steering_tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
106    follow_up_tx: tokio::sync::mpsc::UnboundedSender<AgentMessage>,
107}
108
109impl MessageSender {
110    /// Push a steering message to the agent.
111    ///
112    /// Steering messages are polled after each tool execution batch and can
113    /// interrupt in-progress tool calls.
114    ///
115    /// Returns `false` if the receiver has been dropped.
116    pub fn send_steering(&self, message: AgentMessage) -> bool {
117        self.steering_tx.send(message).is_ok()
118    }
119
120    /// Push a follow-up message to the agent.
121    ///
122    /// Follow-up messages are polled when the agent would otherwise stop,
123    /// triggering another outer-loop iteration.
124    ///
125    /// Returns `false` if the receiver has been dropped.
126    pub fn send_follow_up(&self, message: AgentMessage) -> bool {
127        self.follow_up_tx.send(message).is_ok()
128    }
129
130    /// Alias for [`send_follow_up`](Self::send_follow_up).
131    pub fn send(&self, message: AgentMessage) -> bool {
132        self.send_follow_up(message)
133    }
134}
135
136impl std::fmt::Debug for MessageSender {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("MessageSender").finish_non_exhaustive()
139    }
140}
141
142/// A [`MessageProvider`] backed by tokio unbounded mpsc channels.
143///
144/// Created via [`message_channel`]. External code pushes messages through the
145/// paired [`MessageSender`]; the provider drains them when the agent loop polls.
146pub struct ChannelMessageProvider {
147    steering_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
148    follow_up_rx: Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
149}
150
151impl ChannelMessageProvider {
152    /// Drain all currently buffered messages from a receiver.
153    fn drain_receiver(
154        rx: &Mutex<tokio::sync::mpsc::UnboundedReceiver<AgentMessage>>,
155    ) -> Vec<AgentMessage> {
156        let mut guard = rx.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
157        let mut messages = Vec::new();
158        while let Ok(msg) = guard.try_recv() {
159            messages.push(msg);
160        }
161        messages
162    }
163}
164
165impl MessageProvider for ChannelMessageProvider {
166    fn poll_steering(&self) -> Vec<AgentMessage> {
167        Self::drain_receiver(&self.steering_rx)
168    }
169
170    fn poll_follow_up(&self) -> Vec<AgentMessage> {
171        Self::drain_receiver(&self.follow_up_rx)
172    }
173
174    fn has_steering(&self) -> bool {
175        let guard = self
176            .steering_rx
177            .lock()
178            .unwrap_or_else(std::sync::PoisonError::into_inner);
179        !guard.is_empty()
180    }
181}
182
183/// A [`MessageProvider`] that combines two providers, draining both on each poll.
184///
185/// Messages from the primary provider are returned first, followed by those
186/// from the secondary provider.
187pub struct ComposedMessageProvider {
188    primary: std::sync::Arc<dyn MessageProvider>,
189    secondary: std::sync::Arc<dyn MessageProvider>,
190}
191
192impl ComposedMessageProvider {
193    /// Create a composed provider from two providers.
194    pub fn new(
195        primary: std::sync::Arc<dyn MessageProvider>,
196        secondary: std::sync::Arc<dyn MessageProvider>,
197    ) -> Self {
198        Self { primary, secondary }
199    }
200}
201
202impl MessageProvider for ComposedMessageProvider {
203    fn poll_steering(&self) -> Vec<AgentMessage> {
204        let mut msgs = self.primary.poll_steering();
205        msgs.extend(self.secondary.poll_steering());
206        msgs
207    }
208
209    fn poll_follow_up(&self) -> Vec<AgentMessage> {
210        let mut msgs = self.primary.poll_follow_up();
211        msgs.extend(self.secondary.poll_follow_up());
212        msgs
213    }
214
215    fn has_steering(&self) -> bool {
216        self.primary.has_steering() || self.secondary.has_steering()
217    }
218}
219
220/// Create a channel-backed [`MessageProvider`] and its paired [`MessageSender`].
221///
222/// The returned `ChannelMessageProvider` implements [`MessageProvider`] and can
223/// be passed to [`AgentLoopConfig`](crate::loop_::AgentLoopConfig) or used with
224/// [`AgentOptions::with_message_channel`](crate::AgentOptions::with_message_channel).
225/// The `MessageSender` is a clonable handle that external code uses to push
226/// messages into the agent.
227///
228/// # Example
229///
230/// ```
231/// use swink_agent::message_channel;
232///
233/// let (provider, sender) = message_channel();
234/// // sender.send(msg) pushes a follow-up message
235/// // sender.send_steering(msg) pushes a steering message
236/// ```
237pub fn message_channel() -> (ChannelMessageProvider, MessageSender) {
238    let (steering_tx, steering_rx) = tokio::sync::mpsc::unbounded_channel();
239    let (follow_up_tx, follow_up_rx) = tokio::sync::mpsc::unbounded_channel();
240
241    let provider = ChannelMessageProvider {
242        steering_rx: Mutex::new(steering_rx),
243        follow_up_rx: Mutex::new(follow_up_rx),
244    };
245
246    let sender = MessageSender {
247        steering_tx,
248        follow_up_tx,
249    };
250
251    (provider, sender)
252}