1pub mod loader;
9pub mod metrics;
10
11pub use anyhow::Result;
12
13use clap::Parser;
14use indicatif::{ProgressBar, ProgressStyle};
15use loader::Dataset;
16use metrics::{Metric, all_metrics, minmax_normalize};
17use polars::io::SerWriter;
18use polars::prelude::{Column, CsvWriter, DataFrame};
19use std::fs::File;
20use std::marker::PhantomData;
21use std::path::{Path, PathBuf};
22use std::time::Instant;
23
24pub trait Detector: Send {
29 fn name() -> &'static str
31 where
32 Self: Sized;
33 fn new(n_dimensions: usize) -> Self
35 where
36 Self: Sized;
37 fn update(&mut self, point: &[f32]) -> f32;
39}
40
41pub trait DetectorFactory: Send {
43 fn name(&self) -> String;
45 fn create(&self, n_dims: usize) -> Box<dyn Detector>;
47}
48
49struct FactoryDetector<D> {
51 _detector: PhantomData<D>,
53}
54
55impl<D> DetectorFactory for FactoryDetector<D>
56where
57 D: Detector + 'static,
58{
59 fn name(&self) -> String {
60 D::name().to_string()
61 }
62
63 fn create(&self, n_dims: usize) -> Box<dyn Detector> {
64 Box::new(D::new(n_dims))
65 }
66}
67
68pub struct Touchstone {
70 detector_factories: Vec<Box<dyn DetectorFactory>>,
72 metrics: Vec<Box<dyn Metric>>,
74 data_dir: PathBuf,
76}
77
78impl Touchstone {
79 pub fn new(data_dir: &Path) -> Self {
81 Self {
82 detector_factories: Vec::new(),
83 metrics: Vec::new(),
84 data_dir: data_dir.into(),
85 }
86 }
87
88 pub fn add_detector<D>(&mut self)
93 where
94 D: Detector + 'static,
95 {
96 let detector_factory = FactoryDetector::<D> {
97 _detector: PhantomData,
98 };
99 self.detector_factories.push(Box::new(detector_factory));
100 }
101
102 pub fn add_detector_factory(&mut self, factory: Box<dyn DetectorFactory>) {
106 self.detector_factories.push(factory);
107 }
108
109 pub fn add_metric<M>(&mut self, metric: M)
113 where
114 M: Metric + 'static,
115 {
116 self.metrics.push(Box::new(metric));
117 }
118
119 pub fn run(&mut self) -> Result<DataFrame> {
127 let entries = loader::list_datasets(&self.data_dir)?;
128 if self.metrics.is_empty() {
129 self.metrics = all_metrics();
130 }
131 let metric_names: Vec<String> = self
132 .metrics
133 .iter()
134 .map(|m| m.name().to_string())
135 .chain(["time_sec".to_string()])
136 .collect();
137 let detector_names: Vec<String> =
138 self.detector_factories.iter().map(|d| d.name()).collect();
139 let mut dataset_col: Vec<String> = Vec::new();
140 let mut detector_col: Vec<String> = Vec::new();
141 let mut metric_cols: Vec<Vec<f64>> = vec![Vec::new(); self.metrics.len() + 1];
142
143 let total = (entries.len() * self.detector_factories.len()) as u64;
144 let pb = ProgressBar::new(total);
145 pb.set_style(
146 ProgressStyle::with_template(
147 "{spinner:.cyan} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}",
148 )
149 .unwrap()
150 .progress_chars("█▉▊▋▌▍▎▏ "),
151 );
152
153 for (name, path) in entries {
154 let dataset_name = name.clone();
155 let dataset = match loader::load_dataset(name, &path) {
156 Ok(ds) => ds,
157 Err(e) => {
158 pb.println(format!("skipping {}: {e}", path.display()));
159 pb.inc(self.detector_factories.len() as u64);
160 for det_name in &detector_names {
161 dataset_col.push(dataset_name.clone());
162 detector_col.push(det_name.clone());
163 for metric_values in &mut metric_cols {
164 metric_values.push(f64::NAN);
165 }
166 }
167 continue;
168 }
169 };
170
171 pb.set_message(dataset.name.clone());
172 let n_dims = dataset.features.first().map(|f| f.len()).unwrap_or(1);
173 let detectors = self
174 .detector_factories
175 .iter()
176 .map(|factory| factory.create(n_dims))
177 .collect::<Vec<_>>();
178 let ds_results = run_dataset(&dataset, &self.metrics, detectors);
179 pb.inc(self.detector_factories.len() as u64);
180 for (det_name, det_scores) in detector_names.iter().zip(ds_results.iter()) {
181 dataset_col.push(dataset.name.clone());
182 detector_col.push(det_name.clone());
183 for (mi, value) in det_scores.iter().enumerate() {
184 metric_cols[mi].push(*value);
185 }
186 }
187 }
188
189 pb.finish_and_clear();
190
191 let height = dataset_col.len();
192 let mut columns = Vec::with_capacity(2 + metric_names.len());
193 columns.push(Column::new("dataset".into(), dataset_col));
194 columns.push(Column::new("detector".into(), detector_col));
195 for (metric_name, values) in metric_names.iter().zip(metric_cols) {
196 columns.push(Column::new(metric_name.as_str().into(), values));
197 }
198
199 Ok(DataFrame::new(height, columns)?)
200 }
201}
202
203fn run_dataset(
208 dataset: &Dataset,
209 metrics: &[Box<dyn Metric>],
210 mut detectors: Vec<Box<dyn Detector>>,
211) -> Vec<Vec<f64>> {
212 detectors
213 .iter_mut()
214 .map(|detector| {
215 let start = Instant::now();
216 let raw_scores: Vec<f32> = dataset
217 .features
218 .iter()
219 .map(|point| detector.update(point))
220 .collect();
221 let time_secs = (Instant::now() - start).as_secs_f64();
222
223 let (valid_scores, valid_labels): (Vec<f32>, Vec<u8>) = raw_scores
224 .iter()
225 .zip(dataset.labels.iter())
226 .filter(|(s, _)| !s.is_nan())
227 .map(|(&s, &l)| (s, l))
228 .unzip();
229
230 if valid_scores.is_empty() {
231 return vec![f64::NAN; metrics.len() + 1]; }
233
234 let norm_scores = minmax_normalize(&valid_scores);
235 metrics
236 .iter()
237 .map(|m| m.score(&valid_labels, &norm_scores))
238 .chain([time_secs])
239 .collect()
240 })
241 .collect()
242}
243
244#[derive(Parser, Debug)]
246pub struct RunArgs {
247 #[arg(long)]
249 pub data_dir: PathBuf,
250}
251
252pub fn run_cli<D>() -> Result<()>
258where
259 D: Detector + 'static,
260{
261 let args = RunArgs::parse();
262 let mut experiment = Touchstone::new(&args.data_dir);
263 experiment.add_detector::<D>();
264 let mut report_df = experiment.run()?;
265
266 let mut file = File::create(format!("./touchstone-{}.csv", D::name())).unwrap();
267 CsvWriter::new(&mut file)
268 .include_header(true)
269 .with_separator(b',')
270 .finish(&mut report_df)
271 .unwrap();
272
273 Ok(())
274}
275
276#[macro_export]
292macro_rules! touchstone_main {
293 ($detector:ty) => {
294 fn main() -> $crate::Result<()> {
295 $crate::run_cli::<$detector>()
296 }
297 };
298}