1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum Scale {
14 Linear,
15 Log,
16 ReverseLog,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[serde(tag = "dim_type")]
22#[non_exhaustive]
23pub enum SearchDimension {
24 Float {
26 name: String,
27 low: f64,
28 high: f64,
29 scale: Scale,
30 default: Option<f64>,
31 },
32
33 Int {
35 name: String,
36 low: i64,
37 high: i64,
38 scale: Scale,
39 },
40
41 Categorical {
43 name: String,
44 choices: Vec<serde_json::Value>,
45 },
46
47 Conditional {
49 name: String,
50 parent: String,
51 parent_values: Vec<serde_json::Value>,
52 dimension: Box<SearchDimension>,
53 },
54}
55
56impl SearchDimension {
57 pub fn name(&self) -> &str {
58 match self {
59 Self::Float { name, .. }
60 | Self::Int { name, .. }
61 | Self::Categorical { name, .. }
62 | Self::Conditional { name, .. } => name,
63 }
64 }
65
66 pub fn validate(&self) -> Result<(), String> {
68 match self {
69 Self::Float {
70 low, high, name, ..
71 } => {
72 if low >= high {
73 return Err(format!(
74 "{name}: `low` ({low}) must be less than `high` ({high})"
75 ));
76 }
77 Ok(())
78 }
79 Self::Int {
80 low, high, name, ..
81 } => {
82 if low >= high {
83 return Err(format!(
84 "{name}: `low` ({low}) must be less than `high` ({high})"
85 ));
86 }
87 Ok(())
88 }
89 Self::Categorical { choices, name } => {
90 if choices.is_empty() {
91 return Err(format!("{name}: `choices` must not be empty"));
92 }
93 Ok(())
94 }
95 Self::Conditional { dimension, .. } => dimension.validate(),
96 }
97 }
98}
99
100impl fmt::Display for SearchDimension {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self {
103 Self::Float {
104 name,
105 low,
106 high,
107 scale,
108 ..
109 } => write!(f, "{name}: Float[{low}, {high}] {scale:?}"),
110 Self::Int {
111 name, low, high, ..
112 } => write!(f, "{name}: Int[{low}, {high}]"),
113 Self::Categorical { name, choices } => {
114 let labels: Vec<String> = choices.iter().map(|c| c.to_string()).collect();
115 write!(f, "{name}: Categorical[{}]", labels.join(", "))
116 }
117 Self::Conditional {
118 name,
119 parent,
120 dimension,
121 ..
122 } => write!(f, "{name}: Conditional(if {parent}) -> {dimension}"),
123 }
124 }
125}
126
127#[derive(Debug, Clone, Default, Serialize, Deserialize)]
129pub struct SearchSpace {
130 pub dimensions: Vec<SearchDimension>,
131 pub frozen: HashMap<String, serde_json::Value>,
132}
133
134impl SearchSpace {
135 pub fn new() -> Self {
136 Self::default()
137 }
138
139 pub fn add(&mut self, dim: SearchDimension) {
140 self.dimensions.push(dim);
141 }
142
143 pub fn merge_with_prefix(&mut self, prefix: &str, other: SearchSpace) {
145 for dim in other.dimensions {
146 let prefixed = prefix_dimension(prefix, dim);
147 self.dimensions.push(prefixed);
148 }
149 }
150
151 pub fn freeze(&mut self, name: &str, value: serde_json::Value) {
153 self.frozen.insert(name.to_string(), value);
154 self.dimensions.retain(|d| d.name() != name);
155 }
156
157 pub fn active_dimensions(&self) -> &[SearchDimension] {
159 &self.dimensions
160 }
161
162 pub fn validate(&self) -> Result<(), Vec<String>> {
164 let errors: Vec<String> = self
165 .dimensions
166 .iter()
167 .filter_map(|d| d.validate().err())
168 .collect();
169 if errors.is_empty() {
170 Ok(())
171 } else {
172 Err(errors)
173 }
174 }
175
176 pub fn is_empty(&self) -> bool {
177 self.dimensions.is_empty()
178 }
179
180 pub fn len(&self) -> usize {
181 self.dimensions.len()
182 }
183}
184
185impl fmt::Display for SearchSpace {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 for dim in &self.dimensions {
188 writeln!(f, " {dim}")?;
189 }
190 if !self.frozen.is_empty() {
191 writeln!(f, " Frozen:")?;
192 for (name, val) in &self.frozen {
193 writeln!(f, " {name} = {val}")?;
194 }
195 }
196 Ok(())
197 }
198}
199
200fn prefix_dimension(prefix: &str, dim: SearchDimension) -> SearchDimension {
202 match dim {
203 SearchDimension::Float {
204 name,
205 low,
206 high,
207 scale,
208 default,
209 } => SearchDimension::Float {
210 name: format!("{prefix}.{name}"),
211 low,
212 high,
213 scale,
214 default,
215 },
216 SearchDimension::Int {
217 name,
218 low,
219 high,
220 scale,
221 } => SearchDimension::Int {
222 name: format!("{prefix}.{name}"),
223 low,
224 high,
225 scale,
226 },
227 SearchDimension::Categorical { name, choices } => SearchDimension::Categorical {
228 name: format!("{prefix}.{name}"),
229 choices,
230 },
231 SearchDimension::Conditional {
232 name,
233 parent,
234 parent_values,
235 dimension,
236 } => SearchDimension::Conditional {
237 name: format!("{prefix}.{name}"),
238 parent: format!("{prefix}.{parent}"),
239 parent_values,
240 dimension: Box::new(prefix_dimension(prefix, *dimension)),
241 },
242 }
243}
244
245pub trait Searchable {
248 fn search_space() -> SearchSpace;
249 fn from_sample(params: &HashMap<String, serde_json::Value>) -> crate::error::Result<Self>
250 where
251 Self: Sized;
252 fn current_params(&self) -> HashMap<String, serde_json::Value>;
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use serde_json::json;
259
260 #[test]
261 fn float_dimension_display() {
262 let dim = SearchDimension::Float {
263 name: "lr".into(),
264 low: 0.001,
265 high: 0.1,
266 scale: Scale::Log,
267 default: None,
268 };
269 assert_eq!(dim.to_string(), "lr: Float[0.001, 0.1] Log");
270 }
271
272 #[test]
273 fn categorical_dimension_display() {
274 let dim = SearchDimension::Categorical {
275 name: "kernel".into(),
276 choices: vec![json!("linear"), json!("rbf")],
277 };
278 assert_eq!(dim.to_string(), "kernel: Categorical[\"linear\", \"rbf\"]");
279 }
280
281 #[test]
282 fn validate_rejects_inverted_range() {
283 let dim = SearchDimension::Float {
284 name: "lr".into(),
285 low: 1.0,
286 high: 0.1,
287 scale: Scale::Linear,
288 default: None,
289 };
290 assert!(dim.validate().is_err());
291 }
292
293 #[test]
294 fn validate_rejects_empty_choices() {
295 let dim = SearchDimension::Categorical {
296 name: "kernel".into(),
297 choices: vec![],
298 };
299 assert!(dim.validate().is_err());
300 }
301
302 #[test]
303 fn validate_accepts_valid_dimensions() {
304 let float = SearchDimension::Float {
305 name: "lr".into(),
306 low: 0.001,
307 high: 0.1,
308 scale: Scale::Log,
309 default: None,
310 };
311 let int = SearchDimension::Int {
312 name: "epochs".into(),
313 low: 10,
314 high: 100,
315 scale: Scale::Linear,
316 };
317 assert!(float.validate().is_ok());
318 assert!(int.validate().is_ok());
319 }
320
321 #[test]
322 fn search_space_merge_with_prefix() {
323 let mut space1 = SearchSpace::new();
324 space1.add(SearchDimension::Float {
325 name: "scale".into(),
326 low: 0.1,
327 high: 10.0,
328 scale: Scale::Log,
329 default: None,
330 });
331
332 let mut space2 = SearchSpace::new();
333 space2.add(SearchDimension::Float {
334 name: "C".into(),
335 low: 0.01,
336 high: 100.0,
337 scale: Scale::Log,
338 default: None,
339 });
340
341 let mut combined = SearchSpace::new();
342 combined.merge_with_prefix("Scaler", space1);
343 combined.merge_with_prefix("SVM", space2);
344
345 assert_eq!(combined.len(), 2);
346 assert_eq!(combined.dimensions[0].name(), "Scaler.scale");
347 assert_eq!(combined.dimensions[1].name(), "SVM.C");
348 }
349
350 #[test]
351 fn search_space_freeze() {
352 let mut space = SearchSpace::new();
353 space.add(SearchDimension::Float {
354 name: "lr".into(),
355 low: 0.001,
356 high: 0.1,
357 scale: Scale::Log,
358 default: None,
359 });
360 space.add(SearchDimension::Categorical {
361 name: "kernel".into(),
362 choices: vec![json!("rbf"), json!("linear")],
363 });
364
365 assert_eq!(space.len(), 2);
366 space.freeze("kernel", json!("rbf"));
367 assert_eq!(space.len(), 1);
368 assert_eq!(space.dimensions[0].name(), "lr");
369 assert_eq!(space.frozen["kernel"], json!("rbf"));
370 }
371
372 #[test]
373 fn search_space_validate() {
374 let mut space = SearchSpace::new();
375 space.add(SearchDimension::Float {
376 name: "good".into(),
377 low: 0.0,
378 high: 1.0,
379 scale: Scale::Linear,
380 default: None,
381 });
382 assert!(space.validate().is_ok());
383
384 space.add(SearchDimension::Float {
385 name: "bad".into(),
386 low: 10.0,
387 high: 1.0,
388 scale: Scale::Linear,
389 default: None,
390 });
391 assert!(space.validate().is_err());
392 }
393
394 #[test]
395 fn search_space_serde_roundtrip() {
396 let mut space = SearchSpace::new();
397 space.add(SearchDimension::Float {
398 name: "lr".into(),
399 low: 0.001,
400 high: 0.1,
401 scale: Scale::Log,
402 default: Some(0.01),
403 });
404 space.add(SearchDimension::Int {
405 name: "epochs".into(),
406 low: 10,
407 high: 100,
408 scale: Scale::Linear,
409 });
410 space.add(SearchDimension::Categorical {
411 name: "kernel".into(),
412 choices: vec![json!("rbf"), json!("linear")],
413 });
414
415 let json = serde_json::to_string(&space).unwrap();
416 let deserialized: SearchSpace = serde_json::from_str(&json).unwrap();
417 assert_eq!(deserialized.len(), 3);
418 }
419
420 #[test]
421 fn conditional_dimension() {
422 let dim = SearchDimension::Conditional {
423 name: "momentum".into(),
424 parent: "optimizer".into(),
425 parent_values: vec![json!("sgd")],
426 dimension: Box::new(SearchDimension::Float {
427 name: "momentum".into(),
428 low: 0.0,
429 high: 0.99,
430 scale: Scale::Linear,
431 default: None,
432 }),
433 };
434 assert!(dim.validate().is_ok());
435 }
436}