swarm_engine_core/learn/stats_model/
params.rs1use std::any::Any;
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use super::base::{Model, ModelMetadata, ModelType, ModelVersion};
11use crate::learn::offline::{OfflineModel, RecommendedPath, StrategyConfig};
12use crate::util::epoch_millis;
13
14pub trait Parametric: Model {
16 fn get_param(&self, key: &str) -> Option<ParamValue>;
18
19 fn all_params(&self) -> HashMap<String, ParamValue>;
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub enum ParamValue {
26 Float(f64),
27 Int(i64),
28 Bool(bool),
29 String(String),
30 Array(Vec<ParamValue>),
31}
32
33impl ParamValue {
34 pub fn as_f64(&self) -> Option<f64> {
35 match self {
36 Self::Float(v) => Some(*v),
37 Self::Int(v) => Some(*v as f64),
38 _ => None,
39 }
40 }
41
42 pub fn as_i64(&self) -> Option<i64> {
43 match self {
44 Self::Int(v) => Some(*v),
45 Self::Float(v) => Some(*v as i64),
46 _ => None,
47 }
48 }
49
50 pub fn as_bool(&self) -> Option<bool> {
51 match self {
52 Self::Bool(v) => Some(*v),
53 _ => None,
54 }
55 }
56
57 pub fn as_str(&self) -> Option<&str> {
58 match self {
59 Self::String(v) => Some(v),
60 _ => None,
61 }
62 }
63}
64
65impl From<f64> for ParamValue {
66 fn from(v: f64) -> Self {
67 Self::Float(v)
68 }
69}
70
71impl From<i64> for ParamValue {
72 fn from(v: i64) -> Self {
73 Self::Int(v)
74 }
75}
76
77impl From<bool> for ParamValue {
78 fn from(v: bool) -> Self {
79 Self::Bool(v)
80 }
81}
82
83impl From<String> for ParamValue {
84 fn from(v: String) -> Self {
85 Self::String(v)
86 }
87}
88
89impl From<&str> for ParamValue {
90 fn from(v: &str) -> Self {
91 Self::String(v.to_string())
92 }
93}
94
95pub mod param_keys {
97 pub const UCB1_C: &str = "ucb1_c";
98 pub const LEARNING_WEIGHT: &str = "learning_weight";
99 pub const NGRAM_WEIGHT: &str = "ngram_weight";
100 pub const MATURITY_THRESHOLD: &str = "maturity_threshold";
101 pub const ERROR_RATE_THRESHOLD: &str = "error_rate_threshold";
102 pub const INITIAL_STRATEGY: &str = "initial_strategy";
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct OptimalParamsModel {
112 version: ModelVersion,
113 metadata: ModelMetadata,
114 created_at: u64,
115
116 params: HashMap<String, ParamValue>,
118
119 pub strategy_config: StrategyConfig,
121 pub recommended_paths: Vec<RecommendedPath>,
122
123 pub analyzed_sessions: usize,
125}
126
127impl Default for OptimalParamsModel {
128 fn default() -> Self {
129 let mut params = HashMap::new();
130 params.insert(
131 param_keys::UCB1_C.to_string(),
132 ParamValue::Float(std::f64::consts::SQRT_2),
133 );
134 params.insert(
135 param_keys::LEARNING_WEIGHT.to_string(),
136 ParamValue::Float(0.3),
137 );
138 params.insert(param_keys::NGRAM_WEIGHT.to_string(), ParamValue::Float(1.0));
139
140 Self {
141 version: ModelVersion::new(1, 0),
142 metadata: ModelMetadata::default(),
143 created_at: epoch_millis(),
144 params,
145 strategy_config: StrategyConfig::default(),
146 recommended_paths: Vec::new(),
147 analyzed_sessions: 0,
148 }
149 }
150}
151
152impl OptimalParamsModel {
153 pub fn new() -> Self {
155 Self::default()
156 }
157
158 pub fn set_param(&mut self, key: &str, value: impl Into<ParamValue>) {
160 self.params.insert(key.to_string(), value.into());
161 }
162
163 pub fn ucb1_c(&self) -> f64 {
165 self.get_param(param_keys::UCB1_C)
166 .and_then(|v| v.as_f64())
167 .unwrap_or(std::f64::consts::SQRT_2)
168 }
169
170 pub fn learning_weight(&self) -> f64 {
172 self.get_param(param_keys::LEARNING_WEIGHT)
173 .and_then(|v| v.as_f64())
174 .unwrap_or(0.3)
175 }
176
177 pub fn ngram_weight(&self) -> f64 {
179 self.get_param(param_keys::NGRAM_WEIGHT)
180 .and_then(|v| v.as_f64())
181 .unwrap_or(1.0)
182 }
183}
184
185impl Model for OptimalParamsModel {
186 fn model_type(&self) -> ModelType {
187 ModelType::OptimalParams
188 }
189
190 fn version(&self) -> &ModelVersion {
191 &self.version
192 }
193
194 fn created_at(&self) -> u64 {
195 self.created_at
196 }
197
198 fn metadata(&self) -> &ModelMetadata {
199 &self.metadata
200 }
201
202 fn as_any(&self) -> &dyn Any {
203 self
204 }
205}
206
207impl Parametric for OptimalParamsModel {
208 fn get_param(&self, key: &str) -> Option<ParamValue> {
209 self.params.get(key).cloned()
210 }
211
212 fn all_params(&self) -> HashMap<String, ParamValue> {
213 self.params.clone()
214 }
215}
216
217impl From<OfflineModel> for OptimalParamsModel {
222 fn from(old: OfflineModel) -> Self {
223 let mut params = HashMap::new();
224 params.insert(
225 param_keys::UCB1_C.to_string(),
226 ParamValue::Float(old.parameters.ucb1_c),
227 );
228 params.insert(
229 param_keys::LEARNING_WEIGHT.to_string(),
230 ParamValue::Float(old.parameters.learning_weight),
231 );
232 params.insert(
233 param_keys::NGRAM_WEIGHT.to_string(),
234 ParamValue::Float(old.parameters.ngram_weight),
235 );
236 params.insert(
237 param_keys::MATURITY_THRESHOLD.to_string(),
238 ParamValue::Int(old.strategy_config.maturity_threshold as i64),
239 );
240 params.insert(
241 param_keys::ERROR_RATE_THRESHOLD.to_string(),
242 ParamValue::Float(old.strategy_config.error_rate_threshold),
243 );
244 params.insert(
245 param_keys::INITIAL_STRATEGY.to_string(),
246 ParamValue::String(old.strategy_config.initial_strategy.clone()),
247 );
248
249 Self {
250 version: ModelVersion::new(old.version, 0),
251 metadata: ModelMetadata::default(),
252 created_at: old.updated_at,
253 params,
254 strategy_config: old.strategy_config,
255 recommended_paths: old.recommended_paths,
256 analyzed_sessions: old.analyzed_sessions,
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_param_value_conversions() {
267 let f = ParamValue::Float(1.5);
268 assert_eq!(f.as_f64(), Some(1.5));
269 assert_eq!(f.as_i64(), Some(1));
270
271 let i = ParamValue::Int(42);
272 assert_eq!(i.as_i64(), Some(42));
273 assert_eq!(i.as_f64(), Some(42.0));
274
275 let b = ParamValue::Bool(true);
276 assert_eq!(b.as_bool(), Some(true));
277
278 let s = ParamValue::String("test".to_string());
279 assert_eq!(s.as_str(), Some("test"));
280 }
281
282 #[test]
283 fn test_optimal_params_model_default() {
284 let model = OptimalParamsModel::new();
285
286 assert!((model.ucb1_c() - std::f64::consts::SQRT_2).abs() < 1e-10);
287 assert!((model.learning_weight() - 0.3).abs() < 1e-10);
288 assert!((model.ngram_weight() - 1.0).abs() < 1e-10);
289 }
290
291 #[test]
292 fn test_optimal_params_model_set_param() {
293 let mut model = OptimalParamsModel::new();
294 model.set_param(param_keys::UCB1_C, 2.0);
295
296 assert!((model.ucb1_c() - 2.0).abs() < 1e-10);
297 }
298
299 #[test]
300 fn test_parametric_trait() {
301 let model = OptimalParamsModel::new();
302
303 let value = model.get_param(param_keys::UCB1_C);
304 assert!(value.is_some());
305
306 let all = model.all_params();
307 assert!(all.contains_key(param_keys::UCB1_C));
308 assert!(all.contains_key(param_keys::LEARNING_WEIGHT));
309 }
310
311 #[test]
312 fn test_from_offline_model() {
313 use crate::learn::offline::{OfflineModel, OptimalParameters, StrategyConfig};
314
315 let old = OfflineModel {
316 version: 2,
317 parameters: OptimalParameters {
318 ucb1_c: 1.5,
319 learning_weight: 0.4,
320 ngram_weight: 1.2,
321 },
322 recommended_paths: vec![],
323 strategy_config: StrategyConfig::default(),
324 analyzed_sessions: 5,
325 updated_at: 12345,
326 };
327
328 let model: OptimalParamsModel = old.into();
329
330 assert!((model.ucb1_c() - 1.5).abs() < 1e-10);
331 assert!((model.learning_weight() - 0.4).abs() < 1e-10);
332 assert!((model.ngram_weight() - 1.2).abs() < 1e-10);
333 assert_eq!(model.analyzed_sessions, 5);
334 }
335}