Skip to main content

swarm_engine_core/learn/stats_model/
base.rs

1//! Base Model Trait - 全モデル共通のインターフェース
2
3use std::any::Any;
4use std::collections::HashMap;
5
6use serde::{Deserialize, Serialize};
7
8use crate::util::epoch_millis_for_ordering;
9
10/// 統計モデル識別子
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct StatsModelId(pub String);
13
14impl StatsModelId {
15    pub fn new(id: impl Into<String>) -> Self {
16        Self(id.into())
17    }
18
19    pub fn generate() -> Self {
20        let ts = epoch_millis_for_ordering();
21        Self(format!("stats-{:x}", ts))
22    }
23
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27}
28
29impl std::fmt::Display for StatsModelId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.0)
32    }
33}
34
35/// 全ての学習済みモデルが実装する基本 trait
36pub trait Model: Send + Sync {
37    /// モデルの種類
38    fn model_type(&self) -> ModelType;
39
40    /// バージョン(Lineage追跡用)
41    fn version(&self) -> &ModelVersion;
42
43    /// 作成日時(Unix timestamp ms)
44    fn created_at(&self) -> u64;
45
46    /// メタデータ
47    fn metadata(&self) -> &ModelMetadata;
48
49    /// ダウンキャスト用
50    fn as_any(&self) -> &dyn Any;
51}
52
53/// モデルの種類
54#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
55pub enum ModelType {
56    /// 行動選択スコア
57    ActionScore,
58    /// パラメータ最適化
59    OptimalParams,
60    /// 将来の拡張用
61    Custom(String),
62}
63
64impl ModelType {
65    /// ディレクトリ名を取得
66    pub fn dir_name(&self) -> &str {
67        match self {
68            Self::ActionScore => "action_scores",
69            Self::OptimalParams => "optimal_params",
70            Self::Custom(name) => name,
71        }
72    }
73}
74
75impl std::fmt::Display for ModelType {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            Self::ActionScore => write!(f, "ActionScore"),
79            Self::OptimalParams => write!(f, "OptimalParams"),
80            Self::Custom(name) => write!(f, "Custom({})", name),
81        }
82    }
83}
84
85/// モデルバージョン
86#[derive(Debug, Clone, Default, Serialize, Deserialize)]
87pub struct ModelVersion {
88    pub major: u32,
89    pub minor: u32,
90    /// 元データの識別子(Episode IDs, Snapshot IDs 等)
91    pub source_ids: Vec<String>,
92}
93
94impl ModelVersion {
95    pub fn new(major: u32, minor: u32) -> Self {
96        Self {
97            major,
98            minor,
99            source_ids: Vec::new(),
100        }
101    }
102
103    pub fn with_sources(major: u32, minor: u32, source_ids: Vec<String>) -> Self {
104        Self {
105            major,
106            minor,
107            source_ids,
108        }
109    }
110}
111
112impl std::fmt::Display for ModelVersion {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        write!(f, "{}.{}", self.major, self.minor)
115    }
116}
117
118/// モデルメタデータ
119#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120pub struct ModelMetadata {
121    pub name: Option<String>,
122    pub description: Option<String>,
123    pub tags: HashMap<String, String>,
124}
125
126impl ModelMetadata {
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    pub fn with_name(mut self, name: impl Into<String>) -> Self {
132        self.name = Some(name.into());
133        self
134    }
135
136    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
137        self.description = Some(desc.into());
138        self
139    }
140
141    pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
142        self.tags.insert(key.into(), value.into());
143        self
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_stats_model_id_generate() {
153        let id1 = StatsModelId::generate();
154        let id2 = StatsModelId::generate();
155        assert!(!id1.0.is_empty());
156        assert!(!id2.0.is_empty());
157        assert!(id1.as_str().starts_with("stats-"));
158    }
159
160    #[test]
161    fn test_model_type_dir_name() {
162        assert_eq!(ModelType::ActionScore.dir_name(), "action_scores");
163        assert_eq!(ModelType::OptimalParams.dir_name(), "optimal_params");
164        assert_eq!(
165            ModelType::Custom("my_model".to_string()).dir_name(),
166            "my_model"
167        );
168    }
169
170    #[test]
171    fn test_model_version() {
172        let v = ModelVersion::new(1, 2);
173        assert_eq!(format!("{}", v), "1.2");
174    }
175
176    #[test]
177    fn test_model_metadata_builder() {
178        let meta = ModelMetadata::new()
179            .with_name("test")
180            .with_description("desc")
181            .with_tag("env", "prod");
182
183        assert_eq!(meta.name.as_deref(), Some("test"));
184        assert_eq!(meta.description.as_deref(), Some("desc"));
185        assert_eq!(meta.tags.get("env").map(|s| s.as_str()), Some("prod"));
186    }
187}