1mod config;
8mod event;
9mod overflow;
10mod stream;
11mod tool_dispatch;
12mod turn;
13mod types;
14
15pub use config::AgentLoopConfig;
16pub use event::{AgentEvent, TurnEndReason};
17pub use types::*;
18
19use std::error::Error as _;
20use std::pin::Pin;
21use std::sync::Arc;
22
23use futures::Stream;
24use tokio::sync::mpsc;
25use tokio_stream::wrappers::ReceiverStream;
26use tokio_util::sync::CancellationToken;
27use tracing::{Instrument, info, info_span};
28
29use crate::error::AgentError;
30use crate::stream::StreamErrorKind;
31use crate::types::{AgentMessage, AssistantMessage, ModelSpec, StopReason};
32use crate::util::now_timestamp;
33
34#[deprecated(
39 note = "Overflow recovery now happens in-place in run_single_turn. Retained for backward compatibility."
40)]
41#[allow(dead_code)]
42pub const CONTEXT_OVERFLOW_SENTINEL: &str = "__context_overflow__";
43
44const EVENT_CHANNEL_CAPACITY: usize = 256;
47
48#[must_use]
55pub fn agent_loop(
56 prompt_messages: Vec<AgentMessage>,
57 system_prompt: String,
58 config: AgentLoopConfig,
59 cancellation_token: CancellationToken,
60) -> Pin<Box<dyn Stream<Item = AgentEvent> + Send>> {
61 let initial_new_messages_len = prompt_messages.len();
62 run_loop(
63 prompt_messages,
64 initial_new_messages_len,
65 system_prompt,
66 config,
67 cancellation_token,
68 )
69}
70
71#[must_use]
76pub fn agent_loop_continue(
77 messages: Vec<AgentMessage>,
78 system_prompt: String,
79 config: AgentLoopConfig,
80 cancellation_token: CancellationToken,
81) -> Pin<Box<dyn Stream<Item = AgentEvent> + Send>> {
82 run_loop(messages, 0, system_prompt, config, cancellation_token)
83}
84
85fn run_loop(
90 initial_messages: Vec<AgentMessage>,
91 initial_new_messages_len: usize,
92 system_prompt: String,
93 config: AgentLoopConfig,
94 cancellation_token: CancellationToken,
95) -> Pin<Box<dyn Stream<Item = AgentEvent> + Send>> {
96 let (tx, rx) = mpsc::channel::<AgentEvent>(EVENT_CHANNEL_CAPACITY);
97
98 tokio::spawn(async move {
99 run_loop_inner(
100 initial_messages,
101 initial_new_messages_len,
102 system_prompt,
103 config,
104 cancellation_token,
105 tx,
106 )
107 .await;
108 });
109
110 Box::pin(ReceiverStream::new(rx))
111}
112
113pub async fn emit(tx: &mpsc::Sender<AgentEvent>, event: AgentEvent) -> bool {
115 tx.send(event).await.is_ok()
116}
117
118#[allow(clippy::too_many_lines)]
122async fn run_loop_inner(
123 initial_messages: Vec<AgentMessage>,
124 initial_new_messages_len: usize,
125 system_prompt: String,
126 config: AgentLoopConfig,
127 cancellation_token: CancellationToken,
128 tx: mpsc::Sender<AgentEvent>,
129) {
130 let config = Arc::new(config);
131 let span = info_span!(
132 "agent.run",
133 model_id = %config.model.model_id,
134 provider = %config.model.provider,
135 tool_count = config.tools.len(),
136 message_count = initial_messages.len(),
137 );
138 async {
139 info!(
140 model = %config.model.model_id,
141 provider = %config.model.provider,
142 tools = config.tools.len(),
143 "starting agent loop"
144 );
145 let mut transfer_chain = config.transfer_chain.clone().unwrap_or_default();
148 if let Some(ref name) = config.agent_name {
149 let _ = transfer_chain.push(name.clone());
152 }
153
154 let mut state = LoopState {
155 context_messages: initial_messages,
156 pending_messages: Vec::new(),
157 initial_new_messages_len,
158 overflow_signal: false,
159 overflow_recovery_attempted: false,
160 turn_index: 0,
161 accumulated_usage: crate::types::Usage::default(),
162 accumulated_cost: crate::types::Cost::default(),
163 last_assistant_message: None,
164 last_tool_results: Vec::new(),
165 transfer_chain,
166 };
167
168 if !emit(&tx, AgentEvent::AgentStart).await {
170 return;
171 }
172
173 'outer: loop {
175 'inner: loop {
177 let turn_result = turn::run_single_turn(
178 &config,
179 &mut state,
180 &system_prompt,
181 &cancellation_token,
182 &tx,
183 )
184 .await;
185
186 let should_break = match turn_result {
187 TurnOutcome::ContinueInner => {
188 state.turn_index += 1;
189 false
190 }
191 TurnOutcome::BreakInner => {
192 state.turn_index += 1;
193 true
194 }
195 TurnOutcome::Return => return,
196 };
197
198 if should_break {
205 break 'inner;
206 }
207 }
208
209 {
211 use crate::policy::{PolicyContext, PolicyVerdict, run_post_loop_policies};
212
213 let state_snapshot = {
214 let guard = config
215 .session_state
216 .read()
217 .unwrap_or_else(std::sync::PoisonError::into_inner);
218 guard.clone()
219 };
220 let policy_ctx = PolicyContext {
221 turn_index: state.turn_index,
222 accumulated_usage: &state.accumulated_usage,
223 accumulated_cost: &state.accumulated_cost,
224 message_count: state.context_messages.len(),
225 overflow_signal: state.overflow_signal,
226 new_messages: &[], state: &state_snapshot,
228 };
229 match run_post_loop_policies(&config.post_loop_policies, &policy_ctx) {
230 PolicyVerdict::Continue => {}
231 PolicyVerdict::Stop(_reason) => {
232 let _ = emit(
233 &tx,
234 AgentEvent::AgentEnd {
235 messages: Arc::new(state.context_messages),
236 },
237 )
238 .await;
239 info!("post-loop policy stopped agent");
240 return;
241 }
242 PolicyVerdict::Inject(msgs) => {
243 state.pending_messages.extend(msgs);
244 config
245 .pending_message_snapshot
246 .replace(&state.pending_messages);
247 continue 'outer;
248 }
249 }
250 }
251
252 if let Some(ref provider) = config.message_provider {
254 let msgs = provider.poll_follow_up();
255 if !msgs.is_empty() {
256 state.pending_messages.extend(msgs);
257 config
258 .pending_message_snapshot
259 .replace(&state.pending_messages);
260 continue 'outer;
261 }
262 }
263
264 let _ = emit(
266 &tx,
267 AgentEvent::AgentEnd {
268 messages: Arc::new(state.context_messages),
269 },
270 )
271 .await;
272 info!("agent loop finished");
273 return;
274 }
275 }
276 .instrument(span)
277 .await;
278}
279
280fn build_terminal_message(
284 model: &ModelSpec,
285 stop_reason: StopReason,
286 error_message: String,
287) -> AssistantMessage {
288 AssistantMessage {
289 content: vec![],
290 provider: model.provider.clone(),
291 model_id: model.model_id.clone(),
292 usage: crate::types::Usage::default(),
293 cost: crate::types::Cost::default(),
294 stop_reason,
295 error_message: Some(error_message),
296 error_kind: None,
297 timestamp: now_timestamp(),
298 cache_hint: None,
299 }
300}
301
302pub fn build_abort_message(model: &ModelSpec) -> AssistantMessage {
304 build_terminal_message(
305 model,
306 StopReason::Aborted,
307 "operation aborted via cancellation token".to_string(),
308 )
309}
310
311pub fn build_error_message(model: &ModelSpec, error: &AgentError) -> AssistantMessage {
313 build_terminal_message(model, StopReason::Error, format_error_with_sources(error))
314}
315
316pub fn format_error_with_sources(error: &AgentError) -> String {
317 let mut message = error.to_string();
318 let mut source = error.source();
319
320 while let Some(err) = source {
321 let source_message = err.to_string();
322 if !source_message.is_empty() && !message.contains(&source_message) {
323 message.push_str(": ");
324 message.push_str(&source_message);
325 }
326 source = err.source();
327 }
328
329 message
330}
331
332pub fn classify_stream_error(
337 error_message: &str,
338 stop_reason: StopReason,
339 error_kind: Option<StreamErrorKind>,
340) -> AgentError {
341 if let Some(kind) = error_kind {
343 return match kind {
344 StreamErrorKind::Throttled => AgentError::ModelThrottled,
345 StreamErrorKind::ContextWindowExceeded => AgentError::ContextWindowOverflow {
346 model: String::new(),
347 },
348 StreamErrorKind::Auth => AgentError::StreamError {
349 source: Box::new(std::io::Error::other(error_message.to_string())),
350 },
351 StreamErrorKind::Network => {
352 AgentError::network(std::io::Error::other(error_message.to_string()))
353 }
354 StreamErrorKind::ContentFiltered => AgentError::ContentFiltered,
355 };
356 }
357
358 let lower = error_message.to_lowercase();
360 if lower.contains("context window") || lower.contains("context_length_exceeded") {
361 return AgentError::ContextWindowOverflow {
362 model: String::new(),
363 };
364 }
365 if lower.contains("rate limit") || lower.contains("429") || lower.contains("throttl") {
366 return AgentError::ModelThrottled;
367 }
368 if lower.contains("cache miss")
369 || lower.contains("cache not found")
370 || lower.contains("cache_miss")
371 {
372 return AgentError::CacheMiss;
373 }
374 if lower.contains("content filter") || lower.contains("content_filter") {
375 return AgentError::ContentFiltered;
376 }
377 if stop_reason == StopReason::Aborted {
378 return AgentError::Aborted;
379 }
380 AgentError::StreamError {
381 source: Box::new(std::io::Error::other(error_message.to_string())),
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn classify_cache_miss_variants() {
391 let cases = [
392 "cache miss",
393 "Cache Miss detected",
394 "provider cache_miss",
395 "cache not found",
396 ];
397 for msg in cases {
398 let err = classify_stream_error(msg, StopReason::Error, None);
399 assert!(
400 matches!(err, AgentError::CacheMiss),
401 "expected CacheMiss for \"{msg}\", got {err:?}"
402 );
403 assert!(!err.is_retryable());
404 }
405 }
406
407 #[test]
408 fn classify_non_cache_miss() {
409 let err = classify_stream_error("internal server error", StopReason::Error, None);
410 assert!(!matches!(err, AgentError::CacheMiss));
411 }
412
413 #[test]
414 fn classify_content_filtered_by_kind() {
415 let err = classify_stream_error(
416 "response blocked",
417 StopReason::Error,
418 Some(StreamErrorKind::ContentFiltered),
419 );
420 assert!(matches!(err, AgentError::ContentFiltered));
421 assert!(!err.is_retryable());
422 }
423
424 #[test]
425 fn classify_content_filtered_by_string() {
426 let err =
427 classify_stream_error("content filter violation detected", StopReason::Error, None);
428 assert!(matches!(err, AgentError::ContentFiltered));
429 assert!(!err.is_retryable());
430 }
431
432 #[test]
433 fn classify_throttled_by_kind() {
434 let err = classify_stream_error(
435 "some error",
436 StopReason::Error,
437 Some(StreamErrorKind::Throttled),
438 );
439 assert!(matches!(err, AgentError::ModelThrottled));
440 }
441
442 #[test]
443 fn classify_network_by_kind() {
444 let err = classify_stream_error(
445 "connection reset",
446 StopReason::Error,
447 Some(StreamErrorKind::Network),
448 );
449 assert!(matches!(err, AgentError::NetworkError { .. }));
450 assert!(err.is_retryable());
451 }
452
453 #[test]
454 fn classify_auth_by_kind() {
455 let err = classify_stream_error(
456 "invalid api key",
457 StopReason::Error,
458 Some(StreamErrorKind::Auth),
459 );
460 assert!(matches!(err, AgentError::StreamError { .. }));
461 assert!(!err.is_retryable());
462 }
463
464 #[test]
465 fn classify_context_overflow_by_kind() {
466 let err = classify_stream_error(
467 "too many tokens",
468 StopReason::Error,
469 Some(StreamErrorKind::ContextWindowExceeded),
470 );
471 assert!(matches!(err, AgentError::ContextWindowOverflow { .. }));
472 }
473
474 #[test]
475 fn structured_kind_takes_priority_over_string() {
476 let err = classify_stream_error(
478 "rate limit exceeded",
479 StopReason::Error,
480 Some(StreamErrorKind::Network),
481 );
482 assert!(
483 matches!(err, AgentError::NetworkError { .. }),
484 "structured kind should override string matching, got {err:?}"
485 );
486 }
487
488 #[test]
489 fn string_fallback_for_unclassified_errors() {
490 let err = classify_stream_error("rate limit (429)", StopReason::Error, None);
492 assert!(matches!(err, AgentError::ModelThrottled));
493 }
494
495 #[test]
496 fn string_fallback_context_overflow() {
497 let err =
498 classify_stream_error("context_length_exceeded: too long", StopReason::Error, None);
499 assert!(matches!(err, AgentError::ContextWindowOverflow { .. }));
500 }
501
502 #[test]
503 fn aborted_stop_reason_without_kind() {
504 let err = classify_stream_error("operation cancelled", StopReason::Aborted, None);
505 assert!(matches!(err, AgentError::Aborted));
506 }
507}