Skip to main content

synwire_core/agents/
runner.rs

1//! Agent runner and execution loop.
2//!
3//! `Runner` drives the agent turn loop:
4//! session lookup → middleware chain → model invocation → tool dispatch →
5//! directive execution → event emission → usage tracking.
6//!
7//! It enforces `max_turns` and `max_budget` limits, handles model errors with
8//! configurable retry / fallback, and supports graceful and force stop.
9
10use std::sync::Arc;
11
12use serde_json::Value;
13use tokio::sync::{Mutex, mpsc};
14
15use crate::agents::agent_node::Agent;
16use crate::agents::error::AgentError;
17use crate::agents::streaming::{AgentEvent, TerminationReason};
18use crate::agents::usage::Usage;
19
20// ---------------------------------------------------------------------------
21// Stop signal
22// ---------------------------------------------------------------------------
23
24/// Kind of stop requested from outside the runner.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum StopKind {
27    /// Drain in-flight tool calls, then stop cleanly.
28    Graceful,
29    /// Cancel immediately without draining.
30    Force,
31}
32
33// ---------------------------------------------------------------------------
34// RunErrorAction
35// ---------------------------------------------------------------------------
36
37/// Specifies what the runner should do when an error occurs.
38#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub enum RunErrorAction {
41    /// Retry the current request (up to a configurable limit).
42    Retry,
43    /// Continue to the next turn ignoring this error.
44    Continue,
45    /// Abort the run immediately.
46    Abort(String),
47    /// Switch to a different model and retry.
48    SwitchModel(String),
49}
50
51// ---------------------------------------------------------------------------
52// RunnerConfig
53// ---------------------------------------------------------------------------
54
55/// Configuration for a single runner execution.
56#[derive(Debug, Clone)]
57pub struct RunnerConfig {
58    /// Override the agent's model for this run.
59    pub model_override: Option<String>,
60    /// Session ID to resume (None = new session).
61    pub session_id: Option<String>,
62    /// Maximum number of retries per model error.
63    pub max_retries: u32,
64}
65
66impl Default for RunnerConfig {
67    fn default() -> Self {
68        Self {
69            model_override: None,
70            session_id: None,
71            max_retries: 3,
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// Runner
78// ---------------------------------------------------------------------------
79
80/// Drives the agent execution loop.
81///
82/// The runner is stateless between runs; all per-run state is held in the
83/// channel and local variables inside `run`.
84#[derive(Debug)]
85pub struct Runner<O: serde::Serialize + Send + Sync + 'static = ()> {
86    agent: Arc<Agent<O>>,
87    /// Current model — may be changed via `set_model`.
88    current_model: Mutex<String>,
89    /// Stop signal sender.
90    stop_tx: Mutex<Option<mpsc::Sender<StopKind>>>,
91}
92
93impl<O: serde::Serialize + Send + Sync + 'static> Runner<O> {
94    /// Create a runner wrapping the given agent.
95    #[must_use]
96    pub fn new(agent: Agent<O>) -> Self {
97        let model = agent.model_name().to_string();
98        Self {
99            agent: Arc::new(agent),
100            current_model: Mutex::new(model),
101            stop_tx: Mutex::new(None),
102        }
103    }
104
105    /// Dynamically switch the model for subsequent turns, preserving
106    /// conversation history.
107    pub async fn set_model(&self, model: impl Into<String>) {
108        let mut guard = self.current_model.lock().await;
109        *guard = model.into();
110        tracing::info!(model = %*guard, "Runner: model switched");
111    }
112
113    /// Send a graceful stop signal.  The runner will finish any in-flight
114    /// tool call, then emit `TurnComplete { reason: Stopped }`.
115    pub async fn stop_graceful(&self) {
116        if let Some(tx) = self.stop_tx.lock().await.as_ref() {
117            let _ = tx.send(StopKind::Graceful).await;
118        }
119    }
120
121    /// Send a force stop signal.  The runner cancels immediately and emits
122    /// `TurnComplete { reason: Aborted }`.
123    pub async fn stop_force(&self) {
124        if let Some(tx) = self.stop_tx.lock().await.as_ref() {
125            let _ = tx.send(StopKind::Force).await;
126        }
127    }
128
129    /// Run the agent with the given input, yielding events over a channel.
130    ///
131    /// # Event stream
132    /// Events are sent on the returned receiver.  The stream ends when the
133    /// receiver is closed (after a `TurnComplete` or `Error` event).
134    ///
135    /// # Errors
136    /// Returns `AgentError` if setup fails before the event stream starts.
137    pub async fn run(
138        &self,
139        input: Value,
140        config: RunnerConfig,
141    ) -> Result<mpsc::Receiver<AgentEvent>, AgentError> {
142        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(128);
143        let (stop_tx, stop_rx) = mpsc::channel::<StopKind>(1);
144
145        // Store stop sender so callers can signal stop.
146        *self.stop_tx.lock().await = Some(stop_tx);
147
148        let agent = Arc::clone(&self.agent);
149        let model = self.current_model.lock().await.clone();
150
151        let _handle = tokio::spawn(async move {
152            run_loop(agent, input, config, model, event_tx, stop_rx).await;
153        });
154
155        Ok(event_rx)
156    }
157}
158
159// ---------------------------------------------------------------------------
160// Core loop (spawned task)
161// ---------------------------------------------------------------------------
162
163#[allow(clippy::too_many_lines)]
164async fn run_loop<O: serde::Serialize + Send + Sync + 'static>(
165    agent: Arc<Agent<O>>,
166    input: Value,
167    config: RunnerConfig,
168    initial_model: String,
169    event_tx: mpsc::Sender<AgentEvent>,
170    mut stop_rx: mpsc::Receiver<StopKind>,
171) {
172    let max_turns = agent.max_turn_count();
173    let max_budget = agent.budget_limit();
174    let max_retries = config.max_retries;
175
176    let mut current_model = config.model_override.unwrap_or(initial_model);
177    let mut turn: u32 = 0;
178    let mut cumulative_cost: f64 = 0.0;
179    let mut messages: Vec<Value> = Vec::new();
180    let mut retry_count: u32 = 0;
181
182    // Seed conversation with the user's input.
183    messages.push(serde_json::json!({ "role": "user", "content": input }));
184
185    loop {
186        // Check for stop signal (non-blocking poll).
187        match stop_rx.try_recv() {
188            Ok(StopKind::Graceful) => {
189                emit(
190                    &event_tx,
191                    AgentEvent::TurnComplete {
192                        reason: TerminationReason::Stopped,
193                    },
194                )
195                .await;
196                return;
197            }
198            Ok(StopKind::Force) => {
199                emit(
200                    &event_tx,
201                    AgentEvent::TurnComplete {
202                        reason: TerminationReason::Aborted,
203                    },
204                )
205                .await;
206                return;
207            }
208            Err(_) => {}
209        }
210
211        // Enforce max_turns.
212        if let Some(limit) = max_turns
213            && turn >= limit
214        {
215            tracing::debug!(turn, limit, "max_turns reached");
216            emit(
217                &event_tx,
218                AgentEvent::TurnComplete {
219                    reason: TerminationReason::MaxTurnsExceeded,
220                },
221            )
222            .await;
223            return;
224        }
225
226        // Enforce max_budget.
227        if let Some(budget) = max_budget
228            && cumulative_cost > budget
229        {
230            tracing::debug!(cumulative_cost, budget, "budget exceeded");
231            emit(
232                &event_tx,
233                AgentEvent::TurnComplete {
234                    reason: TerminationReason::BudgetExceeded,
235                },
236            )
237            .await;
238            return;
239        }
240
241        turn += 1;
242
243        // --- Simulated model invocation ---
244        // In production this would call the LLM backend.  The runner provides
245        // the scaffolding; actual model calls are injected by provider crates.
246        let model_result = invoke_model(&current_model, &messages);
247
248        match model_result {
249            Ok(response) => {
250                retry_count = 0;
251
252                // Accumulate synthetic usage.
253                let usage = Usage {
254                    input_tokens: response.input_tokens,
255                    output_tokens: response.output_tokens,
256                    ..Usage::default()
257                };
258                cumulative_cost += response.estimated_cost;
259
260                // Emit usage update.
261                emit(&event_tx, AgentEvent::UsageUpdate { usage }).await;
262
263                // Emit text delta if present.
264                if let Some(text) = response.text {
265                    emit(&event_tx, AgentEvent::TextDelta { content: text }).await;
266                }
267
268                // Check if model signalled completion.
269                if response.done {
270                    emit(
271                        &event_tx,
272                        AgentEvent::TurnComplete {
273                            reason: TerminationReason::Complete,
274                        },
275                    )
276                    .await;
277                    return;
278                }
279
280                // Append assistant message and continue loop.
281                messages.push(serde_json::json!({ "role": "assistant", "content": response.raw }));
282            }
283
284            Err(err) => {
285                let action = dispatch_model_error(
286                    &err,
287                    retry_count,
288                    max_retries,
289                    agent.fallback_model_name(),
290                );
291
292                match action {
293                    RunErrorAction::Retry => {
294                        retry_count += 1;
295                        tracing::warn!(attempt = retry_count, model = %current_model, "Retrying after model error");
296                        turn -= 1; // don't count against max_turns
297                    }
298                    RunErrorAction::SwitchModel(fallback) => {
299                        tracing::warn!(
300                            from = %current_model,
301                            to = %fallback,
302                            "Switching to fallback model"
303                        );
304                        current_model = fallback;
305                        retry_count = 0;
306                        turn -= 1;
307                    }
308                    RunErrorAction::Continue => {
309                        tracing::warn!(%err, "Model error ignored — continuing");
310                    }
311                    RunErrorAction::Abort(msg) => {
312                        emit(&event_tx, AgentEvent::Error { message: msg }).await;
313                        return;
314                    }
315                }
316            }
317        }
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Error dispatch
323// ---------------------------------------------------------------------------
324
325fn dispatch_model_error(
326    err: &AgentError,
327    retry_count: u32,
328    max_retries: u32,
329    fallback_model: Option<&str>,
330) -> RunErrorAction {
331    match err {
332        AgentError::Model(model_err) => {
333            if !model_err.is_retryable() {
334                return RunErrorAction::Abort(err.to_string());
335            }
336            if retry_count < max_retries {
337                // Try fallback on second retry if available.
338                if retry_count > 0
339                    && let Some(fb) = fallback_model
340                {
341                    return RunErrorAction::SwitchModel(fb.to_string());
342                }
343                RunErrorAction::Retry
344            } else if let Some(fb) = fallback_model {
345                RunErrorAction::SwitchModel(fb.to_string())
346            } else {
347                RunErrorAction::Abort(format!("Max retries ({max_retries}) exceeded: {err}"))
348            }
349        }
350        AgentError::Panic(msg) => {
351            tracing::error!(%msg, "Agent panicked");
352            RunErrorAction::Abort(format!("Agent panicked: {msg}"))
353        }
354        _ => RunErrorAction::Abort(err.to_string()),
355    }
356}
357
358// ---------------------------------------------------------------------------
359// Stub model invocation (replaced by provider crates at runtime)
360// ---------------------------------------------------------------------------
361
362struct ModelResponse {
363    text: Option<String>,
364    raw: Value,
365    input_tokens: u64,
366    output_tokens: u64,
367    estimated_cost: f64,
368    done: bool,
369}
370
371/// Placeholder model invocation.  Real implementations are injected by
372/// provider crates (e.g. `synwire-llm-openai`) via the `AgentNode::run`
373/// delegation path.
374#[allow(clippy::unnecessary_wraps)]
375fn invoke_model(model: &str, messages: &[Value]) -> Result<ModelResponse, AgentError> {
376    tracing::debug!(%model, message_count = messages.len(), "invoke_model (stub)");
377    // Stub: immediately complete with empty response.
378    Ok(ModelResponse {
379        text: None,
380        raw: Value::Null,
381        input_tokens: 0,
382        output_tokens: 0,
383        estimated_cost: 0.0,
384        done: true,
385    })
386}
387
388// ---------------------------------------------------------------------------
389// Helpers
390// ---------------------------------------------------------------------------
391
392async fn emit(tx: &mpsc::Sender<AgentEvent>, event: AgentEvent) {
393    // Ignore send errors — receiver may have been dropped.
394    let _ = tx.send(event).await;
395}
396
397#[cfg(test)]
398#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
399mod tests {
400    use super::*;
401    use crate::agents::agent_node::Agent;
402
403    #[tokio::test]
404    async fn test_runner_completes() {
405        let agent: Agent = Agent::new("test", "stub-model");
406        let runner = Runner::new(agent);
407        let mut rx = runner
408            .run(serde_json::json!("Hello"), RunnerConfig::default())
409            .await
410            .unwrap();
411
412        let mut got_complete = false;
413        while let Some(event) = rx.recv().await {
414            if let AgentEvent::TurnComplete { reason } = event {
415                assert_eq!(reason, TerminationReason::Complete);
416                got_complete = true;
417            }
418        }
419        assert!(got_complete, "expected TurnComplete event");
420    }
421
422    #[tokio::test]
423    async fn test_runner_max_turns() {
424        // The stub model never sets done=true on its own in subsequent turns,
425        // but does set done=true immediately.  Adjust by giving 0 max_turns.
426        let agent: Agent = Agent::new("test", "stub-model").max_turns(0);
427        let runner = Runner::new(agent);
428        let mut rx = runner
429            .run(serde_json::json!("Hello"), RunnerConfig::default())
430            .await
431            .unwrap();
432
433        let mut got_max_turns = false;
434        while let Some(event) = rx.recv().await {
435            if let AgentEvent::TurnComplete { reason } = event {
436                // With max_turns=0 the first check fires immediately.
437                if reason == TerminationReason::MaxTurnsExceeded {
438                    got_max_turns = true;
439                }
440            }
441        }
442        assert!(got_max_turns, "expected MaxTurnsExceeded");
443    }
444
445    #[tokio::test]
446    async fn test_runner_graceful_stop() {
447        let agent: Agent = Agent::new("test", "stub-model");
448        let runner = Arc::new(Runner::new(agent));
449        let runner2 = Arc::clone(&runner);
450
451        let mut rx = runner
452            .run(serde_json::json!("Hello"), RunnerConfig::default())
453            .await
454            .unwrap();
455
456        // Stop before any events are processed (races, but tests the wiring).
457        runner2.stop_graceful().await;
458
459        let mut saw_stop_or_complete = false;
460        while let Some(event) = rx.recv().await {
461            if let AgentEvent::TurnComplete { reason } = event
462                && matches!(
463                    reason,
464                    TerminationReason::Stopped | TerminationReason::Complete
465                )
466            {
467                saw_stop_or_complete = true;
468            }
469        }
470        assert!(saw_stop_or_complete);
471    }
472}