Skip to main content

swink_agent_eval/simulation/
orchestrator.rs

1//! Multi-turn simulation orchestrator (US4, FR-026).
2//!
3//! Drives an [`Agent`] ↔ [`ActorSimulator`] dialogue, optionally routing
4//! emitted tool calls through a [`ToolSimulator`]. Returns a fully-populated
5//! [`Invocation`] plus a [`SimulationOutcome`]. Cancellation is honored
6//! cooperatively at every `await` point.
7
8#![forbid(unsafe_code)]
9
10use std::time::{Duration, Instant};
11
12use futures::StreamExt;
13use swink_agent::{
14    Agent, AgentEvent, ContentBlock, Cost, LlmMessage, ModelSpec, StopReason, ToolResultMessage,
15    Usage, UserMessage, now_timestamp,
16};
17use tokio_util::sync::CancellationToken;
18
19use super::actor::{ActorSimulator, ActorTurn};
20use super::tool::{ToolSimulationError, ToolSimulator};
21use crate::judge::JudgeError;
22use crate::trajectory::TrajectoryCollector;
23use crate::types::{Invocation, RecordedToolCall};
24
25/// Outcome classification emitted alongside the [`Invocation`].
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SimulationOutcome {
28    GoalCompleted,
29    MaxTurnsReached,
30    AgentStopped,
31}
32
33/// Errors surfaced by [`run_multiturn_simulation`].
34#[derive(Debug, thiserror::Error)]
35pub enum SimulationError {
36    #[error("actor error: {0}")]
37    Actor(#[source] JudgeError),
38    #[error("tool error: {0}")]
39    Tool(#[source] ToolSimulationError),
40    #[error("simulation cancelled")]
41    Cancelled,
42    #[error("agent error: {0}")]
43    Agent(String),
44    /// Tool response body failed JSON-schema validation (FR-025).
45    #[error("schema validation failed: {0}")]
46    SchemaValidation(String),
47}
48
49impl From<ToolSimulationError> for SimulationError {
50    fn from(err: ToolSimulationError) -> Self {
51        match err {
52            ToolSimulationError::SchemaValidation(reason) => Self::SchemaValidation(reason),
53            other => Self::Tool(other),
54        }
55    }
56}
57
58/// Orchestrate a multi-turn dialogue between `agent` and `actor`.
59#[allow(clippy::too_many_lines)]
60pub async fn run_multiturn_simulation(
61    agent: &mut Agent,
62    actor: &ActorSimulator,
63    tool_sim: Option<&ToolSimulator>,
64    max_turns: u32,
65    cancel: CancellationToken,
66) -> Result<(Invocation, SimulationOutcome), SimulationError> {
67    let overall_start = Instant::now();
68    let mut outcome = SimulationOutcome::AgentStopped;
69    let mut collector = TrajectoryCollector::new();
70    let mut next_user: ActorTurn = actor.greeting();
71    let mut turn_count: u32 = 0;
72
73    while turn_count < max_turns {
74        if cancel.is_cancelled() {
75            return Err(SimulationError::Cancelled);
76        }
77        if next_user.goal_completed.is_some() {
78            outcome = SimulationOutcome::GoalCompleted;
79            break;
80        }
81
82        let conversation = vec![swink_agent::AgentMessage::Llm(LlmMessage::User(
83            UserMessage {
84                content: vec![ContentBlock::Text {
85                    text: next_user.message.clone(),
86                }],
87                timestamp: now_timestamp(),
88                cache_hint: None,
89            },
90        ))];
91        let stream = agent
92            .prompt_stream(conversation)
93            .map_err(|err| SimulationError::Agent(err.to_string()))?;
94        tokio::pin!(stream);
95
96        let mut pending_tool_calls: Vec<RecordedToolCall> = Vec::new();
97        let mut last_assistant_text = String::new();
98
99        loop {
100            tokio::select! {
101                biased;
102                () = cancel.cancelled() => return Err(SimulationError::Cancelled),
103                next = stream.next() => match next {
104                    None => break,
105                    Some(event) => {
106                        if let AgentEvent::TurnEnd { assistant_message, .. } = &event {
107                            last_assistant_text =
108                                ContentBlock::extract_text(&assistant_message.content);
109                            for block in &assistant_message.content {
110                                if let ContentBlock::ToolCall {
111                                    id, name, arguments, ..
112                                } = block
113                                {
114                                    pending_tool_calls.push(RecordedToolCall {
115                                        id: id.clone(),
116                                        name: name.clone(),
117                                        arguments: arguments.clone(),
118                                    });
119                                }
120                            }
121                        }
122                        collector.observe(&event);
123                    }
124                },
125            }
126        }
127
128        // Optionally attach simulated tool results to the most recent turn.
129        if let (Some(sim), false) = (tool_sim, pending_tool_calls.is_empty()) {
130            let last_idx = collector.turns_len_hint().checked_sub(1);
131            for call in std::mem::take(&mut pending_tool_calls) {
132                let value = sim
133                    .invoke(&call.name, &call.arguments, &call.id)
134                    .await
135                    .map_err(SimulationError::from)?;
136                if let Some(idx) = last_idx {
137                    collector.append_tool_result(
138                        idx,
139                        ToolResultMessage {
140                            tool_call_id: call.id.clone(),
141                            content: vec![ContentBlock::Text {
142                                text: value.to_string(),
143                            }],
144                            is_error: false,
145                            timestamp: now_timestamp(),
146                            details: serde_json::Value::Null,
147                            cache_hint: None,
148                        },
149                    );
150                }
151            }
152        }
153
154        turn_count += 1;
155        if turn_count >= max_turns {
156            outcome = SimulationOutcome::MaxTurnsReached;
157            break;
158        }
159
160        let assistant_text = if last_assistant_text.is_empty() {
161            "…".to_string()
162        } else {
163            last_assistant_text
164        };
165        let produced = actor
166            .next_turn(&assistant_text)
167            .await
168            .map_err(SimulationError::Actor)?;
169        if produced.goal_completed.is_some() {
170            outcome = SimulationOutcome::GoalCompleted;
171            break;
172        }
173        next_user = produced;
174    }
175
176    let mut invocation = collector.finish();
177    if invocation.total_duration == Duration::ZERO {
178        invocation.total_duration = overall_start.elapsed();
179    }
180    if invocation.model == ModelSpec::new("unknown", "unknown") {
181        invocation.model = ModelSpec::new("simulated", actor.model_id());
182    }
183    if invocation.turns.is_empty() {
184        invocation.total_usage = Usage::default();
185        invocation.total_cost = Cost::default();
186        invocation.stop_reason = StopReason::Stop;
187    }
188
189    Ok((invocation, outcome))
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn simulation_error_wraps_schema_variant() {
198        let err: SimulationError = ToolSimulationError::SchemaValidation("boom".into()).into();
199        assert!(matches!(err, SimulationError::SchemaValidation(_)));
200    }
201}