Skip to main content

tooltest_core/
input.rs

1use std::collections::{BTreeMap, HashSet};
2use std::ops::RangeInclusive;
3use std::sync::Arc;
4
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    load_lint_suite, HttpConfig, LintSuite, PreRunHook, RunConfig, RunnerOptions,
10    StateMachineConfig, StdioConfig, ToolNamePredicate, ToolPredicate,
11};
12
13fn default_cases() -> u32 {
14    32
15}
16
17fn default_min_sequence_len() -> usize {
18    1
19}
20
21fn default_max_sequence_len() -> usize {
22    20
23}
24
25fn default_uncallable_limit() -> usize {
26    1
27}
28
29/// Shared tooltest input type for CLI and MCP modes.
30#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
31#[serde(rename_all = "snake_case", deny_unknown_fields)]
32pub struct TooltestInput {
33    /// Target MCP transport configuration.
34    pub target: TooltestTarget,
35    /// Number of proptest cases to execute.
36    #[serde(default = "default_cases")]
37    pub cases: u32,
38    /// Minimum sequence length per generated run.
39    #[serde(default = "default_min_sequence_len")]
40    pub min_sequence_len: usize,
41    /// Maximum sequence length per generated run.
42    #[serde(default = "default_max_sequence_len")]
43    pub max_sequence_len: usize,
44    /// Allow schema-based generation when corpus lacks required values.
45    #[serde(default)]
46    pub lenient_sourcing: bool,
47    /// Mine whitespace-delimited text tokens into the state corpus.
48    #[serde(default)]
49    pub mine_text: bool,
50    /// Dump the final state-machine corpus after the run completes.
51    #[serde(default)]
52    pub dump_corpus: bool,
53    /// Log newly mined corpus values after each tool response.
54    #[serde(default)]
55    pub log_corpus_deltas: bool,
56    /// Disable schema-based generation when corpus lacks required values.
57    #[serde(default)]
58    pub no_lenient_sourcing: bool,
59    /// State-machine config overrides.
60    #[serde(default)]
61    pub state_machine_config: Option<StateMachineConfig>,
62    /// Allowlist tool names eligible for invocation generation.
63    #[serde(default)]
64    pub tool_allowlist: Vec<String>,
65    /// Blocklist tool names excluded from invocation generation.
66    #[serde(default)]
67    pub tool_blocklist: Vec<String>,
68    /// Fail the run when a tool result reports `isError = true`.
69    #[serde(default)]
70    pub in_band_error_forbidden: bool,
71    /// Pre-run hook configuration.
72    #[serde(default)]
73    pub pre_run_hook: Option<TooltestPreRunHook>,
74    /// Include tool responses in the trace output.
75    #[serde(default)]
76    pub full_trace: bool,
77    /// Include uncallable tool traces when coverage validation fails.
78    #[serde(default)]
79    pub show_uncallable: bool,
80    /// Number of calls per tool to include in uncallable traces.
81    #[serde(default = "default_uncallable_limit")]
82    pub uncallable_limit: usize,
83}
84
85impl TooltestInput {
86    /// Validates the input to match CLI semantics.
87    pub fn validate(&self) -> Result<(), String> {
88        if self.uncallable_limit < 1 {
89            return Err("uncallable-limit must be at least 1".to_string());
90        }
91        self.validate_run_config()?;
92        build_sequence_len(self.min_sequence_len, self.max_sequence_len)?;
93        Ok(())
94    }
95
96    /// Builds the target configuration for the run.
97    pub fn to_target_config(&self) -> Result<TooltestTargetConfig, String> {
98        self.target.to_config()
99    }
100
101    /// Builds the run configuration for the run.
102    pub fn to_run_config(&self) -> Result<RunConfig, String> {
103        self.to_run_config_with_lints(load_lint_suite())
104    }
105
106    fn to_run_config_with_lints(
107        &self,
108        lint_suite: Result<LintSuite, String>,
109    ) -> Result<RunConfig, String> {
110        self.validate_run_config()?;
111        let mut state_machine = self.state_machine_config.clone().unwrap_or_default();
112        if self.lenient_sourcing {
113            state_machine.lenient_sourcing = true;
114        } else if self.no_lenient_sourcing {
115            state_machine.lenient_sourcing = false;
116        }
117        if self.mine_text {
118            state_machine.mine_text = true;
119        }
120        if self.dump_corpus {
121            state_machine.dump_corpus = true;
122        }
123        if self.log_corpus_deltas {
124            state_machine.log_corpus_deltas = true;
125        }
126
127        let mut run_config = RunConfig::new()
128            .with_state_machine(state_machine)
129            .with_full_trace(self.full_trace)
130            .with_show_uncallable(self.show_uncallable);
131
132        run_config = run_config.with_uncallable_limit(self.uncallable_limit)?;
133
134        if let Some(hook) = self.pre_run_hook.as_ref() {
135            run_config = run_config.with_pre_run_hook(hook.to_pre_run_hook());
136        }
137        if self.in_band_error_forbidden {
138            run_config = run_config.with_in_band_error_forbidden(true);
139        }
140        if let Some(filters) = build_tool_filters(&self.tool_allowlist, &self.tool_blocklist) {
141            run_config = run_config
142                .with_predicate(filters.predicate)
143                .with_tool_filter(filters.name_predicate);
144        }
145
146        let lints = lint_suite.map_err(|error| format!("lint config error: {error}"))?;
147        Ok(run_config.with_lints(lints))
148    }
149
150    /// Builds the runner options for the run.
151    pub fn to_runner_options(&self) -> Result<RunnerOptions, String> {
152        let sequence_len = build_sequence_len(self.min_sequence_len, self.max_sequence_len)?;
153        RunnerOptions::new(self.cases, sequence_len)
154    }
155
156    /// Builds the target configuration, run configuration, and runner options together.
157    pub fn to_configs(&self) -> Result<TooltestRunConfig, String> {
158        let target = self.to_target_config()?;
159        let run_config = self.to_run_config();
160        let runner_options = self.to_runner_options();
161        match (run_config, runner_options) {
162            (Ok(run_config), Ok(runner_options)) => Ok(TooltestRunConfig {
163                target,
164                run_config,
165                runner_options,
166            }),
167            (Err(error), _) => Err(error),
168            (_, Err(error)) => Err(error),
169        }
170    }
171
172    fn validate_run_config(&self) -> Result<(), String> {
173        if self.cases < 1 {
174            return Err("cases must be at least 1".to_string());
175        }
176        if self.lenient_sourcing && self.no_lenient_sourcing {
177            return Err("lenient-sourcing conflicts with no-lenient-sourcing".to_string());
178        }
179        self.target.validate()?;
180        Ok(())
181    }
182}
183
184/// Target configuration input wrapper.
185#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
186#[serde(rename_all = "snake_case", untagged)]
187pub enum TooltestTarget {
188    /// Stdio transport configuration.
189    Stdio(TooltestTargetStdio),
190    /// HTTP transport configuration.
191    Http(TooltestTargetHttp),
192}
193
194#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
195#[serde(rename_all = "snake_case", deny_unknown_fields)]
196pub struct TooltestTargetStdio {
197    /// Stdio transport configuration.
198    pub stdio: TooltestStdioTarget,
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
202#[serde(rename_all = "snake_case", deny_unknown_fields)]
203pub struct TooltestTargetHttp {
204    /// HTTP transport configuration.
205    pub http: TooltestHttpTarget,
206}
207
208impl TooltestTarget {
209    fn validate(&self) -> Result<(), String> {
210        match self {
211            TooltestTarget::Stdio(wrapper) => crate::validate_stdio_command(&wrapper.stdio.command),
212            TooltestTarget::Http(wrapper) => crate::validate_http_url(&wrapper.http.url),
213        }
214    }
215
216    fn to_config(&self) -> Result<TooltestTargetConfig, String> {
217        match self {
218            TooltestTarget::Stdio(wrapper) => {
219                Ok(TooltestTargetConfig::Stdio(wrapper.stdio.to_config()?))
220            }
221            TooltestTarget::Http(wrapper) => {
222                Ok(TooltestTargetConfig::Http(wrapper.http.to_config()?))
223            }
224        }
225    }
226}
227
228/// Stdio transport input configuration.
229#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
230#[serde(rename_all = "snake_case", deny_unknown_fields)]
231pub struct TooltestStdioTarget {
232    /// Command to execute for the MCP server.
233    #[schemars(length(min = 1))]
234    pub command: String,
235    /// Command-line arguments passed to the MCP server.
236    #[serde(default)]
237    pub args: Vec<String>,
238    /// Environment variables to add or override for the MCP process.
239    #[serde(default)]
240    pub env: BTreeMap<String, String>,
241    /// Optional working directory for the MCP process.
242    #[serde(default)]
243    pub cwd: Option<String>,
244}
245
246impl TooltestStdioTarget {
247    fn to_config(&self) -> Result<StdioConfig, String> {
248        let mut config = StdioConfig::new(self.command.clone())?;
249        config.args = self.args.clone();
250        config.env = self.env.clone();
251        config.cwd = self.cwd.clone();
252        Ok(config)
253    }
254}
255
256/// HTTP transport input configuration.
257#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
258#[serde(rename_all = "snake_case", deny_unknown_fields)]
259pub struct TooltestHttpTarget {
260    /// MCP endpoint URL.
261    #[schemars(length(min = 1))]
262    pub url: String,
263    /// Authorization bearer token.
264    #[serde(default)]
265    pub auth_token: Option<String>,
266}
267
268impl TooltestHttpTarget {
269    fn to_config(&self) -> Result<HttpConfig, String> {
270        let mut config = HttpConfig::new(self.url.clone())?;
271        config.auth_token = self.auth_token.clone();
272        Ok(config)
273    }
274}
275
276/// Pre-run hook input configuration.
277#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
278#[serde(rename_all = "snake_case", deny_unknown_fields)]
279pub struct TooltestPreRunHook {
280    /// Shell command string to execute before each run and validation.
281    pub command: String,
282    /// Environment variables to add or override for the hook process.
283    #[serde(default)]
284    pub env: BTreeMap<String, String>,
285    /// Optional working directory for the hook process.
286    #[serde(default)]
287    pub cwd: Option<String>,
288}
289
290impl TooltestPreRunHook {
291    fn to_pre_run_hook(&self) -> PreRunHook {
292        PreRunHook {
293            command: self.command.clone(),
294            env: self.env.clone(),
295            cwd: self.cwd.clone(),
296        }
297    }
298}
299
300/// Target configuration for a tooltest run.
301#[derive(Debug)]
302pub enum TooltestTargetConfig {
303    /// Stdio transport configuration.
304    Stdio(StdioConfig),
305    /// HTTP transport configuration.
306    Http(HttpConfig),
307}
308
309/// Combined configuration output from shared tooltest input.
310#[derive(Debug)]
311pub struct TooltestRunConfig {
312    /// Target transport configuration.
313    pub target: TooltestTargetConfig,
314    /// Run configuration.
315    pub run_config: RunConfig,
316    /// Runner options.
317    pub runner_options: RunnerOptions,
318}
319
320struct ToolFilters {
321    predicate: ToolPredicate,
322    name_predicate: ToolNamePredicate,
323}
324
325fn build_tool_filters(allowlist: &[String], blocklist: &[String]) -> Option<ToolFilters> {
326    if allowlist.is_empty() && blocklist.is_empty() {
327        return None;
328    }
329    let allowlist =
330        (!allowlist.is_empty()).then(|| allowlist.iter().cloned().collect::<HashSet<_>>());
331    let blocklist =
332        (!blocklist.is_empty()).then(|| blocklist.iter().cloned().collect::<HashSet<_>>());
333    let name_predicate: ToolNamePredicate = Arc::new(move |tool_name| {
334        if let Some(allowlist) = allowlist.as_ref() {
335            if !allowlist.contains(tool_name) {
336                return false;
337            }
338        }
339        if let Some(blocklist) = blocklist.as_ref() {
340            if blocklist.contains(tool_name) {
341                return false;
342            }
343        }
344        true
345    });
346    let predicate_name = Arc::clone(&name_predicate);
347    let predicate: ToolPredicate = Arc::new(move |tool_name, _input| predicate_name(tool_name));
348    Some(ToolFilters {
349        predicate,
350        name_predicate,
351    })
352}
353
354fn build_sequence_len(min_len: usize, max_len: usize) -> Result<RangeInclusive<usize>, String> {
355    if min_len == 0 {
356        return Err("min-sequence-len must be at least 1".to_string());
357    }
358    if min_len > max_len {
359        return Err("min-sequence-len must be <= max-sequence-len".to_string());
360    }
361    Ok(min_len..=max_len)
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn target_validate_rejects_empty_command() {
370        let target = TooltestTarget::Stdio(TooltestTargetStdio {
371            stdio: TooltestStdioTarget {
372                command: "  ".to_string(),
373                args: Vec::new(),
374                env: BTreeMap::new(),
375                cwd: None,
376            },
377        });
378        let error = target.validate().unwrap_err();
379        assert!(error.contains("stdio command"));
380    }
381
382    #[test]
383    fn target_to_config_rejects_invalid_http_url() {
384        let target = TooltestTarget::Http(TooltestTargetHttp {
385            http: TooltestHttpTarget {
386                url: "localhost:8080/mcp".to_string(),
387                auth_token: None,
388            },
389        });
390        let error = target.to_config().unwrap_err();
391        assert!(error.contains("invalid http url"));
392    }
393
394    #[test]
395    fn to_run_config_reports_lint_config_error() {
396        let input = TooltestInput {
397            target: TooltestTarget::Stdio(TooltestTargetStdio {
398                stdio: TooltestStdioTarget {
399                    command: "server".to_string(),
400                    args: Vec::new(),
401                    env: BTreeMap::new(),
402                    cwd: None,
403                },
404            }),
405            cases: 1,
406            min_sequence_len: 1,
407            max_sequence_len: 1,
408            lenient_sourcing: false,
409            mine_text: false,
410            dump_corpus: false,
411            log_corpus_deltas: false,
412            no_lenient_sourcing: false,
413            state_machine_config: None,
414            tool_allowlist: Vec::new(),
415            tool_blocklist: Vec::new(),
416            in_band_error_forbidden: false,
417            pre_run_hook: None,
418            full_trace: false,
419            show_uncallable: false,
420            uncallable_limit: 1,
421        };
422        let error = input
423            .to_run_config_with_lints(Err("bad lint config".to_string()))
424            .expect_err("lint config error");
425        assert!(error.contains("lint config error"));
426    }
427}