Skip to main content

swink_agent/loop_/
mod.rs

1//! Core agent loop execution engine.
2//!
3//! Implements the nested inner/outer loop, tool dispatch, steering/follow-up
4//! injection, event emission, retry integration, error/abort handling, and max
5//! tokens recovery. Stateless — all state is passed in via [`AgentLoopConfig`].
6
7mod config;
8mod event;
9mod overflow;
10mod stream;
11mod tool_dispatch;
12mod turn;
13mod types;
14
15pub use config::AgentLoopConfig;
16pub use event::{AgentEvent, TurnEndReason};
17pub use types::*;
18
19use std::error::Error as _;
20use std::pin::Pin;
21use std::sync::Arc;
22
23use futures::Stream;
24use tokio::sync::mpsc;
25use tokio_stream::wrappers::ReceiverStream;
26use tokio_util::sync::CancellationToken;
27use tracing::{Instrument, info, info_span};
28
29use crate::error::AgentError;
30use crate::stream::StreamErrorKind;
31use crate::types::{AgentMessage, AssistantMessage, ModelSpec, StopReason};
32use crate::util::now_timestamp;
33
34// ─── Constants ───────────────────────────────────────────────────────────────
35
36/// Sentinel value used to signal context overflow between `handle_stream_result`
37/// and `run_single_turn`.
38#[deprecated(
39    note = "Overflow recovery now happens in-place in run_single_turn. Retained for backward compatibility."
40)]
41#[allow(dead_code)]
42pub const CONTEXT_OVERFLOW_SENTINEL: &str = "__context_overflow__";
43
44/// Channel capacity for agent events. Sized to handle burst streaming
45/// without backpressure under normal operation.
46const EVENT_CHANNEL_CAPACITY: usize = 256;
47
48// ─── Entry Points ────────────────────────────────────────────────────────────
49
50/// Start a new agent loop with prompt messages.
51///
52/// Creates an initial context with the prompt messages, then runs the loop.
53/// Returns a stream of `AgentEvent` values.
54#[must_use]
55pub fn agent_loop(
56    prompt_messages: Vec<AgentMessage>,
57    system_prompt: String,
58    config: AgentLoopConfig,
59    cancellation_token: CancellationToken,
60) -> Pin<Box<dyn Stream<Item = AgentEvent> + Send>> {
61    let initial_new_messages_len = prompt_messages.len();
62    run_loop(
63        prompt_messages,
64        initial_new_messages_len,
65        system_prompt,
66        config,
67        cancellation_token,
68    )
69}
70
71/// Resume an agent loop from existing messages.
72///
73/// Resumes from existing messages (no new prompt), calls the loop.
74/// Returns a stream of `AgentEvent` values.
75#[must_use]
76pub fn agent_loop_continue(
77    messages: Vec<AgentMessage>,
78    system_prompt: String,
79    config: AgentLoopConfig,
80    cancellation_token: CancellationToken,
81) -> Pin<Box<dyn Stream<Item = AgentEvent> + Send>> {
82    run_loop(messages, 0, system_prompt, config, cancellation_token)
83}
84
85// ─── Internal Loop ───────────────────────────────────────────────────────────
86
87/// The core loop implementation. Spawns a task that drives the loop and sends
88/// events through an mpsc channel, returning a stream of events.
89fn run_loop(
90    initial_messages: Vec<AgentMessage>,
91    initial_new_messages_len: usize,
92    system_prompt: String,
93    config: AgentLoopConfig,
94    cancellation_token: CancellationToken,
95) -> Pin<Box<dyn Stream<Item = AgentEvent> + Send>> {
96    let (tx, rx) = mpsc::channel::<AgentEvent>(EVENT_CHANNEL_CAPACITY);
97
98    tokio::spawn(async move {
99        run_loop_inner(
100            initial_messages,
101            initial_new_messages_len,
102            system_prompt,
103            config,
104            cancellation_token,
105            tx,
106        )
107        .await;
108    });
109
110    Box::pin(ReceiverStream::new(rx))
111}
112
113/// Send an event through the channel. Returns false if the receiver is dropped.
114pub async fn emit(tx: &mpsc::Sender<AgentEvent>, event: AgentEvent) -> bool {
115    tx.send(event).await.is_ok()
116}
117
118// ─── run_loop_inner ──────────────────────────────────────────────────────────
119
120/// The actual loop logic running inside the spawned task.
121#[allow(clippy::too_many_lines)]
122async fn run_loop_inner(
123    initial_messages: Vec<AgentMessage>,
124    initial_new_messages_len: usize,
125    system_prompt: String,
126    config: AgentLoopConfig,
127    cancellation_token: CancellationToken,
128    tx: mpsc::Sender<AgentEvent>,
129) {
130    let config = Arc::new(config);
131    let span = info_span!(
132        "agent.run",
133        model_id = %config.model.model_id,
134        provider = %config.model.provider,
135        tool_count = config.tools.len(),
136        message_count = initial_messages.len(),
137    );
138    async {
139        info!(
140            model = %config.model.model_id,
141            provider = %config.model.provider,
142            tools = config.tools.len(),
143            "starting agent loop"
144        );
145        // Build the transfer chain and push the current agent name (if known)
146        // so that circular transfers back to this agent are detected.
147        let mut transfer_chain = config.transfer_chain.clone().unwrap_or_default();
148        if let Some(ref name) = config.agent_name {
149            // Ignore the error — when resuming from a handoff chain the agent
150            // name may already be present as the latest hop.
151            let _ = transfer_chain.push(name.clone());
152        }
153
154        let mut state = LoopState {
155            context_messages: initial_messages,
156            pending_messages: Vec::new(),
157            initial_new_messages_len,
158            overflow_signal: false,
159            overflow_recovery_attempted: false,
160            turn_index: 0,
161            accumulated_usage: crate::types::Usage::default(),
162            accumulated_cost: crate::types::Cost::default(),
163            last_assistant_message: None,
164            last_tool_results: Vec::new(),
165            transfer_chain,
166        };
167
168        // 1. Emit AgentStart
169        if !emit(&tx, AgentEvent::AgentStart).await {
170            return;
171        }
172
173        // 2. Outer loop (follow-up phase)
174        'outer: loop {
175            // Inner loop (turn + tool phase)
176            'inner: loop {
177                let turn_result = turn::run_single_turn(
178                    &config,
179                    &mut state,
180                    &system_prompt,
181                    &cancellation_token,
182                    &tx,
183                )
184                .await;
185
186                let should_break = match turn_result {
187                    TurnOutcome::ContinueInner => {
188                        state.turn_index += 1;
189                        false
190                    }
191                    TurnOutcome::BreakInner => {
192                        state.turn_index += 1;
193                        true
194                    }
195                    TurnOutcome::Return => return,
196                };
197
198                // Post-turn policies are evaluated inside the turn handlers
199                // (handle_no_tool_calls / handle_tool_calls) against the
200                // committed turn snapshot before TurnEnd is emitted or transfer
201                // termination is honored. This lets policies replace the
202                // assistant message before listeners observe the turn.
203
204                if should_break {
205                    break 'inner;
206                }
207            }
208
209            // Post-loop policies: evaluate after inner loop exits
210            {
211                use crate::policy::{PolicyContext, PolicyVerdict, run_post_loop_policies};
212
213                let state_snapshot = {
214                    let guard = config
215                        .session_state
216                        .read()
217                        .unwrap_or_else(std::sync::PoisonError::into_inner);
218                    guard.clone()
219                };
220                let policy_ctx = PolicyContext {
221                    turn_index: state.turn_index,
222                    accumulated_usage: &state.accumulated_usage,
223                    accumulated_cost: &state.accumulated_cost,
224                    message_count: state.context_messages.len(),
225                    overflow_signal: state.overflow_signal,
226                    new_messages: &[], // no new messages at post-loop
227                    state: &state_snapshot,
228                };
229                match run_post_loop_policies(&config.post_loop_policies, &policy_ctx) {
230                    PolicyVerdict::Continue => {}
231                    PolicyVerdict::Stop(_reason) => {
232                        let _ = emit(
233                            &tx,
234                            AgentEvent::AgentEnd {
235                                messages: Arc::new(state.context_messages),
236                            },
237                        )
238                        .await;
239                        info!("post-loop policy stopped agent");
240                        return;
241                    }
242                    PolicyVerdict::Inject(msgs) => {
243                        state.pending_messages.extend(msgs);
244                        config
245                            .pending_message_snapshot
246                            .replace(&state.pending_messages);
247                        continue 'outer;
248                    }
249                }
250            }
251
252            // Outer loop: poll follow-up messages
253            if let Some(ref provider) = config.message_provider {
254                let msgs = provider.poll_follow_up();
255                if !msgs.is_empty() {
256                    state.pending_messages.extend(msgs);
257                    config
258                        .pending_message_snapshot
259                        .replace(&state.pending_messages);
260                    continue 'outer;
261                }
262            }
263
264            // No follow-up → emit AgentEnd and exit
265            let _ = emit(
266                &tx,
267                AgentEvent::AgentEnd {
268                    messages: Arc::new(state.context_messages),
269                },
270            )
271            .await;
272            info!("agent loop finished");
273            return;
274        }
275    }
276    .instrument(span)
277    .await;
278}
279
280// ─── Helpers ─────────────────────────────────────────────────────────────────
281
282/// Build a terminal `AssistantMessage` with the given stop reason and message.
283fn build_terminal_message(
284    model: &ModelSpec,
285    stop_reason: StopReason,
286    error_message: String,
287) -> AssistantMessage {
288    AssistantMessage {
289        content: vec![],
290        provider: model.provider.clone(),
291        model_id: model.model_id.clone(),
292        usage: crate::types::Usage::default(),
293        cost: crate::types::Cost::default(),
294        stop_reason,
295        error_message: Some(error_message),
296        error_kind: None,
297        timestamp: now_timestamp(),
298        cache_hint: None,
299    }
300}
301
302/// Build an aborted `AssistantMessage`.
303pub fn build_abort_message(model: &ModelSpec) -> AssistantMessage {
304    build_terminal_message(
305        model,
306        StopReason::Aborted,
307        "operation aborted via cancellation token".to_string(),
308    )
309}
310
311/// Build an error `AssistantMessage` from a `AgentError`.
312pub fn build_error_message(model: &ModelSpec, error: &AgentError) -> AssistantMessage {
313    build_terminal_message(model, StopReason::Error, format_error_with_sources(error))
314}
315
316pub fn format_error_with_sources(error: &AgentError) -> String {
317    let mut message = error.to_string();
318    let mut source = error.source();
319
320    while let Some(err) = source {
321        let source_message = err.to_string();
322        if !source_message.is_empty() && !message.contains(&source_message) {
323            message.push_str(": ");
324            message.push_str(&source_message);
325        }
326        source = err.source();
327    }
328
329    message
330}
331
332/// Classify an `AssistantMessageEvent::Error` into a `AgentError`.
333///
334/// When `error_kind` is present, structural classification takes priority
335/// over string matching on the error message.
336pub fn classify_stream_error(
337    error_message: &str,
338    stop_reason: StopReason,
339    error_kind: Option<StreamErrorKind>,
340) -> AgentError {
341    // Prefer structural classification when the adapter provides it
342    if let Some(kind) = error_kind {
343        return match kind {
344            StreamErrorKind::Throttled => AgentError::ModelThrottled,
345            StreamErrorKind::ContextWindowExceeded => AgentError::ContextWindowOverflow {
346                model: String::new(),
347            },
348            StreamErrorKind::Auth => AgentError::StreamError {
349                source: Box::new(std::io::Error::other(error_message.to_string())),
350            },
351            StreamErrorKind::Network => {
352                AgentError::network(std::io::Error::other(error_message.to_string()))
353            }
354            StreamErrorKind::ContentFiltered => AgentError::ContentFiltered,
355        };
356    }
357
358    // Fallback to string matching for adapters that don't set error_kind
359    let lower = error_message.to_lowercase();
360    if lower.contains("context window") || lower.contains("context_length_exceeded") {
361        return AgentError::ContextWindowOverflow {
362            model: String::new(),
363        };
364    }
365    if lower.contains("rate limit") || lower.contains("429") || lower.contains("throttl") {
366        return AgentError::ModelThrottled;
367    }
368    if lower.contains("cache miss")
369        || lower.contains("cache not found")
370        || lower.contains("cache_miss")
371    {
372        return AgentError::CacheMiss;
373    }
374    if lower.contains("content filter") || lower.contains("content_filter") {
375        return AgentError::ContentFiltered;
376    }
377    if stop_reason == StopReason::Aborted {
378        return AgentError::Aborted;
379    }
380    AgentError::StreamError {
381        source: Box::new(std::io::Error::other(error_message.to_string())),
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn classify_cache_miss_variants() {
391        let cases = [
392            "cache miss",
393            "Cache Miss detected",
394            "provider cache_miss",
395            "cache not found",
396        ];
397        for msg in cases {
398            let err = classify_stream_error(msg, StopReason::Error, None);
399            assert!(
400                matches!(err, AgentError::CacheMiss),
401                "expected CacheMiss for \"{msg}\", got {err:?}"
402            );
403            assert!(!err.is_retryable());
404        }
405    }
406
407    #[test]
408    fn classify_non_cache_miss() {
409        let err = classify_stream_error("internal server error", StopReason::Error, None);
410        assert!(!matches!(err, AgentError::CacheMiss));
411    }
412
413    #[test]
414    fn classify_content_filtered_by_kind() {
415        let err = classify_stream_error(
416            "response blocked",
417            StopReason::Error,
418            Some(StreamErrorKind::ContentFiltered),
419        );
420        assert!(matches!(err, AgentError::ContentFiltered));
421        assert!(!err.is_retryable());
422    }
423
424    #[test]
425    fn classify_content_filtered_by_string() {
426        let err =
427            classify_stream_error("content filter violation detected", StopReason::Error, None);
428        assert!(matches!(err, AgentError::ContentFiltered));
429        assert!(!err.is_retryable());
430    }
431
432    #[test]
433    fn classify_throttled_by_kind() {
434        let err = classify_stream_error(
435            "some error",
436            StopReason::Error,
437            Some(StreamErrorKind::Throttled),
438        );
439        assert!(matches!(err, AgentError::ModelThrottled));
440    }
441
442    #[test]
443    fn classify_network_by_kind() {
444        let err = classify_stream_error(
445            "connection reset",
446            StopReason::Error,
447            Some(StreamErrorKind::Network),
448        );
449        assert!(matches!(err, AgentError::NetworkError { .. }));
450        assert!(err.is_retryable());
451    }
452
453    #[test]
454    fn classify_auth_by_kind() {
455        let err = classify_stream_error(
456            "invalid api key",
457            StopReason::Error,
458            Some(StreamErrorKind::Auth),
459        );
460        assert!(matches!(err, AgentError::StreamError { .. }));
461        assert!(!err.is_retryable());
462    }
463
464    #[test]
465    fn classify_context_overflow_by_kind() {
466        let err = classify_stream_error(
467            "too many tokens",
468            StopReason::Error,
469            Some(StreamErrorKind::ContextWindowExceeded),
470        );
471        assert!(matches!(err, AgentError::ContextWindowOverflow { .. }));
472    }
473
474    #[test]
475    fn structured_kind_takes_priority_over_string() {
476        // Message says "rate limit" but kind says Network — kind wins
477        let err = classify_stream_error(
478            "rate limit exceeded",
479            StopReason::Error,
480            Some(StreamErrorKind::Network),
481        );
482        assert!(
483            matches!(err, AgentError::NetworkError { .. }),
484            "structured kind should override string matching, got {err:?}"
485        );
486    }
487
488    #[test]
489    fn string_fallback_for_unclassified_errors() {
490        // No error_kind — string matching should still work for external adapters
491        let err = classify_stream_error("rate limit (429)", StopReason::Error, None);
492        assert!(matches!(err, AgentError::ModelThrottled));
493    }
494
495    #[test]
496    fn string_fallback_context_overflow() {
497        let err =
498            classify_stream_error("context_length_exceeded: too long", StopReason::Error, None);
499        assert!(matches!(err, AgentError::ContextWindowOverflow { .. }));
500    }
501
502    #[test]
503    fn aborted_stop_reason_without_kind() {
504        let err = classify_stream_error("operation cancelled", StopReason::Aborted, None);
505        assert!(matches!(err, AgentError::Aborted));
506    }
507}