1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::task::{Context, Poll};
5
6use futures::Stream;
7use tokio::sync::Notify;
8use tokio_util::sync::CancellationToken;
9use tracing::{info, warn};
10
11use crate::agent_options::{ApproveToolFn, GetApiKeyFn};
12use crate::error::AgentError;
13use crate::loop_::{AgentEvent, AgentLoopConfig, agent_loop, agent_loop_continue};
14use crate::message_provider::MessageProvider;
15use crate::types::message_codec::clone_messages_for_send;
16use crate::types::{AgentMessage, AgentResult, ContentBlock, LlmMessage};
17use crate::util::now_timestamp;
18
19use super::queueing::QueueMessageProvider;
20use super::{Agent, SharedRetryStrategy};
21
22struct LoopGuardStream {
30 inner: Pin<Box<dyn Stream<Item = AgentEvent> + Send>>,
31 loop_active: Arc<AtomicBool>,
32 idle_notify: Arc<Notify>,
33 pending_message_snapshot: Arc<crate::pause_state::PendingMessageSnapshot>,
34 loop_context_snapshot: Arc<crate::pause_state::LoopContextSnapshot>,
35 generation: u64,
36 expected_generation: Arc<AtomicU64>,
37}
38
39impl Stream for LoopGuardStream {
40 type Item = AgentEvent;
41
42 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
43 self.inner.as_mut().poll_next(cx)
44 }
45}
46
47impl Drop for LoopGuardStream {
48 fn drop(&mut self) {
49 if self.expected_generation.load(Ordering::Acquire) == self.generation {
53 self.loop_active.store(false, Ordering::Release);
54 self.pending_message_snapshot.clear();
55 self.loop_context_snapshot.clear();
56 self.idle_notify.notify_waiters();
57 }
58 }
59}
60
61impl Agent {
62 pub fn prompt_stream(
68 &mut self,
69 input: Vec<AgentMessage>,
70 ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
71 self.check_not_running().inspect_err(|_| {
72 warn!("prompt_stream called while agent is already running");
73 })?;
74 info!(
75 model = %self.state.model.model_id,
76 input_messages = input.len(),
77 "prompt_stream starting"
78 );
79 self.start_loop(input, false)
80 }
81
82 pub async fn prompt_async(
88 &mut self,
89 input: Vec<AgentMessage>,
90 ) -> Result<AgentResult, AgentError> {
91 info!(
92 model = %self.state.model.model_id,
93 input_messages = input.len(),
94 "prompt_async starting"
95 );
96 let stream = self.prompt_stream(input)?;
97 self.collect_stream(stream).await
98 }
99
100 pub fn prompt_sync(&mut self, input: Vec<AgentMessage>) -> Result<AgentResult, AgentError> {
107 self.check_not_running()?;
108 let rt = new_blocking_runtime()?;
109 rt.block_on(async {
110 let stream = self.start_loop(input, false)?;
111 self.collect_stream(stream).await
112 })
113 }
114
115 pub async fn prompt_text(
119 &mut self,
120 text: impl Into<String>,
121 ) -> Result<AgentResult, AgentError> {
122 let msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
123 content: vec![ContentBlock::Text { text: text.into() }],
124 timestamp: now_timestamp(),
125 cache_hint: None,
126 }));
127 self.prompt_async(vec![msg]).await
128 }
129
130 pub async fn prompt_text_with_images(
134 &mut self,
135 text: impl Into<String>,
136 images: Vec<crate::types::ImageSource>,
137 ) -> Result<AgentResult, AgentError> {
138 let mut content = vec![ContentBlock::Text { text: text.into() }];
139 for source in images {
140 content.push(ContentBlock::Image { source });
141 }
142 let msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
143 content,
144 timestamp: now_timestamp(),
145 cache_hint: None,
146 }));
147 self.prompt_async(vec![msg]).await
148 }
149
150 pub fn prompt_text_sync(&mut self, text: impl Into<String>) -> Result<AgentResult, AgentError> {
154 let msg = AgentMessage::Llm(LlmMessage::User(crate::types::UserMessage {
155 content: vec![ContentBlock::Text { text: text.into() }],
156 timestamp: now_timestamp(),
157 cache_hint: None,
158 }));
159 self.prompt_sync(vec![msg])
160 }
161
162 pub fn continue_stream(
169 &mut self,
170 ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
171 self.check_not_running()?;
172 self.validate_continue()?;
173 self.start_loop(Vec::new(), true)
174 }
175
176 pub async fn continue_async(&mut self) -> Result<AgentResult, AgentError> {
183 let stream = self.continue_stream()?;
184 self.collect_stream(stream).await
185 }
186
187 pub fn continue_sync(&mut self) -> Result<AgentResult, AgentError> {
194 self.check_not_running()?;
195 self.validate_continue()?;
196 let rt = new_blocking_runtime()?;
197 rt.block_on(async {
198 let stream = self.start_loop(Vec::new(), true)?;
199 self.collect_stream(stream).await
200 })
201 }
202
203 pub(super) fn check_not_running(&mut self) -> Result<(), AgentError> {
204 let active = self.loop_active.load(Ordering::Acquire);
207 self.state.is_running = active;
208 if active {
209 return Err(AgentError::AlreadyRunning);
210 }
211 Ok(())
212 }
213
214 fn validate_continue(&self) -> Result<(), AgentError> {
215 if self.state.messages.is_empty() {
216 return Err(AgentError::NoMessages);
217 }
218 if let Some(AgentMessage::Llm(LlmMessage::Assistant(_))) = self.state.messages.last()
219 && !self.has_pending_messages()
220 {
221 return Err(AgentError::InvalidContinue);
222 }
223 Ok(())
224 }
225
226 #[allow(clippy::unnecessary_wraps)]
228 fn start_loop(
229 &mut self,
230 input: Vec<AgentMessage>,
231 is_continue: bool,
232 ) -> Result<Pin<Box<dyn Stream<Item = AgentEvent> + Send>>, AgentError> {
233 self.state.is_running = true;
234 self.state.error = None;
235 self.pending_message_snapshot.clear();
236 self.loop_context_snapshot.clear();
237 self.loop_active.store(true, Ordering::Release);
238 let generation = self.loop_generation.fetch_add(1, Ordering::AcqRel) + 1;
239
240 let token = CancellationToken::new();
241 self.abort_controller = Some(token.clone());
242
243 let config = self.build_loop_config();
244 let system_prompt = self.state.system_prompt.clone();
245 let llm_source: Box<dyn Iterator<Item = &AgentMessage>> = if is_continue {
246 Box::new(self.state.messages.iter())
247 } else {
248 Box::new(self.state.messages.iter().chain(input.iter()))
249 };
250 let in_flight_llm_messages: Vec<AgentMessage> = llm_source
251 .filter_map(|msg| match msg {
252 AgentMessage::Llm(llm) => Some(AgentMessage::Llm(llm.clone())),
253 AgentMessage::Custom(_) => None,
254 })
255 .collect();
256
257 let messages_for_loop = if is_continue {
258 std::mem::take(&mut self.state.messages)
259 } else {
260 let mut msgs = std::mem::take(&mut self.state.messages);
261 msgs.extend(input);
262 msgs
263 };
264 let in_flight_messages = clone_messages_for_send(&messages_for_loop);
265
266 let raw_stream = if is_continue {
267 agent_loop_continue(messages_for_loop, system_prompt, config, token)
268 } else {
269 agent_loop(messages_for_loop, system_prompt, config, token)
270 };
271
272 self.in_flight_llm_messages = Some(in_flight_llm_messages);
273 self.in_flight_messages = Some(in_flight_messages);
274
275 let guarded: Pin<Box<dyn Stream<Item = AgentEvent> + Send>> = Box::pin(LoopGuardStream {
276 inner: raw_stream,
277 loop_active: Arc::clone(&self.loop_active),
278 idle_notify: Arc::clone(&self.idle_notify),
279 pending_message_snapshot: Arc::clone(&self.pending_message_snapshot),
280 loop_context_snapshot: Arc::clone(&self.loop_context_snapshot),
281 generation,
282 expected_generation: Arc::clone(&self.loop_generation),
283 });
284 Ok(guarded)
285 }
286
287 #[allow(clippy::type_complexity)]
288 fn build_loop_config(&self) -> AgentLoopConfig {
289 let convert = Arc::clone(&self.convert_to_llm);
290 let convert_box: Box<dyn Fn(&AgentMessage) -> Option<LlmMessage> + Send + Sync> =
291 Box::new(move |msg| convert(msg));
292
293 let transform = self.transform_context.as_ref().map(Arc::clone);
294
295 let api_key_box = self.get_api_key.as_ref().map(|k| {
296 let k = Arc::clone(k);
297 let b: Box<GetApiKeyFn> = Box::new(move |provider| k(provider));
298 b
299 });
300
301 let queue_provider: Arc<dyn MessageProvider> = Arc::new(QueueMessageProvider {
302 steering_queue: Arc::clone(&self.steering_queue),
303 follow_up_queue: Arc::clone(&self.follow_up_queue),
304 steering_mode: self.steering_mode,
305 follow_up_mode: self.follow_up_mode,
306 pending_message_snapshot: Arc::clone(&self.pending_message_snapshot),
307 });
308
309 let message_provider: Arc<dyn MessageProvider> =
310 if let Some(ref external) = self.external_message_provider {
311 Arc::new(crate::message_provider::ComposedMessageProvider::new(
312 queue_provider,
313 Arc::clone(external),
314 ))
315 } else {
316 queue_provider
317 };
318
319 AgentLoopConfig {
320 agent_name: self.agent_name.clone(),
321 transfer_chain: self.transfer_chain.clone(),
322 model: self.state.model.clone(),
323 stream_options: self.stream_options.clone(),
324 retry_strategy: Box::new(SharedRetryStrategy(Arc::clone(&self.retry_strategy))),
325 stream_fn: Arc::clone(&self.stream_fn),
326 tools: self.state.tools.clone(),
327 convert_to_llm: convert_box,
328 transform_context: transform,
329 get_api_key: api_key_box,
330 message_provider: Some(message_provider),
331 pending_message_snapshot: Arc::clone(&self.pending_message_snapshot),
332 loop_context_snapshot: Arc::clone(&self.loop_context_snapshot),
333 approve_tool: self.approve_tool.as_ref().map(|a| {
334 let a = Arc::clone(a);
335 let b: Box<ApproveToolFn> = Box::new(move |req| a(req));
336 b
337 }),
338 approval_mode: self.approval_mode,
339 pre_turn_policies: self.pre_turn_policies.clone(),
340 pre_dispatch_policies: self.pre_dispatch_policies.clone(),
341 post_turn_policies: self.post_turn_policies.clone(),
342 post_loop_policies: self.post_loop_policies.clone(),
343 async_transform_context: self.async_transform_context.as_ref().map(Arc::clone),
344 metrics_collector: self.metrics_collector.as_ref().map(Arc::clone),
345 fallback: self.fallback.clone(),
346 tool_execution_policy: self.tool_execution_policy.clone(),
347 session_state: Arc::clone(&self.session_state),
348 credential_resolver: self.credential_resolver.as_ref().map(Arc::clone),
349 cache_config: self.cache_config.clone(),
350 cache_state: std::sync::Mutex::new(crate::context_cache::CacheState::new()),
351 dynamic_system_prompt: self.dynamic_system_prompt.clone(),
352 }
353 }
354}
355
356fn new_blocking_runtime_with(
357 build: impl FnOnce() -> std::io::Result<tokio::runtime::Runtime>,
358) -> Result<tokio::runtime::Runtime, AgentError> {
359 if tokio::runtime::Handle::try_current().is_ok() {
360 return Err(AgentError::SyncInAsyncContext);
361 }
362 build().map_err(AgentError::runtime_init)
363}
364
365pub(super) fn new_blocking_runtime() -> Result<tokio::runtime::Runtime, AgentError> {
370 new_blocking_runtime_with(tokio::runtime::Runtime::new)
371}
372
373#[cfg(test)]
374mod tests {
375 use super::new_blocking_runtime_with;
376 use crate::error::AgentError;
377
378 #[test]
379 fn new_blocking_runtime_returns_runtime_init_error() {
380 let err = new_blocking_runtime_with(|| Err(std::io::Error::other("boom"))).unwrap_err();
381
382 assert!(matches!(err, AgentError::RuntimeInit { .. }));
383 assert_eq!(
384 err.to_string(),
385 "failed to create Tokio runtime for sync API"
386 );
387 }
388}