Skip to main content

somatize_core/
search.rs

1//! Search spaces for hyperparameter optimization.
2//!
3//! Defines [`SearchSpace`] (a collection of [`SearchDimension`]s) that
4//! samplers use to generate trial configurations. Dimensions can be
5//! float, int, categorical, or conditional.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11/// Scale for continuous search ranges.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum Scale {
14    Linear,
15    Log,
16    ReverseLog,
17}
18
19/// A single searchable parameter dimension.
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[serde(tag = "dim_type")]
22#[non_exhaustive]
23pub enum SearchDimension {
24    /// Continuous range (f64)
25    Float {
26        name: String,
27        low: f64,
28        high: f64,
29        scale: Scale,
30        default: Option<f64>,
31    },
32
33    /// Integer range
34    Int {
35        name: String,
36        low: i64,
37        high: i64,
38        scale: Scale,
39    },
40
41    /// Discrete set of choices
42    Categorical {
43        name: String,
44        choices: Vec<serde_json::Value>,
45    },
46
47    /// Active only when parent parameter has specific values
48    Conditional {
49        name: String,
50        parent: String,
51        parent_values: Vec<serde_json::Value>,
52        dimension: Box<SearchDimension>,
53    },
54}
55
56impl SearchDimension {
57    pub fn name(&self) -> &str {
58        match self {
59            Self::Float { name, .. }
60            | Self::Int { name, .. }
61            | Self::Categorical { name, .. }
62            | Self::Conditional { name, .. } => name,
63        }
64    }
65
66    /// Validate the dimension configuration.
67    pub fn validate(&self) -> Result<(), String> {
68        match self {
69            Self::Float {
70                low, high, name, ..
71            } => {
72                if low >= high {
73                    return Err(format!(
74                        "{name}: `low` ({low}) must be less than `high` ({high})"
75                    ));
76                }
77                Ok(())
78            }
79            Self::Int {
80                low, high, name, ..
81            } => {
82                if low >= high {
83                    return Err(format!(
84                        "{name}: `low` ({low}) must be less than `high` ({high})"
85                    ));
86                }
87                Ok(())
88            }
89            Self::Categorical { choices, name } => {
90                if choices.is_empty() {
91                    return Err(format!("{name}: `choices` must not be empty"));
92                }
93                Ok(())
94            }
95            Self::Conditional { dimension, .. } => dimension.validate(),
96        }
97    }
98}
99
100impl fmt::Display for SearchDimension {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self {
103            Self::Float {
104                name,
105                low,
106                high,
107                scale,
108                ..
109            } => write!(f, "{name}: Float[{low}, {high}] {scale:?}"),
110            Self::Int {
111                name, low, high, ..
112            } => write!(f, "{name}: Int[{low}, {high}]"),
113            Self::Categorical { name, choices } => {
114                let labels: Vec<String> = choices.iter().map(|c| c.to_string()).collect();
115                write!(f, "{name}: Categorical[{}]", labels.join(", "))
116            }
117            Self::Conditional {
118                name,
119                parent,
120                dimension,
121                ..
122            } => write!(f, "{name}: Conditional(if {parent}) -> {dimension}"),
123        }
124    }
125}
126
127/// Aggregation of search dimensions from one or more filters.
128#[derive(Debug, Clone, Default, Serialize, Deserialize)]
129pub struct SearchSpace {
130    pub dimensions: Vec<SearchDimension>,
131    pub frozen: HashMap<String, serde_json::Value>,
132}
133
134impl SearchSpace {
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    pub fn add(&mut self, dim: SearchDimension) {
140        self.dimensions.push(dim);
141    }
142
143    /// Merge another search space with a prefix to avoid name collisions.
144    pub fn merge_with_prefix(&mut self, prefix: &str, other: SearchSpace) {
145        for dim in other.dimensions {
146            let prefixed = prefix_dimension(prefix, dim);
147            self.dimensions.push(prefixed);
148        }
149    }
150
151    /// Freeze a parameter to a fixed value (exclude from search).
152    pub fn freeze(&mut self, name: &str, value: serde_json::Value) {
153        self.frozen.insert(name.to_string(), value);
154        self.dimensions.retain(|d| d.name() != name);
155    }
156
157    /// Get only the active (non-frozen) dimensions.
158    pub fn active_dimensions(&self) -> &[SearchDimension] {
159        &self.dimensions
160    }
161
162    /// Validate all dimensions.
163    pub fn validate(&self) -> Result<(), Vec<String>> {
164        let errors: Vec<String> = self
165            .dimensions
166            .iter()
167            .filter_map(|d| d.validate().err())
168            .collect();
169        if errors.is_empty() {
170            Ok(())
171        } else {
172            Err(errors)
173        }
174    }
175
176    pub fn is_empty(&self) -> bool {
177        self.dimensions.is_empty()
178    }
179
180    pub fn len(&self) -> usize {
181        self.dimensions.len()
182    }
183}
184
185impl fmt::Display for SearchSpace {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        for dim in &self.dimensions {
188            writeln!(f, "  {dim}")?;
189        }
190        if !self.frozen.is_empty() {
191            writeln!(f, "  Frozen:")?;
192            for (name, val) in &self.frozen {
193                writeln!(f, "    {name} = {val}")?;
194            }
195        }
196        Ok(())
197    }
198}
199
200/// Prefix all dimension names with a filter label.
201fn prefix_dimension(prefix: &str, dim: SearchDimension) -> SearchDimension {
202    match dim {
203        SearchDimension::Float {
204            name,
205            low,
206            high,
207            scale,
208            default,
209        } => SearchDimension::Float {
210            name: format!("{prefix}.{name}"),
211            low,
212            high,
213            scale,
214            default,
215        },
216        SearchDimension::Int {
217            name,
218            low,
219            high,
220            scale,
221        } => SearchDimension::Int {
222            name: format!("{prefix}.{name}"),
223            low,
224            high,
225            scale,
226        },
227        SearchDimension::Categorical { name, choices } => SearchDimension::Categorical {
228            name: format!("{prefix}.{name}"),
229            choices,
230        },
231        SearchDimension::Conditional {
232            name,
233            parent,
234            parent_values,
235            dimension,
236        } => SearchDimension::Conditional {
237            name: format!("{prefix}.{name}"),
238            parent: format!("{prefix}.{parent}"),
239            parent_values,
240            dimension: Box::new(prefix_dimension(prefix, *dimension)),
241        },
242    }
243}
244
245/// Trait for filters that declare their search space.
246/// Auto-generated by `#[derive(Filter)]`.
247pub trait Searchable {
248    fn search_space() -> SearchSpace;
249    fn from_sample(params: &HashMap<String, serde_json::Value>) -> crate::error::Result<Self>
250    where
251        Self: Sized;
252    fn current_params(&self) -> HashMap<String, serde_json::Value>;
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use serde_json::json;
259
260    #[test]
261    fn float_dimension_display() {
262        let dim = SearchDimension::Float {
263            name: "lr".into(),
264            low: 0.001,
265            high: 0.1,
266            scale: Scale::Log,
267            default: None,
268        };
269        assert_eq!(dim.to_string(), "lr: Float[0.001, 0.1] Log");
270    }
271
272    #[test]
273    fn categorical_dimension_display() {
274        let dim = SearchDimension::Categorical {
275            name: "kernel".into(),
276            choices: vec![json!("linear"), json!("rbf")],
277        };
278        assert_eq!(dim.to_string(), "kernel: Categorical[\"linear\", \"rbf\"]");
279    }
280
281    #[test]
282    fn validate_rejects_inverted_range() {
283        let dim = SearchDimension::Float {
284            name: "lr".into(),
285            low: 1.0,
286            high: 0.1,
287            scale: Scale::Linear,
288            default: None,
289        };
290        assert!(dim.validate().is_err());
291    }
292
293    #[test]
294    fn validate_rejects_empty_choices() {
295        let dim = SearchDimension::Categorical {
296            name: "kernel".into(),
297            choices: vec![],
298        };
299        assert!(dim.validate().is_err());
300    }
301
302    #[test]
303    fn validate_accepts_valid_dimensions() {
304        let float = SearchDimension::Float {
305            name: "lr".into(),
306            low: 0.001,
307            high: 0.1,
308            scale: Scale::Log,
309            default: None,
310        };
311        let int = SearchDimension::Int {
312            name: "epochs".into(),
313            low: 10,
314            high: 100,
315            scale: Scale::Linear,
316        };
317        assert!(float.validate().is_ok());
318        assert!(int.validate().is_ok());
319    }
320
321    #[test]
322    fn search_space_merge_with_prefix() {
323        let mut space1 = SearchSpace::new();
324        space1.add(SearchDimension::Float {
325            name: "scale".into(),
326            low: 0.1,
327            high: 10.0,
328            scale: Scale::Log,
329            default: None,
330        });
331
332        let mut space2 = SearchSpace::new();
333        space2.add(SearchDimension::Float {
334            name: "C".into(),
335            low: 0.01,
336            high: 100.0,
337            scale: Scale::Log,
338            default: None,
339        });
340
341        let mut combined = SearchSpace::new();
342        combined.merge_with_prefix("Scaler", space1);
343        combined.merge_with_prefix("SVM", space2);
344
345        assert_eq!(combined.len(), 2);
346        assert_eq!(combined.dimensions[0].name(), "Scaler.scale");
347        assert_eq!(combined.dimensions[1].name(), "SVM.C");
348    }
349
350    #[test]
351    fn search_space_freeze() {
352        let mut space = SearchSpace::new();
353        space.add(SearchDimension::Float {
354            name: "lr".into(),
355            low: 0.001,
356            high: 0.1,
357            scale: Scale::Log,
358            default: None,
359        });
360        space.add(SearchDimension::Categorical {
361            name: "kernel".into(),
362            choices: vec![json!("rbf"), json!("linear")],
363        });
364
365        assert_eq!(space.len(), 2);
366        space.freeze("kernel", json!("rbf"));
367        assert_eq!(space.len(), 1);
368        assert_eq!(space.dimensions[0].name(), "lr");
369        assert_eq!(space.frozen["kernel"], json!("rbf"));
370    }
371
372    #[test]
373    fn search_space_validate() {
374        let mut space = SearchSpace::new();
375        space.add(SearchDimension::Float {
376            name: "good".into(),
377            low: 0.0,
378            high: 1.0,
379            scale: Scale::Linear,
380            default: None,
381        });
382        assert!(space.validate().is_ok());
383
384        space.add(SearchDimension::Float {
385            name: "bad".into(),
386            low: 10.0,
387            high: 1.0,
388            scale: Scale::Linear,
389            default: None,
390        });
391        assert!(space.validate().is_err());
392    }
393
394    #[test]
395    fn search_space_serde_roundtrip() {
396        let mut space = SearchSpace::new();
397        space.add(SearchDimension::Float {
398            name: "lr".into(),
399            low: 0.001,
400            high: 0.1,
401            scale: Scale::Log,
402            default: Some(0.01),
403        });
404        space.add(SearchDimension::Int {
405            name: "epochs".into(),
406            low: 10,
407            high: 100,
408            scale: Scale::Linear,
409        });
410        space.add(SearchDimension::Categorical {
411            name: "kernel".into(),
412            choices: vec![json!("rbf"), json!("linear")],
413        });
414
415        let json = serde_json::to_string(&space).unwrap();
416        let deserialized: SearchSpace = serde_json::from_str(&json).unwrap();
417        assert_eq!(deserialized.len(), 3);
418    }
419
420    #[test]
421    fn conditional_dimension() {
422        let dim = SearchDimension::Conditional {
423            name: "momentum".into(),
424            parent: "optimizer".into(),
425            parent_values: vec![json!("sgd")],
426            dimension: Box::new(SearchDimension::Float {
427                name: "momentum".into(),
428                low: 0.0,
429                high: 0.99,
430                scale: Scale::Linear,
431                default: None,
432            }),
433        };
434        assert!(dim.validate().is_ok());
435    }
436}