Skip to main content

swarm_engine_core/learn/stats_model/
params.rs

1//! Parametric Trait - パラメータ提供機能
2//!
3//! 戦略設定に使用するパラメータを提供するモデル。
4
5use std::any::Any;
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use super::base::{Model, ModelMetadata, ModelType, ModelVersion};
11use crate::learn::offline::{OfflineModel, RecommendedPath, StrategyConfig};
12use crate::util::epoch_millis;
13
14/// パラメータを提供できるモデル(戦略設定に使用)
15pub trait Parametric: Model {
16    /// パラメータ取得
17    fn get_param(&self, key: &str) -> Option<ParamValue>;
18
19    /// 全パラメータ取得
20    fn all_params(&self) -> HashMap<String, ParamValue>;
21}
22
23/// パラメータ値
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub enum ParamValue {
26    Float(f64),
27    Int(i64),
28    Bool(bool),
29    String(String),
30    Array(Vec<ParamValue>),
31}
32
33impl ParamValue {
34    pub fn as_f64(&self) -> Option<f64> {
35        match self {
36            Self::Float(v) => Some(*v),
37            Self::Int(v) => Some(*v as f64),
38            _ => None,
39        }
40    }
41
42    pub fn as_i64(&self) -> Option<i64> {
43        match self {
44            Self::Int(v) => Some(*v),
45            Self::Float(v) => Some(*v as i64),
46            _ => None,
47        }
48    }
49
50    pub fn as_bool(&self) -> Option<bool> {
51        match self {
52            Self::Bool(v) => Some(*v),
53            _ => None,
54        }
55    }
56
57    pub fn as_str(&self) -> Option<&str> {
58        match self {
59            Self::String(v) => Some(v),
60            _ => None,
61        }
62    }
63}
64
65impl From<f64> for ParamValue {
66    fn from(v: f64) -> Self {
67        Self::Float(v)
68    }
69}
70
71impl From<i64> for ParamValue {
72    fn from(v: i64) -> Self {
73        Self::Int(v)
74    }
75}
76
77impl From<bool> for ParamValue {
78    fn from(v: bool) -> Self {
79        Self::Bool(v)
80    }
81}
82
83impl From<String> for ParamValue {
84    fn from(v: String) -> Self {
85        Self::String(v)
86    }
87}
88
89impl From<&str> for ParamValue {
90    fn from(v: &str) -> Self {
91        Self::String(v.to_string())
92    }
93}
94
95/// パラメータキー定数
96pub mod param_keys {
97    pub const UCB1_C: &str = "ucb1_c";
98    pub const LEARNING_WEIGHT: &str = "learning_weight";
99    pub const NGRAM_WEIGHT: &str = "ngram_weight";
100    pub const MATURITY_THRESHOLD: &str = "maturity_threshold";
101    pub const ERROR_RATE_THRESHOLD: &str = "error_rate_threshold";
102    pub const INITIAL_STRATEGY: &str = "initial_strategy";
103}
104
105// ============================================================================
106// OptimalParamsModel - パラメータ最適化モデル
107// ============================================================================
108
109/// パラメータ最適化モデル
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct OptimalParamsModel {
112    version: ModelVersion,
113    metadata: ModelMetadata,
114    created_at: u64,
115
116    /// パラメータデータ
117    params: HashMap<String, ParamValue>,
118
119    /// 推奨設定
120    pub strategy_config: StrategyConfig,
121    pub recommended_paths: Vec<RecommendedPath>,
122
123    /// 分析に使用したセッション数
124    pub analyzed_sessions: usize,
125}
126
127impl Default for OptimalParamsModel {
128    fn default() -> Self {
129        let mut params = HashMap::new();
130        params.insert(
131            param_keys::UCB1_C.to_string(),
132            ParamValue::Float(std::f64::consts::SQRT_2),
133        );
134        params.insert(
135            param_keys::LEARNING_WEIGHT.to_string(),
136            ParamValue::Float(0.3),
137        );
138        params.insert(param_keys::NGRAM_WEIGHT.to_string(), ParamValue::Float(1.0));
139
140        Self {
141            version: ModelVersion::new(1, 0),
142            metadata: ModelMetadata::default(),
143            created_at: epoch_millis(),
144            params,
145            strategy_config: StrategyConfig::default(),
146            recommended_paths: Vec::new(),
147            analyzed_sessions: 0,
148        }
149    }
150}
151
152impl OptimalParamsModel {
153    /// 新しいモデルを作成
154    pub fn new() -> Self {
155        Self::default()
156    }
157
158    /// パラメータを設定
159    pub fn set_param(&mut self, key: &str, value: impl Into<ParamValue>) {
160        self.params.insert(key.to_string(), value.into());
161    }
162
163    /// UCB1 の探索係数を取得
164    pub fn ucb1_c(&self) -> f64 {
165        self.get_param(param_keys::UCB1_C)
166            .and_then(|v| v.as_f64())
167            .unwrap_or(std::f64::consts::SQRT_2)
168    }
169
170    /// 学習重みを取得
171    pub fn learning_weight(&self) -> f64 {
172        self.get_param(param_keys::LEARNING_WEIGHT)
173            .and_then(|v| v.as_f64())
174            .unwrap_or(0.3)
175    }
176
177    /// N-gram 重みを取得
178    pub fn ngram_weight(&self) -> f64 {
179        self.get_param(param_keys::NGRAM_WEIGHT)
180            .and_then(|v| v.as_f64())
181            .unwrap_or(1.0)
182    }
183}
184
185impl Model for OptimalParamsModel {
186    fn model_type(&self) -> ModelType {
187        ModelType::OptimalParams
188    }
189
190    fn version(&self) -> &ModelVersion {
191        &self.version
192    }
193
194    fn created_at(&self) -> u64 {
195        self.created_at
196    }
197
198    fn metadata(&self) -> &ModelMetadata {
199        &self.metadata
200    }
201
202    fn as_any(&self) -> &dyn Any {
203        self
204    }
205}
206
207impl Parametric for OptimalParamsModel {
208    fn get_param(&self, key: &str) -> Option<ParamValue> {
209        self.params.get(key).cloned()
210    }
211
212    fn all_params(&self) -> HashMap<String, ParamValue> {
213        self.params.clone()
214    }
215}
216
217// ============================================================================
218// 旧 OfflineModel からの変換
219// ============================================================================
220
221impl From<OfflineModel> for OptimalParamsModel {
222    fn from(old: OfflineModel) -> Self {
223        let mut params = HashMap::new();
224        params.insert(
225            param_keys::UCB1_C.to_string(),
226            ParamValue::Float(old.parameters.ucb1_c),
227        );
228        params.insert(
229            param_keys::LEARNING_WEIGHT.to_string(),
230            ParamValue::Float(old.parameters.learning_weight),
231        );
232        params.insert(
233            param_keys::NGRAM_WEIGHT.to_string(),
234            ParamValue::Float(old.parameters.ngram_weight),
235        );
236        params.insert(
237            param_keys::MATURITY_THRESHOLD.to_string(),
238            ParamValue::Int(old.strategy_config.maturity_threshold as i64),
239        );
240        params.insert(
241            param_keys::ERROR_RATE_THRESHOLD.to_string(),
242            ParamValue::Float(old.strategy_config.error_rate_threshold),
243        );
244        params.insert(
245            param_keys::INITIAL_STRATEGY.to_string(),
246            ParamValue::String(old.strategy_config.initial_strategy.clone()),
247        );
248
249        Self {
250            version: ModelVersion::new(old.version, 0),
251            metadata: ModelMetadata::default(),
252            created_at: old.updated_at,
253            params,
254            strategy_config: old.strategy_config,
255            recommended_paths: old.recommended_paths,
256            analyzed_sessions: old.analyzed_sessions,
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_param_value_conversions() {
267        let f = ParamValue::Float(1.5);
268        assert_eq!(f.as_f64(), Some(1.5));
269        assert_eq!(f.as_i64(), Some(1));
270
271        let i = ParamValue::Int(42);
272        assert_eq!(i.as_i64(), Some(42));
273        assert_eq!(i.as_f64(), Some(42.0));
274
275        let b = ParamValue::Bool(true);
276        assert_eq!(b.as_bool(), Some(true));
277
278        let s = ParamValue::String("test".to_string());
279        assert_eq!(s.as_str(), Some("test"));
280    }
281
282    #[test]
283    fn test_optimal_params_model_default() {
284        let model = OptimalParamsModel::new();
285
286        assert!((model.ucb1_c() - std::f64::consts::SQRT_2).abs() < 1e-10);
287        assert!((model.learning_weight() - 0.3).abs() < 1e-10);
288        assert!((model.ngram_weight() - 1.0).abs() < 1e-10);
289    }
290
291    #[test]
292    fn test_optimal_params_model_set_param() {
293        let mut model = OptimalParamsModel::new();
294        model.set_param(param_keys::UCB1_C, 2.0);
295
296        assert!((model.ucb1_c() - 2.0).abs() < 1e-10);
297    }
298
299    #[test]
300    fn test_parametric_trait() {
301        let model = OptimalParamsModel::new();
302
303        let value = model.get_param(param_keys::UCB1_C);
304        assert!(value.is_some());
305
306        let all = model.all_params();
307        assert!(all.contains_key(param_keys::UCB1_C));
308        assert!(all.contains_key(param_keys::LEARNING_WEIGHT));
309    }
310
311    #[test]
312    fn test_from_offline_model() {
313        use crate::learn::offline::{OfflineModel, OptimalParameters, StrategyConfig};
314
315        let old = OfflineModel {
316            version: 2,
317            parameters: OptimalParameters {
318                ucb1_c: 1.5,
319                learning_weight: 0.4,
320                ngram_weight: 1.2,
321            },
322            recommended_paths: vec![],
323            strategy_config: StrategyConfig::default(),
324            analyzed_sessions: 5,
325            updated_at: 12345,
326            action_order: None,
327        };
328
329        let model: OptimalParamsModel = old.into();
330
331        assert!((model.ucb1_c() - 1.5).abs() < 1e-10);
332        assert!((model.learning_weight() - 0.4).abs() < 1e-10);
333        assert!((model.ngram_weight() - 1.2).abs() < 1e-10);
334        assert_eq!(model.analyzed_sessions, 5);
335    }
336}