swink_agent_eval/simulation/
orchestrator.rs1#![forbid(unsafe_code)]
9
10use std::time::{Duration, Instant};
11
12use futures::StreamExt;
13use swink_agent::{
14 Agent, AgentEvent, ContentBlock, Cost, LlmMessage, ModelSpec, StopReason, ToolResultMessage,
15 Usage, UserMessage, now_timestamp,
16};
17use tokio_util::sync::CancellationToken;
18
19use super::actor::{ActorSimulator, ActorTurn};
20use super::tool::{ToolSimulationError, ToolSimulator};
21use crate::judge::JudgeError;
22use crate::trajectory::TrajectoryCollector;
23use crate::types::{Invocation, RecordedToolCall};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SimulationOutcome {
28 GoalCompleted,
29 MaxTurnsReached,
30 AgentStopped,
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum SimulationError {
36 #[error("actor error: {0}")]
37 Actor(#[source] JudgeError),
38 #[error("tool error: {0}")]
39 Tool(#[source] ToolSimulationError),
40 #[error("simulation cancelled")]
41 Cancelled,
42 #[error("agent error: {0}")]
43 Agent(String),
44 #[error("schema validation failed: {0}")]
46 SchemaValidation(String),
47}
48
49impl From<ToolSimulationError> for SimulationError {
50 fn from(err: ToolSimulationError) -> Self {
51 match err {
52 ToolSimulationError::SchemaValidation(reason) => Self::SchemaValidation(reason),
53 other => Self::Tool(other),
54 }
55 }
56}
57
58#[allow(clippy::too_many_lines)]
60pub async fn run_multiturn_simulation(
61 agent: &mut Agent,
62 actor: &ActorSimulator,
63 tool_sim: Option<&ToolSimulator>,
64 max_turns: u32,
65 cancel: CancellationToken,
66) -> Result<(Invocation, SimulationOutcome), SimulationError> {
67 let overall_start = Instant::now();
68 let mut outcome = SimulationOutcome::AgentStopped;
69 let mut collector = TrajectoryCollector::new();
70 let mut next_user: ActorTurn = actor.greeting();
71 let mut turn_count: u32 = 0;
72
73 while turn_count < max_turns {
74 if cancel.is_cancelled() {
75 return Err(SimulationError::Cancelled);
76 }
77 if next_user.goal_completed.is_some() {
78 outcome = SimulationOutcome::GoalCompleted;
79 break;
80 }
81
82 let conversation = vec![swink_agent::AgentMessage::Llm(LlmMessage::User(
83 UserMessage {
84 content: vec![ContentBlock::Text {
85 text: next_user.message.clone(),
86 }],
87 timestamp: now_timestamp(),
88 cache_hint: None,
89 },
90 ))];
91 let stream = agent
92 .prompt_stream(conversation)
93 .map_err(|err| SimulationError::Agent(err.to_string()))?;
94 tokio::pin!(stream);
95
96 let mut pending_tool_calls: Vec<RecordedToolCall> = Vec::new();
97 let mut last_assistant_text = String::new();
98
99 loop {
100 tokio::select! {
101 biased;
102 () = cancel.cancelled() => return Err(SimulationError::Cancelled),
103 next = stream.next() => match next {
104 None => break,
105 Some(event) => {
106 if let AgentEvent::TurnEnd { assistant_message, .. } = &event {
107 last_assistant_text =
108 ContentBlock::extract_text(&assistant_message.content);
109 for block in &assistant_message.content {
110 if let ContentBlock::ToolCall {
111 id, name, arguments, ..
112 } = block
113 {
114 pending_tool_calls.push(RecordedToolCall {
115 id: id.clone(),
116 name: name.clone(),
117 arguments: arguments.clone(),
118 });
119 }
120 }
121 }
122 collector.observe(&event);
123 }
124 },
125 }
126 }
127
128 if let (Some(sim), false) = (tool_sim, pending_tool_calls.is_empty()) {
130 let last_idx = collector.turns_len_hint().checked_sub(1);
131 for call in std::mem::take(&mut pending_tool_calls) {
132 let value = sim
133 .invoke(&call.name, &call.arguments, &call.id)
134 .await
135 .map_err(SimulationError::from)?;
136 if let Some(idx) = last_idx {
137 collector.append_tool_result(
138 idx,
139 ToolResultMessage {
140 tool_call_id: call.id.clone(),
141 content: vec![ContentBlock::Text {
142 text: value.to_string(),
143 }],
144 is_error: false,
145 timestamp: now_timestamp(),
146 details: serde_json::Value::Null,
147 cache_hint: None,
148 },
149 );
150 }
151 }
152 }
153
154 turn_count += 1;
155 if turn_count >= max_turns {
156 outcome = SimulationOutcome::MaxTurnsReached;
157 break;
158 }
159
160 let assistant_text = if last_assistant_text.is_empty() {
161 "…".to_string()
162 } else {
163 last_assistant_text
164 };
165 let produced = actor
166 .next_turn(&assistant_text)
167 .await
168 .map_err(SimulationError::Actor)?;
169 if produced.goal_completed.is_some() {
170 outcome = SimulationOutcome::GoalCompleted;
171 break;
172 }
173 next_user = produced;
174 }
175
176 let mut invocation = collector.finish();
177 if invocation.total_duration == Duration::ZERO {
178 invocation.total_duration = overall_start.elapsed();
179 }
180 if invocation.model == ModelSpec::new("unknown", "unknown") {
181 invocation.model = ModelSpec::new("simulated", actor.model_id());
182 }
183 if invocation.turns.is_empty() {
184 invocation.total_usage = Usage::default();
185 invocation.total_cost = Cost::default();
186 invocation.stop_reason = StopReason::Stop;
187 }
188
189 Ok((invocation, outcome))
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn simulation_error_wraps_schema_variant() {
198 let err: SimulationError = ToolSimulationError::SchemaValidation("boom".into()).into();
199 assert!(matches!(err, SimulationError::SchemaValidation(_)));
200 }
201}