Skip to main content

spice_framework/
runner.rs

1use crate::agent::{AgentConfig, AgentUnderTest};
2use crate::assertion::AssertionResult;
3use crate::report::{SuiteReport, TestReport};
4use crate::test_case::{TestCase, TestSuite};
5use crate::trace::Trace;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::Semaphore;
10
11/// Configuration for the test runner.
12pub struct RunnerConfig {
13    /// Max concurrent tests.
14    pub concurrency: usize,
15    /// Default timeout per test.
16    pub default_timeout: Duration,
17    /// Filter: only run tests whose id or name contains this substring.
18    pub filter: Option<String>,
19    /// Filter: only run tests with any of these tags.
20    pub tag_filter: Option<Vec<String>>,
21    /// Directory to write trace files.
22    pub trace_dir: Option<PathBuf>,
23    /// Path to write JSON report.
24    pub report_path: Option<PathBuf>,
25    /// Print console output.
26    pub console_output: bool,
27}
28
29impl Default for RunnerConfig {
30    fn default() -> Self {
31        Self {
32            concurrency: 4,
33            default_timeout: Duration::from_secs(60),
34            filter: None,
35            tag_filter: None,
36            trace_dir: None,
37            report_path: None,
38            console_output: true,
39        }
40    }
41}
42
43/// The test runner.
44pub struct Runner {
45    pub config: RunnerConfig,
46}
47
48impl Runner {
49    pub fn new(config: RunnerConfig) -> Self {
50        Self { config }
51    }
52
53    /// Run a test suite against an agent, returning the suite report.
54    pub async fn run(
55        &self,
56        suite: TestSuite,
57        agent: Arc<dyn AgentUnderTest>,
58    ) -> SuiteReport {
59        let start = Instant::now();
60        let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
61
62        let tests: Vec<TestCase> = suite
63            .tests
64            .into_iter()
65            .filter(|t| self.matches_filter(t))
66            .collect();
67
68        let total = tests.len();
69        let mut handles = Vec::with_capacity(total);
70
71        for test_case in tests {
72            let sem = semaphore.clone();
73            let agent = agent.clone();
74            let default_timeout = suite
75                .default_timeout
76                .unwrap_or(self.config.default_timeout);
77            let default_retries = suite.default_retries;
78            let default_config = suite.default_config.clone();
79            let trace_dir = self.config.trace_dir.clone();
80
81            let handle = tokio::spawn(async move {
82                let _permit = sem.acquire().await.unwrap();
83                run_single_test(
84                    test_case,
85                    &*agent,
86                    default_timeout,
87                    default_retries,
88                    &default_config,
89                    trace_dir.as_ref(),
90                )
91                .await
92            });
93            handles.push(handle);
94        }
95
96        let mut reports = Vec::with_capacity(total);
97        for handle in handles {
98            match handle.await {
99                Ok(report) => reports.push(report),
100                Err(e) => {
101                    reports.push(TestReport {
102                        test_id: "unknown".into(),
103                        test_name: None,
104                        tags: vec![],
105                        passed: false,
106                        attempts: 0,
107                        assertion_results: vec![],
108                        duration: Duration::ZERO,
109                        error: Some(format!("Task panicked: {}", e)),
110                    });
111                }
112            }
113        }
114
115        let passed = reports.iter().filter(|r| r.passed).count();
116        let failed = reports.len() - passed;
117
118        let suite_report = SuiteReport {
119            suite_name: suite.name,
120            tests: reports,
121            total,
122            passed,
123            failed,
124            duration: start.elapsed(),
125            timestamp: chrono::Utc::now(),
126        };
127
128        if self.config.console_output {
129            suite_report.print_console();
130        }
131
132        if let Some(path) = &self.config.report_path {
133            if let Err(e) = suite_report.save_to_file(path) {
134                eprintln!("Failed to save report: {}", e);
135            }
136        }
137
138        suite_report
139    }
140
141    fn matches_filter(&self, test: &TestCase) -> bool {
142        if let Some(filter) = &self.config.filter {
143            let id_match = test.id.contains(filter.as_str());
144            let name_match = test
145                .name
146                .as_ref()
147                .map(|n| n.contains(filter.as_str()))
148                .unwrap_or(false);
149            if !id_match && !name_match {
150                return false;
151            }
152        }
153        if let Some(tag_filter) = &self.config.tag_filter {
154            if !test.tags.iter().any(|t| tag_filter.contains(t)) {
155                return false;
156            }
157        }
158        true
159    }
160}
161
162async fn run_single_test(
163    test: TestCase,
164    agent: &dyn AgentUnderTest,
165    default_timeout: Duration,
166    default_retries: usize,
167    default_config: &AgentConfig,
168    trace_dir: Option<&PathBuf>,
169) -> TestReport {
170    let start = Instant::now();
171    let timeout = test.timeout.unwrap_or(default_timeout);
172    let max_retries = test.retries.max(default_retries);
173    let config = if test.config.data.is_null() {
174        default_config
175    } else {
176        &test.config
177    };
178    let available_tools = agent.available_tools(config);
179
180    // Consensus mode
181    if let (Some(runs), Some(required)) = (test.consensus_runs, test.consensus_required) {
182        return run_consensus(
183            &test,
184            agent,
185            config,
186            &available_tools,
187            timeout,
188            runs,
189            required,
190            trace_dir,
191            start,
192        )
193        .await;
194    }
195
196    // Standard retry mode
197    let mut last_results = vec![];
198    let mut last_error = None;
199    let mut attempts = 0;
200
201    for attempt in 0..=max_retries {
202        attempts = attempt + 1;
203
204        let run_result = tokio::time::timeout(timeout, agent.run(&test.user_message, config)).await;
205
206        match run_result {
207            Ok(Ok(output)) => {
208                // Save trace
209                if let Some(dir) = trace_dir {
210                    let trace = Trace::new(
211                        test.id.clone(),
212                        test.user_message.clone(),
213                        output.clone(),
214                    );
215                    let path = dir.join(format!("{}_attempt{}.json", test.id, attempt));
216                    let _ = trace.save_to_file(&path);
217                }
218
219                let results: Vec<AssertionResult> = test
220                    .assertions
221                    .iter()
222                    .map(|a| a.evaluate(&output, &available_tools))
223                    .collect();
224
225                let all_passed = results.iter().all(|r| r.passed);
226                if all_passed {
227                    return TestReport {
228                        test_id: test.id,
229                        test_name: test.name,
230                        tags: test.tags,
231                        passed: true,
232                        attempts,
233                        assertion_results: results,
234                        duration: start.elapsed(),
235                        error: None,
236                    };
237                }
238                last_results = results;
239                last_error = None;
240            }
241            Ok(Err(e)) => {
242                last_error = Some(e.to_string());
243                last_results = vec![];
244            }
245            Err(_) => {
246                last_error = Some(format!("Timeout after {:?}", timeout));
247                last_results = vec![];
248            }
249        }
250    }
251
252    TestReport {
253        test_id: test.id,
254        test_name: test.name,
255        tags: test.tags,
256        passed: false,
257        attempts,
258        assertion_results: last_results,
259        duration: start.elapsed(),
260        error: last_error,
261    }
262}
263
264async fn run_consensus(
265    test: &TestCase,
266    agent: &dyn AgentUnderTest,
267    config: &AgentConfig,
268    available_tools: &[String],
269    timeout: Duration,
270    runs: usize,
271    required: usize,
272    trace_dir: Option<&PathBuf>,
273    start: Instant,
274) -> TestReport {
275    let mut pass_count = 0;
276    let mut last_results = vec![];
277
278    for i in 0..runs {
279        let run_result = tokio::time::timeout(timeout, agent.run(&test.user_message, config)).await;
280
281        match run_result {
282            Ok(Ok(output)) => {
283                if let Some(dir) = trace_dir {
284                    let trace = Trace::new(
285                        test.id.clone(),
286                        test.user_message.clone(),
287                        output.clone(),
288                    );
289                    let path = dir.join(format!("{}_consensus{}.json", test.id, i));
290                    let _ = trace.save_to_file(&path);
291                }
292
293                let results: Vec<AssertionResult> = test
294                    .assertions
295                    .iter()
296                    .map(|a| a.evaluate(&output, available_tools))
297                    .collect();
298
299                let all_passed = results.iter().all(|r| r.passed);
300                if all_passed {
301                    pass_count += 1;
302                }
303                last_results = results;
304            }
305            Ok(Err(_)) | Err(_) => {
306                // Count as failure
307            }
308        }
309
310        if pass_count >= required {
311            break;
312        }
313    }
314
315    TestReport {
316        test_id: test.id.clone(),
317        test_name: test.name.clone(),
318        tags: test.tags.clone(),
319        passed: pass_count >= required,
320        attempts: runs,
321        assertion_results: last_results,
322        duration: start.elapsed(),
323        error: if pass_count < required {
324            Some(format!(
325                "Consensus: {}/{} passed, needed {}",
326                pass_count, runs, required
327            ))
328        } else {
329            None
330        },
331    }
332}