1use std::sync::Arc;
11
12use serde_json::Value;
13use tokio::sync::{Mutex, mpsc};
14
15use crate::agents::agent_node::Agent;
16use crate::agents::error::AgentError;
17use crate::agents::streaming::{AgentEvent, TerminationReason};
18use crate::agents::usage::Usage;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum StopKind {
27 Graceful,
29 Force,
31}
32
33#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub enum RunErrorAction {
41 Retry,
43 Continue,
45 Abort(String),
47 SwitchModel(String),
49}
50
51#[derive(Debug, Clone)]
57pub struct RunnerConfig {
58 pub model_override: Option<String>,
60 pub session_id: Option<String>,
62 pub max_retries: u32,
64}
65
66impl Default for RunnerConfig {
67 fn default() -> Self {
68 Self {
69 model_override: None,
70 session_id: None,
71 max_retries: 3,
72 }
73 }
74}
75
76#[derive(Debug)]
85pub struct Runner<O: serde::Serialize + Send + Sync + 'static = ()> {
86 agent: Arc<Agent<O>>,
87 current_model: Mutex<String>,
89 stop_tx: Mutex<Option<mpsc::Sender<StopKind>>>,
91}
92
93impl<O: serde::Serialize + Send + Sync + 'static> Runner<O> {
94 #[must_use]
96 pub fn new(agent: Agent<O>) -> Self {
97 let model = agent.model_name().to_string();
98 Self {
99 agent: Arc::new(agent),
100 current_model: Mutex::new(model),
101 stop_tx: Mutex::new(None),
102 }
103 }
104
105 pub async fn set_model(&self, model: impl Into<String>) {
108 let mut guard = self.current_model.lock().await;
109 *guard = model.into();
110 tracing::info!(model = %*guard, "Runner: model switched");
111 }
112
113 pub async fn stop_graceful(&self) {
116 if let Some(tx) = self.stop_tx.lock().await.as_ref() {
117 let _ = tx.send(StopKind::Graceful).await;
118 }
119 }
120
121 pub async fn stop_force(&self) {
124 if let Some(tx) = self.stop_tx.lock().await.as_ref() {
125 let _ = tx.send(StopKind::Force).await;
126 }
127 }
128
129 pub async fn run(
138 &self,
139 input: Value,
140 config: RunnerConfig,
141 ) -> Result<mpsc::Receiver<AgentEvent>, AgentError> {
142 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(128);
143 let (stop_tx, stop_rx) = mpsc::channel::<StopKind>(1);
144
145 *self.stop_tx.lock().await = Some(stop_tx);
147
148 let agent = Arc::clone(&self.agent);
149 let model = self.current_model.lock().await.clone();
150
151 let _handle = tokio::spawn(async move {
152 run_loop(agent, input, config, model, event_tx, stop_rx).await;
153 });
154
155 Ok(event_rx)
156 }
157}
158
159#[allow(clippy::too_many_lines)]
164async fn run_loop<O: serde::Serialize + Send + Sync + 'static>(
165 agent: Arc<Agent<O>>,
166 input: Value,
167 config: RunnerConfig,
168 initial_model: String,
169 event_tx: mpsc::Sender<AgentEvent>,
170 mut stop_rx: mpsc::Receiver<StopKind>,
171) {
172 let max_turns = agent.max_turn_count();
173 let max_budget = agent.budget_limit();
174 let max_retries = config.max_retries;
175
176 let mut current_model = config.model_override.unwrap_or(initial_model);
177 let mut turn: u32 = 0;
178 let mut cumulative_cost: f64 = 0.0;
179 let mut messages: Vec<Value> = Vec::new();
180 let mut retry_count: u32 = 0;
181
182 messages.push(serde_json::json!({ "role": "user", "content": input }));
184
185 loop {
186 match stop_rx.try_recv() {
188 Ok(StopKind::Graceful) => {
189 emit(
190 &event_tx,
191 AgentEvent::TurnComplete {
192 reason: TerminationReason::Stopped,
193 },
194 )
195 .await;
196 return;
197 }
198 Ok(StopKind::Force) => {
199 emit(
200 &event_tx,
201 AgentEvent::TurnComplete {
202 reason: TerminationReason::Aborted,
203 },
204 )
205 .await;
206 return;
207 }
208 Err(_) => {}
209 }
210
211 if let Some(limit) = max_turns
213 && turn >= limit
214 {
215 tracing::debug!(turn, limit, "max_turns reached");
216 emit(
217 &event_tx,
218 AgentEvent::TurnComplete {
219 reason: TerminationReason::MaxTurnsExceeded,
220 },
221 )
222 .await;
223 return;
224 }
225
226 if let Some(budget) = max_budget
228 && cumulative_cost > budget
229 {
230 tracing::debug!(cumulative_cost, budget, "budget exceeded");
231 emit(
232 &event_tx,
233 AgentEvent::TurnComplete {
234 reason: TerminationReason::BudgetExceeded,
235 },
236 )
237 .await;
238 return;
239 }
240
241 turn += 1;
242
243 let model_result = invoke_model(¤t_model, &messages);
247
248 match model_result {
249 Ok(response) => {
250 retry_count = 0;
251
252 let usage = Usage {
254 input_tokens: response.input_tokens,
255 output_tokens: response.output_tokens,
256 ..Usage::default()
257 };
258 cumulative_cost += response.estimated_cost;
259
260 emit(&event_tx, AgentEvent::UsageUpdate { usage }).await;
262
263 if let Some(text) = response.text {
265 emit(&event_tx, AgentEvent::TextDelta { content: text }).await;
266 }
267
268 if response.done {
270 emit(
271 &event_tx,
272 AgentEvent::TurnComplete {
273 reason: TerminationReason::Complete,
274 },
275 )
276 .await;
277 return;
278 }
279
280 messages.push(serde_json::json!({ "role": "assistant", "content": response.raw }));
282 }
283
284 Err(err) => {
285 let action = dispatch_model_error(
286 &err,
287 retry_count,
288 max_retries,
289 agent.fallback_model_name(),
290 );
291
292 match action {
293 RunErrorAction::Retry => {
294 retry_count += 1;
295 tracing::warn!(attempt = retry_count, model = %current_model, "Retrying after model error");
296 turn -= 1; }
298 RunErrorAction::SwitchModel(fallback) => {
299 tracing::warn!(
300 from = %current_model,
301 to = %fallback,
302 "Switching to fallback model"
303 );
304 current_model = fallback;
305 retry_count = 0;
306 turn -= 1;
307 }
308 RunErrorAction::Continue => {
309 tracing::warn!(%err, "Model error ignored — continuing");
310 }
311 RunErrorAction::Abort(msg) => {
312 emit(&event_tx, AgentEvent::Error { message: msg }).await;
313 return;
314 }
315 }
316 }
317 }
318 }
319}
320
321fn dispatch_model_error(
326 err: &AgentError,
327 retry_count: u32,
328 max_retries: u32,
329 fallback_model: Option<&str>,
330) -> RunErrorAction {
331 match err {
332 AgentError::Model(model_err) => {
333 if !model_err.is_retryable() {
334 return RunErrorAction::Abort(err.to_string());
335 }
336 if retry_count < max_retries {
337 if retry_count > 0
339 && let Some(fb) = fallback_model
340 {
341 return RunErrorAction::SwitchModel(fb.to_string());
342 }
343 RunErrorAction::Retry
344 } else if let Some(fb) = fallback_model {
345 RunErrorAction::SwitchModel(fb.to_string())
346 } else {
347 RunErrorAction::Abort(format!("Max retries ({max_retries}) exceeded: {err}"))
348 }
349 }
350 AgentError::Panic(msg) => {
351 tracing::error!(%msg, "Agent panicked");
352 RunErrorAction::Abort(format!("Agent panicked: {msg}"))
353 }
354 _ => RunErrorAction::Abort(err.to_string()),
355 }
356}
357
358struct ModelResponse {
363 text: Option<String>,
364 raw: Value,
365 input_tokens: u64,
366 output_tokens: u64,
367 estimated_cost: f64,
368 done: bool,
369}
370
371#[allow(clippy::unnecessary_wraps)]
375fn invoke_model(model: &str, messages: &[Value]) -> Result<ModelResponse, AgentError> {
376 tracing::debug!(%model, message_count = messages.len(), "invoke_model (stub)");
377 Ok(ModelResponse {
379 text: None,
380 raw: Value::Null,
381 input_tokens: 0,
382 output_tokens: 0,
383 estimated_cost: 0.0,
384 done: true,
385 })
386}
387
388async fn emit(tx: &mpsc::Sender<AgentEvent>, event: AgentEvent) {
393 let _ = tx.send(event).await;
395}
396
397#[cfg(test)]
398#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
399mod tests {
400 use super::*;
401 use crate::agents::agent_node::Agent;
402
403 #[tokio::test]
404 async fn test_runner_completes() {
405 let agent: Agent = Agent::new("test", "stub-model");
406 let runner = Runner::new(agent);
407 let mut rx = runner
408 .run(serde_json::json!("Hello"), RunnerConfig::default())
409 .await
410 .unwrap();
411
412 let mut got_complete = false;
413 while let Some(event) = rx.recv().await {
414 if let AgentEvent::TurnComplete { reason } = event {
415 assert_eq!(reason, TerminationReason::Complete);
416 got_complete = true;
417 }
418 }
419 assert!(got_complete, "expected TurnComplete event");
420 }
421
422 #[tokio::test]
423 async fn test_runner_max_turns() {
424 let agent: Agent = Agent::new("test", "stub-model").max_turns(0);
427 let runner = Runner::new(agent);
428 let mut rx = runner
429 .run(serde_json::json!("Hello"), RunnerConfig::default())
430 .await
431 .unwrap();
432
433 let mut got_max_turns = false;
434 while let Some(event) = rx.recv().await {
435 if let AgentEvent::TurnComplete { reason } = event {
436 if reason == TerminationReason::MaxTurnsExceeded {
438 got_max_turns = true;
439 }
440 }
441 }
442 assert!(got_max_turns, "expected MaxTurnsExceeded");
443 }
444
445 #[tokio::test]
446 async fn test_runner_graceful_stop() {
447 let agent: Agent = Agent::new("test", "stub-model");
448 let runner = Arc::new(Runner::new(agent));
449 let runner2 = Arc::clone(&runner);
450
451 let mut rx = runner
452 .run(serde_json::json!("Hello"), RunnerConfig::default())
453 .await
454 .unwrap();
455
456 runner2.stop_graceful().await;
458
459 let mut saw_stop_or_complete = false;
460 while let Some(event) = rx.recv().await {
461 if let AgentEvent::TurnComplete { reason } = event
462 && matches!(
463 reason,
464 TerminationReason::Stopped | TerminationReason::Complete
465 )
466 {
467 saw_stop_or_complete = true;
468 }
469 }
470 assert!(saw_stop_or_complete);
471 }
472}