swink_agent_eval/simulation/
tool.rs1#![forbid(unsafe_code)]
4
5use std::collections::{HashMap, VecDeque};
6use std::sync::{Arc, Mutex};
7use std::time::SystemTime;
8
9use crate::judge::{JudgeClient, JudgeError};
10
11pub const DEFAULT_HISTORY_CAP: usize = 32;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct ToolSchema {
17 pub name: String,
18 pub response_schema: serde_json::Value,
20}
21
22impl ToolSchema {
23 #[must_use]
24 pub fn new(name: impl Into<String>, response_schema: serde_json::Value) -> Self {
25 Self {
26 name: name.into(),
27 response_schema,
28 }
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct ToolCallRecord {
35 pub tool: String,
36 pub args: serde_json::Value,
37 pub result: serde_json::Value,
38 pub timestamp: SystemTime,
39}
40
41#[derive(Debug, Clone)]
43pub struct StateBucket {
44 pub shared_state: serde_json::Value,
45 pub history: VecDeque<ToolCallRecord>,
46 history_cap: usize,
47}
48
49impl StateBucket {
50 #[must_use]
52 pub fn with_capacity(history_cap: usize) -> Self {
53 Self {
54 shared_state: serde_json::Value::Null,
55 history: VecDeque::new(),
56 history_cap: history_cap.max(1),
57 }
58 }
59
60 pub fn record(&mut self, record: ToolCallRecord) {
62 self.history.push_back(record);
63 while self.history.len() > self.history_cap {
64 self.history.pop_front();
65 }
66 }
67
68 #[must_use]
69 pub const fn history_cap(&self) -> usize {
70 self.history_cap
71 }
72}
73
74#[derive(Debug)]
76pub struct StateRegistry {
77 buckets: Mutex<HashMap<String, StateBucket>>,
78 history_cap: usize,
79}
80
81impl StateRegistry {
82 #[must_use]
83 pub fn new() -> Self {
84 Self::with_history_cap(DEFAULT_HISTORY_CAP)
85 }
86
87 #[must_use]
88 pub fn with_history_cap(history_cap: usize) -> Self {
89 Self {
90 buckets: Mutex::new(HashMap::new()),
91 history_cap: history_cap.max(1),
92 }
93 }
94
95 pub fn with_bucket<R>(&self, key: &str, f: impl FnOnce(&mut StateBucket) -> R) -> R {
97 let mut buckets = self
98 .buckets
99 .lock()
100 .unwrap_or_else(std::sync::PoisonError::into_inner);
101 let bucket = buckets
102 .entry(key.to_string())
103 .or_insert_with(|| StateBucket::with_capacity(self.history_cap));
104 f(bucket)
105 }
106
107 #[must_use]
108 pub fn history_snapshot(&self, key: &str) -> Vec<ToolCallRecord> {
109 let buckets = self
110 .buckets
111 .lock()
112 .unwrap_or_else(std::sync::PoisonError::into_inner);
113 buckets
114 .get(key)
115 .map(|bucket| bucket.history.iter().cloned().collect())
116 .unwrap_or_default()
117 }
118
119 #[must_use]
120 pub fn bucket_count(&self) -> usize {
121 self.buckets
122 .lock()
123 .unwrap_or_else(std::sync::PoisonError::into_inner)
124 .len()
125 }
126
127 #[must_use]
128 pub const fn history_cap(&self) -> usize {
129 self.history_cap
130 }
131}
132
133impl Default for StateRegistry {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139pub struct ToolSimulator {
142 tools: HashMap<String, ToolSchema>,
143 judge: Arc<dyn JudgeClient>,
144 model_id: String,
145 registry: Arc<StateRegistry>,
146}
147
148impl ToolSimulator {
149 #[must_use]
150 pub fn new(
151 tools: Vec<ToolSchema>,
152 judge: Arc<dyn JudgeClient>,
153 model_id: impl Into<String>,
154 ) -> Self {
155 Self::with_registry(tools, judge, model_id, Arc::new(StateRegistry::new()))
156 }
157
158 #[must_use]
159 pub fn with_registry(
160 tools: Vec<ToolSchema>,
161 judge: Arc<dyn JudgeClient>,
162 model_id: impl Into<String>,
163 registry: Arc<StateRegistry>,
164 ) -> Self {
165 let tools = tools
166 .into_iter()
167 .map(|schema| (schema.name.clone(), schema))
168 .collect();
169 Self {
170 tools,
171 judge,
172 model_id: model_id.into(),
173 registry,
174 }
175 }
176
177 #[must_use]
178 pub fn registry(&self) -> &Arc<StateRegistry> {
179 &self.registry
180 }
181
182 pub fn tool_names(&self) -> impl Iterator<Item = &str> {
183 self.tools.keys().map(String::as_str)
184 }
185
186 #[must_use]
187 pub fn model_id(&self) -> &str {
188 &self.model_id
189 }
190
191 pub async fn invoke(
194 &self,
195 tool_name: &str,
196 args: &serde_json::Value,
197 state_key: &str,
198 ) -> Result<serde_json::Value, ToolSimulationError> {
199 let schema = self
200 .tools
201 .get(tool_name)
202 .ok_or_else(|| ToolSimulationError::UnknownTool(tool_name.to_string()))?;
203
204 let history = self.registry.history_snapshot(state_key);
205 let prompt = render_tool_prompt(tool_name, args, &history);
206 let verdict = self
207 .judge
208 .judge(&prompt)
209 .await
210 .map_err(ToolSimulationError::Judge)?;
211 let body = verdict
212 .reason
213 .ok_or_else(|| ToolSimulationError::MissingBody(tool_name.to_string()))?;
214 let value: serde_json::Value = serde_json::from_str(body.trim())
215 .map_err(|err| ToolSimulationError::Parse(err.to_string()))?;
216
217 validate_against_schema(&value, &schema.response_schema)?;
218
219 self.registry.with_bucket(state_key, |bucket| {
220 bucket.record(ToolCallRecord {
221 tool: tool_name.to_string(),
222 args: args.clone(),
223 result: value.clone(),
224 timestamp: SystemTime::now(),
225 });
226 });
227
228 Ok(value)
229 }
230}
231
232fn render_tool_prompt(tool: &str, args: &serde_json::Value, history: &[ToolCallRecord]) -> String {
233 let mut prompt = format!("Simulate a response for tool `{tool}`.\nArguments: {args}\n");
234 if !history.is_empty() {
235 prompt.push_str("Prior calls in bucket:\n");
236 for (idx, record) in history.iter().enumerate() {
237 prompt.push_str(&format!(
238 "- [{idx}] {} args={} -> {}\n",
239 record.tool, record.args, record.result
240 ));
241 }
242 }
243 prompt.push_str("Respond with a single JSON object matching the tool's response schema.");
244 prompt
245}
246
247fn validate_against_schema(
248 value: &serde_json::Value,
249 schema: &serde_json::Value,
250) -> Result<(), ToolSimulationError> {
251 let compiled = jsonschema::validator_for(schema)
252 .map_err(|err| ToolSimulationError::SchemaValidation(err.to_string()))?;
253 if let Err(err) = compiled.validate(value) {
254 return Err(ToolSimulationError::SchemaValidation(err.to_string()));
255 }
256 Ok(())
257}
258
259#[derive(Debug, thiserror::Error)]
261pub enum ToolSimulationError {
262 #[error("tool `{0}` not registered with simulator")]
263 UnknownTool(String),
264 #[error("judge produced no body for tool `{0}`")]
265 MissingBody(String),
266 #[error("schema validation failed: {0}")]
267 SchemaValidation(String),
268 #[error("tool response parse error: {0}")]
269 Parse(String),
270 #[error("judge error: {0}")]
271 Judge(#[source] JudgeError),
272}