Skip to main content

swink_agent/agent/
invoke.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::task::{Context, Poll};
5
6use futures::Stream;
7use tokio::sync::Notify;
8use tokio_util::sync::CancellationToken;
9use tracing::{info, warn};
10
11use crate::agent_options::{ApproveToolFn, GetApiKeyFn};
12use crate::error::AgentError;
13use crate::loop_::{AgentEvent, AgentLoopConfig, agent_loop, agent_loop_continue};
14use crate::message_provider::MessageProvider;
15use crate::types::message_codec::clone_messages_for_send;
16use crate::types::{AgentMessage, AgentResult, ContentBlock, LlmMessage};
17use crate::util::now_timestamp;
18
19use super::queueing::QueueMessageProvider;
20use super::{Agent, SharedRetryStrategy};
21
22// ─── LoopGuardStream ────────────────────────────────────────────────────────
23
24/// Wrapper stream that clears the agent's `loop_active` flag when dropped.
25///
26/// This ensures the agent becomes idle even if the caller drops the stream
27/// without draining it to `AgentEnd`. A generation counter prevents a stale
28/// guard from clearing the flag for a newer run.
29struct LoopGuardStream {
30    inner: Pin<Box<dyn Stream<Item = AgentEvent> + Send>>,
31    loop_active: Arc<AtomicBool>,
32    idle_notify: Arc<Notify>,
33    pending_message_snapshot: Arc<crate::pause_state::PendingMessageSnapshot>,
34    loop_context_snapshot: Arc<crate::pause_state::LoopContextSnapshot>,
35    generation: u64,
36    expected_generation: Arc<AtomicU64>,
37}
38
39impl Stream for LoopGuardStream {
40    type Item = AgentEvent;
41
42    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
43        self.inner.as_mut().poll_next(cx)
44    }
45}
46
47impl Drop for LoopGuardStream {
48    fn drop(&mut self) {
49        // Only clear loop_active if this guard belongs to the current run.
50        // A newer start_loop will have incremented loop_generation, making
51        // this guard's generation stale.
52        if self.expected_generation.load(Ordering::Acquire) == self.generation {
53            self.loop_active.store(false, Ordering::Release);
54            self.pending_message_snapshot.clear();
55            self.loop_context_snapshot.clear();
56            self.idle_notify.notify_waiters();
57        }
58    }
59}
60
61impl Agent {
62    /// Start a new loop with input messages, returning an event stream.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`AgentError::AlreadyRunning`] if the agent is already running.
67    pub fn prompt_stream(
68        &mut self,
69        input: Vec<AgentMessage>,
70    ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
71        self.check_not_running().inspect_err(|_| {
72            warn!("prompt_stream called while agent is already running");
73        })?;
74        info!(
75            model = %self.state.model.model_id,
76            input_messages = input.len(),
77            "prompt_stream starting"
78        );
79        self.start_loop(input, false)
80    }
81
82    /// Start a new loop with input messages, collecting to completion.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`AgentError::AlreadyRunning`] if the agent is already running.
87    pub async fn prompt_async(
88        &mut self,
89        input: Vec<AgentMessage>,
90    ) -> Result<AgentResult, AgentError> {
91        info!(
92            model = %self.state.model.model_id,
93            input_messages = input.len(),
94            "prompt_async starting"
95        );
96        let stream = self.prompt_stream(input)?;
97        self.collect_stream(stream).await
98    }
99
100    /// Start a new loop with input messages, blocking the current thread.
101    ///
102    /// # Errors
103    ///
104    /// Returns [`AgentError::AlreadyRunning`] if the agent is already running,
105    /// or [`AgentError::SyncInAsyncContext`] if called from within a Tokio runtime.
106    pub fn prompt_sync(&mut self, input: Vec<AgentMessage>) -> Result<AgentResult, AgentError> {
107        self.check_not_running()?;
108        let rt = new_blocking_runtime()?;
109        rt.block_on(async {
110            let stream = self.start_loop(input, false)?;
111            self.collect_stream(stream).await
112        })
113    }
114
115    /// Start a new loop from a plain text string, collecting to completion.
116    ///
117    /// Convenience wrapper that builds a `UserMessage` from the string.
118    pub async fn prompt_text(
119        &mut self,
120        text: impl Into<String>,
121    ) -> Result<AgentResult, AgentError> {
122        let msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
123            content: vec![ContentBlock::Text { text: text.into() }],
124            timestamp: now_timestamp(),
125            cache_hint: None,
126        }));
127        self.prompt_async(vec![msg]).await
128    }
129
130    /// Start a new loop from a text string with images, collecting to completion.
131    ///
132    /// Convenience wrapper that builds a `UserMessage` from text and image blocks.
133    pub async fn prompt_text_with_images(
134        &mut self,
135        text: impl Into<String>,
136        images: Vec<crate::types::ImageSource>,
137    ) -> Result<AgentResult, AgentError> {
138        let mut content = vec![ContentBlock::Text { text: text.into() }];
139        for source in images {
140            content.push(ContentBlock::Image { source });
141        }
142        let msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
143            content,
144            timestamp: now_timestamp(),
145            cache_hint: None,
146        }));
147        self.prompt_async(vec![msg]).await
148    }
149
150    /// Start a new loop from a plain text string, blocking the current thread.
151    ///
152    /// Convenience wrapper that builds a `UserMessage` from the string.
153    pub fn prompt_text_sync(&mut self, text: impl Into<String>) -> Result<AgentResult, AgentError> {
154        let msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
155            content: vec![ContentBlock::Text { text: text.into() }],
156            timestamp: now_timestamp(),
157            cache_hint: None,
158        }));
159        self.prompt_sync(vec![msg])
160    }
161
162    /// Continue from existing messages, returning an event stream.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`AgentError::AlreadyRunning`], [`AgentError::NoMessages`],
167    /// or [`AgentError::InvalidContinue`].
168    pub fn continue_stream(
169        &mut self,
170    ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
171        self.check_not_running()?;
172        self.validate_continue()?;
173        self.start_loop(Vec::new(), true)
174    }
175
176    /// Continue from existing messages, collecting to completion.
177    ///
178    /// # Errors
179    ///
180    /// Returns [`AgentError::AlreadyRunning`], [`AgentError::NoMessages`],
181    /// or [`AgentError::InvalidContinue`].
182    pub async fn continue_async(&mut self) -> Result<AgentResult, AgentError> {
183        let stream = self.continue_stream()?;
184        self.collect_stream(stream).await
185    }
186
187    /// Continue from existing messages, blocking the current thread.
188    ///
189    /// # Errors
190    ///
191    /// Returns [`AgentError::AlreadyRunning`], [`AgentError::NoMessages`],
192    /// [`AgentError::InvalidContinue`], or [`AgentError::SyncInAsyncContext`].
193    pub fn continue_sync(&mut self) -> Result<AgentResult, AgentError> {
194        self.check_not_running()?;
195        self.validate_continue()?;
196        let rt = new_blocking_runtime()?;
197        rt.block_on(async {
198            let stream = self.start_loop(Vec::new(), true)?;
199            self.collect_stream(stream).await
200        })
201    }
202
203    pub(super) fn check_not_running(&mut self) -> Result<(), AgentError> {
204        // Synchronise the observable `state.is_running` from the atomic ground
205        // truth so callers that inspect `agent.state()` see an up-to-date value.
206        let active = self.loop_active.load(Ordering::Acquire);
207        self.state.is_running = active;
208        if active {
209            return Err(AgentError::AlreadyRunning);
210        }
211        Ok(())
212    }
213
214    fn validate_continue(&self) -> Result<(), AgentError> {
215        if self.state.messages.is_empty() {
216            return Err(AgentError::NoMessages);
217        }
218        if let Some(AgentMessage::Llm(LlmMessage::Assistant(_))) = self.state.messages.last()
219            && !self.has_pending_messages()
220        {
221            return Err(AgentError::InvalidContinue);
222        }
223        Ok(())
224    }
225
226    /// Build the loop config and start the agent loop, returning a wrapped stream.
227    #[allow(clippy::unnecessary_wraps)]
228    fn start_loop(
229        &mut self,
230        input: Vec<AgentMessage>,
231        is_continue: bool,
232    ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
233        self.state.is_running = true;
234        self.state.error = None;
235        self.pending_message_snapshot.clear();
236        self.loop_context_snapshot.clear();
237        self.loop_active.store(true, Ordering::Release);
238        let generation = self.loop_generation.fetch_add(1, Ordering::AcqRel) + 1;
239
240        let token = CancellationToken::new();
241        self.abort_controller = Some(token.clone());
242
243        let config = self.build_loop_config();
244        let system_prompt = self.state.system_prompt.clone();
245        let llm_source: Box<dyn Iterator<Item = &AgentMessage>> = if is_continue {
246            Box::new(self.state.messages.iter())
247        } else {
248            Box::new(self.state.messages.iter().chain(input.iter()))
249        };
250        let in_flight_llm_messages: Vec<AgentMessage> = llm_source
251            .filter_map(|msg| match msg {
252                AgentMessage::Llm(llm) => Some(AgentMessage::Llm(llm.clone())),
253                AgentMessage::Custom(_) => None,
254            })
255            .collect();
256
257        let messages_for_loop = if is_continue {
258            std::mem::take(&mut self.state.messages)
259        } else {
260            let mut msgs = std::mem::take(&mut self.state.messages);
261            msgs.extend(input);
262            msgs
263        };
264        let in_flight_messages = clone_messages_for_send(&messages_for_loop);
265
266        let raw_stream = if is_continue {
267            agent_loop_continue(messages_for_loop, system_prompt, config, token)
268        } else {
269            agent_loop(messages_for_loop, system_prompt, config, token)
270        };
271
272        self.in_flight_llm_messages = Some(in_flight_llm_messages);
273        self.in_flight_messages = Some(in_flight_messages);
274
275        let guarded: Pin<Box<dyn Stream<Item = AgentEvent> + Send>> = Box::pin(LoopGuardStream {
276            inner: raw_stream,
277            loop_active: Arc::clone(&self.loop_active),
278            idle_notify: Arc::clone(&self.idle_notify),
279            pending_message_snapshot: Arc::clone(&self.pending_message_snapshot),
280            loop_context_snapshot: Arc::clone(&self.loop_context_snapshot),
281            generation,
282            expected_generation: Arc::clone(&self.loop_generation),
283        });
284        Ok(guarded)
285    }
286
287    #[allow(clippy::type_complexity)]
288    fn build_loop_config(&self) -> AgentLoopConfig {
289        let convert = Arc::clone(&self.convert_to_llm);
290        let convert_box: Box<dyn Fn(&AgentMessage) -> Option<LlmMessage> + Send + Sync> =
291            Box::new(move |msg| convert(msg));
292
293        let transform = self.transform_context.as_ref().map(Arc::clone);
294
295        let api_key_box = self.get_api_key.as_ref().map(|k| {
296            let k = Arc::clone(k);
297            let b: Box<GetApiKeyFn> = Box::new(move |provider| k(provider));
298            b
299        });
300
301        let queue_provider: Arc<dyn MessageProvider> = Arc::new(QueueMessageProvider {
302            steering_queue: Arc::clone(&self.steering_queue),
303            follow_up_queue: Arc::clone(&self.follow_up_queue),
304            steering_mode: self.steering_mode,
305            follow_up_mode: self.follow_up_mode,
306            pending_message_snapshot: Arc::clone(&self.pending_message_snapshot),
307        });
308
309        let message_provider: Arc<dyn MessageProvider> =
310            if let Some(ref external) = self.external_message_provider {
311                Arc::new(crate::message_provider::ComposedMessageProvider::new(
312                    queue_provider,
313                    Arc::clone(external),
314                ))
315            } else {
316                queue_provider
317            };
318
319        AgentLoopConfig {
320            agent_name: self.agent_name.clone(),
321            transfer_chain: self.transfer_chain.clone(),
322            model: self.state.model.clone(),
323            stream_options: self.stream_options.clone(),
324            retry_strategy: Box::new(SharedRetryStrategy(Arc::clone(&self.retry_strategy))),
325            stream_fn: Arc::clone(&self.stream_fn),
326            tools: self.state.tools.clone(),
327            convert_to_llm: convert_box,
328            transform_context: transform,
329            get_api_key: api_key_box,
330            message_provider: Some(message_provider),
331            pending_message_snapshot: Arc::clone(&self.pending_message_snapshot),
332            loop_context_snapshot: Arc::clone(&self.loop_context_snapshot),
333            approve_tool: self.approve_tool.as_ref().map(|a| {
334                let a = Arc::clone(a);
335                let b: Box<ApproveToolFn> = Box::new(move |req| a(req));
336                b
337            }),
338            approval_mode: self.approval_mode,
339            pre_turn_policies: self.pre_turn_policies.clone(),
340            pre_dispatch_policies: self.pre_dispatch_policies.clone(),
341            post_turn_policies: self.post_turn_policies.clone(),
342            post_loop_policies: self.post_loop_policies.clone(),
343            async_transform_context: self.async_transform_context.as_ref().map(Arc::clone),
344            metrics_collector: self.metrics_collector.as_ref().map(Arc::clone),
345            fallback: self.fallback.clone(),
346            tool_execution_policy: self.tool_execution_policy.clone(),
347            session_state: Arc::clone(&self.session_state),
348            credential_resolver: self.credential_resolver.as_ref().map(Arc::clone),
349            cache_config: self.cache_config.clone(),
350            cache_state: std::sync::Mutex::new(crate::context_cache::CacheState::new()),
351            dynamic_system_prompt: self.dynamic_system_prompt.clone(),
352        }
353    }
354}
355
356fn new_blocking_runtime_with(
357    build: impl FnOnce() -> std::io::Result<tokio::runtime::Runtime>,
358) -> Result<tokio::runtime::Runtime, AgentError> {
359    if tokio::runtime::Handle::try_current().is_ok() {
360        return Err(AgentError::SyncInAsyncContext);
361    }
362    build().map_err(AgentError::runtime_init)
363}
364
365/// Create a new Tokio runtime for blocking sync APIs, returning
366/// [`AgentError::SyncInAsyncContext`] if a runtime is already active on
367/// the current thread and [`AgentError::RuntimeInit`] if runtime construction
368/// fails.
369pub(super) fn new_blocking_runtime() -> Result<tokio::runtime::Runtime, AgentError> {
370    new_blocking_runtime_with(tokio::runtime::Runtime::new)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::new_blocking_runtime_with;
376    use crate::error::AgentError;
377
378    #[test]
379    fn new_blocking_runtime_returns_runtime_init_error() {
380        let err = new_blocking_runtime_with(|| Err(std::io::Error::other("boom"))).unwrap_err();
381
382        assert!(matches!(err, AgentError::RuntimeInit { .. }));
383        assert_eq!(
384            err.to_string(),
385            "failed to create Tokio runtime for sync API"
386        );
387    }
388}