1use std::collections::{BTreeMap, HashSet};
2use std::ops::RangeInclusive;
3use std::sync::Arc;
4
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 load_lint_suite, HttpConfig, LintSuite, PreRunHook, RunConfig, RunnerOptions,
10 StateMachineConfig, StdioConfig, ToolNamePredicate, ToolPredicate,
11};
12
13fn default_cases() -> u32 {
14 32
15}
16
17fn default_min_sequence_len() -> usize {
18 1
19}
20
21fn default_max_sequence_len() -> usize {
22 20
23}
24
25fn default_uncallable_limit() -> usize {
26 1
27}
28
29#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
31#[serde(rename_all = "snake_case", deny_unknown_fields)]
32pub struct TooltestInput {
33 pub target: TooltestTarget,
35 #[serde(default = "default_cases")]
37 pub cases: u32,
38 #[serde(default = "default_min_sequence_len")]
40 pub min_sequence_len: usize,
41 #[serde(default = "default_max_sequence_len")]
43 pub max_sequence_len: usize,
44 #[serde(default)]
46 pub lenient_sourcing: bool,
47 #[serde(default)]
49 pub mine_text: bool,
50 #[serde(default)]
52 pub dump_corpus: bool,
53 #[serde(default)]
55 pub log_corpus_deltas: bool,
56 #[serde(default)]
58 pub no_lenient_sourcing: bool,
59 #[serde(default)]
61 pub state_machine_config: Option<StateMachineConfig>,
62 #[serde(default)]
64 pub tool_allowlist: Vec<String>,
65 #[serde(default)]
67 pub tool_blocklist: Vec<String>,
68 #[serde(default)]
70 pub in_band_error_forbidden: bool,
71 #[serde(default)]
73 pub pre_run_hook: Option<TooltestPreRunHook>,
74 #[serde(default)]
76 pub full_trace: bool,
77 #[serde(default)]
79 pub show_uncallable: bool,
80 #[serde(default = "default_uncallable_limit")]
82 pub uncallable_limit: usize,
83}
84
85impl TooltestInput {
86 pub fn validate(&self) -> Result<(), String> {
88 if self.uncallable_limit < 1 {
89 return Err("uncallable-limit must be at least 1".to_string());
90 }
91 self.validate_run_config()?;
92 build_sequence_len(self.min_sequence_len, self.max_sequence_len)?;
93 Ok(())
94 }
95
96 pub fn to_target_config(&self) -> Result<TooltestTargetConfig, String> {
98 self.target.to_config()
99 }
100
101 pub fn to_run_config(&self) -> Result<RunConfig, String> {
103 self.to_run_config_with_lints(load_lint_suite())
104 }
105
106 fn to_run_config_with_lints(
107 &self,
108 lint_suite: Result<LintSuite, String>,
109 ) -> Result<RunConfig, String> {
110 self.validate_run_config()?;
111 let mut state_machine = self.state_machine_config.clone().unwrap_or_default();
112 if self.lenient_sourcing {
113 state_machine.lenient_sourcing = true;
114 } else if self.no_lenient_sourcing {
115 state_machine.lenient_sourcing = false;
116 }
117 if self.mine_text {
118 state_machine.mine_text = true;
119 }
120 if self.dump_corpus {
121 state_machine.dump_corpus = true;
122 }
123 if self.log_corpus_deltas {
124 state_machine.log_corpus_deltas = true;
125 }
126
127 let mut run_config = RunConfig::new()
128 .with_state_machine(state_machine)
129 .with_full_trace(self.full_trace)
130 .with_show_uncallable(self.show_uncallable);
131
132 run_config = run_config.with_uncallable_limit(self.uncallable_limit)?;
133
134 if let Some(hook) = self.pre_run_hook.as_ref() {
135 run_config = run_config.with_pre_run_hook(hook.to_pre_run_hook());
136 }
137 if self.in_band_error_forbidden {
138 run_config = run_config.with_in_band_error_forbidden(true);
139 }
140 if let Some(filters) = build_tool_filters(&self.tool_allowlist, &self.tool_blocklist) {
141 run_config = run_config
142 .with_predicate(filters.predicate)
143 .with_tool_filter(filters.name_predicate);
144 }
145
146 let lints = lint_suite.map_err(|error| format!("lint config error: {error}"))?;
147 Ok(run_config.with_lints(lints))
148 }
149
150 pub fn to_runner_options(&self) -> Result<RunnerOptions, String> {
152 let sequence_len = build_sequence_len(self.min_sequence_len, self.max_sequence_len)?;
153 RunnerOptions::new(self.cases, sequence_len)
154 }
155
156 pub fn to_configs(&self) -> Result<TooltestRunConfig, String> {
158 let target = self.to_target_config()?;
159 let run_config = self.to_run_config();
160 let runner_options = self.to_runner_options();
161 match (run_config, runner_options) {
162 (Ok(run_config), Ok(runner_options)) => Ok(TooltestRunConfig {
163 target,
164 run_config,
165 runner_options,
166 }),
167 (Err(error), _) => Err(error),
168 (_, Err(error)) => Err(error),
169 }
170 }
171
172 fn validate_run_config(&self) -> Result<(), String> {
173 if self.cases < 1 {
174 return Err("cases must be at least 1".to_string());
175 }
176 if self.lenient_sourcing && self.no_lenient_sourcing {
177 return Err("lenient-sourcing conflicts with no-lenient-sourcing".to_string());
178 }
179 self.target.validate()?;
180 Ok(())
181 }
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
186#[serde(rename_all = "snake_case", untagged)]
187pub enum TooltestTarget {
188 Stdio(TooltestTargetStdio),
190 Http(TooltestTargetHttp),
192}
193
194#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
195#[serde(rename_all = "snake_case", deny_unknown_fields)]
196pub struct TooltestTargetStdio {
197 pub stdio: TooltestStdioTarget,
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
202#[serde(rename_all = "snake_case", deny_unknown_fields)]
203pub struct TooltestTargetHttp {
204 pub http: TooltestHttpTarget,
206}
207
208impl TooltestTarget {
209 fn validate(&self) -> Result<(), String> {
210 match self {
211 TooltestTarget::Stdio(wrapper) => crate::validate_stdio_command(&wrapper.stdio.command),
212 TooltestTarget::Http(wrapper) => crate::validate_http_url(&wrapper.http.url),
213 }
214 }
215
216 fn to_config(&self) -> Result<TooltestTargetConfig, String> {
217 match self {
218 TooltestTarget::Stdio(wrapper) => {
219 Ok(TooltestTargetConfig::Stdio(wrapper.stdio.to_config()?))
220 }
221 TooltestTarget::Http(wrapper) => {
222 Ok(TooltestTargetConfig::Http(wrapper.http.to_config()?))
223 }
224 }
225 }
226}
227
228#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
230#[serde(rename_all = "snake_case", deny_unknown_fields)]
231pub struct TooltestStdioTarget {
232 #[schemars(length(min = 1))]
234 pub command: String,
235 #[serde(default)]
237 pub args: Vec<String>,
238 #[serde(default)]
240 pub env: BTreeMap<String, String>,
241 #[serde(default)]
243 pub cwd: Option<String>,
244}
245
246impl TooltestStdioTarget {
247 fn to_config(&self) -> Result<StdioConfig, String> {
248 let mut config = StdioConfig::new(self.command.clone())?;
249 config.args = self.args.clone();
250 config.env = self.env.clone();
251 config.cwd = self.cwd.clone();
252 Ok(config)
253 }
254}
255
256#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
258#[serde(rename_all = "snake_case", deny_unknown_fields)]
259pub struct TooltestHttpTarget {
260 #[schemars(length(min = 1))]
262 pub url: String,
263 #[serde(default)]
265 pub auth_token: Option<String>,
266}
267
268impl TooltestHttpTarget {
269 fn to_config(&self) -> Result<HttpConfig, String> {
270 let mut config = HttpConfig::new(self.url.clone())?;
271 config.auth_token = self.auth_token.clone();
272 Ok(config)
273 }
274}
275
276#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
278#[serde(rename_all = "snake_case", deny_unknown_fields)]
279pub struct TooltestPreRunHook {
280 pub command: String,
282 #[serde(default)]
284 pub env: BTreeMap<String, String>,
285 #[serde(default)]
287 pub cwd: Option<String>,
288}
289
290impl TooltestPreRunHook {
291 fn to_pre_run_hook(&self) -> PreRunHook {
292 PreRunHook {
293 command: self.command.clone(),
294 env: self.env.clone(),
295 cwd: self.cwd.clone(),
296 }
297 }
298}
299
300#[derive(Debug)]
302pub enum TooltestTargetConfig {
303 Stdio(StdioConfig),
305 Http(HttpConfig),
307}
308
309#[derive(Debug)]
311pub struct TooltestRunConfig {
312 pub target: TooltestTargetConfig,
314 pub run_config: RunConfig,
316 pub runner_options: RunnerOptions,
318}
319
320struct ToolFilters {
321 predicate: ToolPredicate,
322 name_predicate: ToolNamePredicate,
323}
324
325fn build_tool_filters(allowlist: &[String], blocklist: &[String]) -> Option<ToolFilters> {
326 if allowlist.is_empty() && blocklist.is_empty() {
327 return None;
328 }
329 let allowlist =
330 (!allowlist.is_empty()).then(|| allowlist.iter().cloned().collect::<HashSet<_>>());
331 let blocklist =
332 (!blocklist.is_empty()).then(|| blocklist.iter().cloned().collect::<HashSet<_>>());
333 let name_predicate: ToolNamePredicate = Arc::new(move |tool_name| {
334 if let Some(allowlist) = allowlist.as_ref() {
335 if !allowlist.contains(tool_name) {
336 return false;
337 }
338 }
339 if let Some(blocklist) = blocklist.as_ref() {
340 if blocklist.contains(tool_name) {
341 return false;
342 }
343 }
344 true
345 });
346 let predicate_name = Arc::clone(&name_predicate);
347 let predicate: ToolPredicate = Arc::new(move |tool_name, _input| predicate_name(tool_name));
348 Some(ToolFilters {
349 predicate,
350 name_predicate,
351 })
352}
353
354fn build_sequence_len(min_len: usize, max_len: usize) -> Result<RangeInclusive<usize>, String> {
355 if min_len == 0 {
356 return Err("min-sequence-len must be at least 1".to_string());
357 }
358 if min_len > max_len {
359 return Err("min-sequence-len must be <= max-sequence-len".to_string());
360 }
361 Ok(min_len..=max_len)
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn target_validate_rejects_empty_command() {
370 let target = TooltestTarget::Stdio(TooltestTargetStdio {
371 stdio: TooltestStdioTarget {
372 command: " ".to_string(),
373 args: Vec::new(),
374 env: BTreeMap::new(),
375 cwd: None,
376 },
377 });
378 let error = target.validate().unwrap_err();
379 assert!(error.contains("stdio command"));
380 }
381
382 #[test]
383 fn target_to_config_rejects_invalid_http_url() {
384 let target = TooltestTarget::Http(TooltestTargetHttp {
385 http: TooltestHttpTarget {
386 url: "localhost:8080/mcp".to_string(),
387 auth_token: None,
388 },
389 });
390 let error = target.to_config().unwrap_err();
391 assert!(error.contains("invalid http url"));
392 }
393
394 #[test]
395 fn to_run_config_reports_lint_config_error() {
396 let input = TooltestInput {
397 target: TooltestTarget::Stdio(TooltestTargetStdio {
398 stdio: TooltestStdioTarget {
399 command: "server".to_string(),
400 args: Vec::new(),
401 env: BTreeMap::new(),
402 cwd: None,
403 },
404 }),
405 cases: 1,
406 min_sequence_len: 1,
407 max_sequence_len: 1,
408 lenient_sourcing: false,
409 mine_text: false,
410 dump_corpus: false,
411 log_corpus_deltas: false,
412 no_lenient_sourcing: false,
413 state_machine_config: None,
414 tool_allowlist: Vec::new(),
415 tool_blocklist: Vec::new(),
416 in_band_error_forbidden: false,
417 pre_run_hook: None,
418 full_trace: false,
419 show_uncallable: false,
420 uncallable_limit: 1,
421 };
422 let error = input
423 .to_run_config_with_lints(Err("bad lint config".to_string()))
424 .expect_err("lint config error");
425 assert!(error.contains("lint config error"));
426 }
427}