Skip to main content

swink_agent_eval/simulation/
tool.rs

1//! Tool-call simulator and shared state registry (US4, FR-025).
2
3#![forbid(unsafe_code)]
4
5use std::collections::{HashMap, VecDeque};
6use std::sync::{Arc, Mutex};
7use std::time::SystemTime;
8
9use crate::judge::{JudgeClient, JudgeError};
10
11/// Default history retention per state bucket.
12pub const DEFAULT_HISTORY_CAP: usize = 32;
13
14/// Schema record for the simulator's tool catalogue.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct ToolSchema {
17    pub name: String,
18    /// JSON Schema describing the *response* shape produced for this tool.
19    pub response_schema: serde_json::Value,
20}
21
22impl ToolSchema {
23    #[must_use]
24    pub fn new(name: impl Into<String>, response_schema: serde_json::Value) -> Self {
25        Self {
26            name: name.into(),
27            response_schema,
28        }
29    }
30}
31
32/// One recorded tool invocation within a bucket.
33#[derive(Debug, Clone)]
34pub struct ToolCallRecord {
35    pub tool: String,
36    pub args: serde_json::Value,
37    pub result: serde_json::Value,
38    pub timestamp: SystemTime,
39}
40
41/// Mutable state shared across tool calls with the same `state_key`.
42#[derive(Debug, Clone)]
43pub struct StateBucket {
44    pub shared_state: serde_json::Value,
45    pub history: VecDeque<ToolCallRecord>,
46    history_cap: usize,
47}
48
49impl StateBucket {
50    /// A cap of `0` is promoted to `1`.
51    #[must_use]
52    pub fn with_capacity(history_cap: usize) -> Self {
53        Self {
54            shared_state: serde_json::Value::Null,
55            history: VecDeque::new(),
56            history_cap: history_cap.max(1),
57        }
58    }
59
60    /// Record a call, evicting the oldest if we exceed the cap.
61    pub fn record(&mut self, record: ToolCallRecord) {
62        self.history.push_back(record);
63        while self.history.len() > self.history_cap {
64            self.history.pop_front();
65        }
66    }
67
68    #[must_use]
69    pub const fn history_cap(&self) -> usize {
70        self.history_cap
71    }
72}
73
74/// Registry of [`StateBucket`]s keyed by arbitrary string.
75#[derive(Debug)]
76pub struct StateRegistry {
77    buckets: Mutex<HashMap<String, StateBucket>>,
78    history_cap: usize,
79}
80
81impl StateRegistry {
82    #[must_use]
83    pub fn new() -> Self {
84        Self::with_history_cap(DEFAULT_HISTORY_CAP)
85    }
86
87    #[must_use]
88    pub fn with_history_cap(history_cap: usize) -> Self {
89        Self {
90            buckets: Mutex::new(HashMap::new()),
91            history_cap: history_cap.max(1),
92        }
93    }
94
95    /// Run `f` with mutable access to the bucket for `key`, creating if absent.
96    pub fn with_bucket<R>(&self, key: &str, f: impl FnOnce(&mut StateBucket) -> R) -> R {
97        let mut buckets = self
98            .buckets
99            .lock()
100            .unwrap_or_else(std::sync::PoisonError::into_inner);
101        let bucket = buckets
102            .entry(key.to_string())
103            .or_insert_with(|| StateBucket::with_capacity(self.history_cap));
104        f(bucket)
105    }
106
107    #[must_use]
108    pub fn history_snapshot(&self, key: &str) -> Vec<ToolCallRecord> {
109        let buckets = self
110            .buckets
111            .lock()
112            .unwrap_or_else(std::sync::PoisonError::into_inner);
113        buckets
114            .get(key)
115            .map(|bucket| bucket.history.iter().cloned().collect())
116            .unwrap_or_default()
117    }
118
119    #[must_use]
120    pub fn bucket_count(&self) -> usize {
121        self.buckets
122            .lock()
123            .unwrap_or_else(std::sync::PoisonError::into_inner)
124            .len()
125    }
126
127    #[must_use]
128    pub const fn history_cap(&self) -> usize {
129        self.history_cap
130    }
131}
132
133impl Default for StateRegistry {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139/// Simulates tool responses by consulting a [`JudgeClient`] and validating
140/// the result against the registered [`ToolSchema::response_schema`].
141pub struct ToolSimulator {
142    tools: HashMap<String, ToolSchema>,
143    judge: Arc<dyn JudgeClient>,
144    model_id: String,
145    registry: Arc<StateRegistry>,
146}
147
148impl ToolSimulator {
149    #[must_use]
150    pub fn new(
151        tools: Vec<ToolSchema>,
152        judge: Arc<dyn JudgeClient>,
153        model_id: impl Into<String>,
154    ) -> Self {
155        Self::with_registry(tools, judge, model_id, Arc::new(StateRegistry::new()))
156    }
157
158    #[must_use]
159    pub fn with_registry(
160        tools: Vec<ToolSchema>,
161        judge: Arc<dyn JudgeClient>,
162        model_id: impl Into<String>,
163        registry: Arc<StateRegistry>,
164    ) -> Self {
165        let tools = tools
166            .into_iter()
167            .map(|schema| (schema.name.clone(), schema))
168            .collect();
169        Self {
170            tools,
171            judge,
172            model_id: model_id.into(),
173            registry,
174        }
175    }
176
177    #[must_use]
178    pub fn registry(&self) -> &Arc<StateRegistry> {
179        &self.registry
180    }
181
182    pub fn tool_names(&self) -> impl Iterator<Item = &str> {
183        self.tools.keys().map(String::as_str)
184    }
185
186    #[must_use]
187    pub fn model_id(&self) -> &str {
188        &self.model_id
189    }
190
191    /// Simulate one tool invocation, record it in the `state_key` bucket,
192    /// and return the schema-validated result body.
193    pub async fn invoke(
194        &self,
195        tool_name: &str,
196        args: &serde_json::Value,
197        state_key: &str,
198    ) -> Result<serde_json::Value, ToolSimulationError> {
199        let schema = self
200            .tools
201            .get(tool_name)
202            .ok_or_else(|| ToolSimulationError::UnknownTool(tool_name.to_string()))?;
203
204        let history = self.registry.history_snapshot(state_key);
205        let prompt = render_tool_prompt(tool_name, args, &history);
206        let verdict = self
207            .judge
208            .judge(&prompt)
209            .await
210            .map_err(ToolSimulationError::Judge)?;
211        let body = verdict
212            .reason
213            .ok_or_else(|| ToolSimulationError::MissingBody(tool_name.to_string()))?;
214        let value: serde_json::Value = serde_json::from_str(body.trim())
215            .map_err(|err| ToolSimulationError::Parse(err.to_string()))?;
216
217        validate_against_schema(&value, &schema.response_schema)?;
218
219        self.registry.with_bucket(state_key, |bucket| {
220            bucket.record(ToolCallRecord {
221                tool: tool_name.to_string(),
222                args: args.clone(),
223                result: value.clone(),
224                timestamp: SystemTime::now(),
225            });
226        });
227
228        Ok(value)
229    }
230}
231
232fn render_tool_prompt(tool: &str, args: &serde_json::Value, history: &[ToolCallRecord]) -> String {
233    let mut prompt = format!("Simulate a response for tool `{tool}`.\nArguments: {args}\n");
234    if !history.is_empty() {
235        prompt.push_str("Prior calls in bucket:\n");
236        for (idx, record) in history.iter().enumerate() {
237            prompt.push_str(&format!(
238                "- [{idx}] {} args={} -> {}\n",
239                record.tool, record.args, record.result
240            ));
241        }
242    }
243    prompt.push_str("Respond with a single JSON object matching the tool's response schema.");
244    prompt
245}
246
247fn validate_against_schema(
248    value: &serde_json::Value,
249    schema: &serde_json::Value,
250) -> Result<(), ToolSimulationError> {
251    let compiled = jsonschema::validator_for(schema)
252        .map_err(|err| ToolSimulationError::SchemaValidation(err.to_string()))?;
253    if let Err(err) = compiled.validate(value) {
254        return Err(ToolSimulationError::SchemaValidation(err.to_string()));
255    }
256    Ok(())
257}
258
259/// Errors surfaced by [`ToolSimulator::invoke`].
260#[derive(Debug, thiserror::Error)]
261pub enum ToolSimulationError {
262    #[error("tool `{0}` not registered with simulator")]
263    UnknownTool(String),
264    #[error("judge produced no body for tool `{0}`")]
265    MissingBody(String),
266    #[error("schema validation failed: {0}")]
267    SchemaValidation(String),
268    #[error("tool response parse error: {0}")]
269    Parse(String),
270    #[error("judge error: {0}")]
271    Judge(#[source] JudgeError),
272}