swarm_engine_core/learn/stats_model/
base.rs1use std::any::Any;
4use std::collections::HashMap;
5
6use serde::{Deserialize, Serialize};
7
8use crate::util::epoch_millis_for_ordering;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct StatsModelId(pub String);
13
14impl StatsModelId {
15 pub fn new(id: impl Into<String>) -> Self {
16 Self(id.into())
17 }
18
19 pub fn generate() -> Self {
20 let ts = epoch_millis_for_ordering();
21 Self(format!("stats-{:x}", ts))
22 }
23
24 pub fn as_str(&self) -> &str {
25 &self.0
26 }
27}
28
29impl std::fmt::Display for StatsModelId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.0)
32 }
33}
34
35pub trait Model: Send + Sync {
37 fn model_type(&self) -> ModelType;
39
40 fn version(&self) -> &ModelVersion;
42
43 fn created_at(&self) -> u64;
45
46 fn metadata(&self) -> &ModelMetadata;
48
49 fn as_any(&self) -> &dyn Any;
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
55pub enum ModelType {
56 ActionScore,
58 OptimalParams,
60 Custom(String),
62}
63
64impl ModelType {
65 pub fn dir_name(&self) -> &str {
67 match self {
68 Self::ActionScore => "action_scores",
69 Self::OptimalParams => "optimal_params",
70 Self::Custom(name) => name,
71 }
72 }
73}
74
75impl std::fmt::Display for ModelType {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Self::ActionScore => write!(f, "ActionScore"),
79 Self::OptimalParams => write!(f, "OptimalParams"),
80 Self::Custom(name) => write!(f, "Custom({})", name),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
87pub struct ModelVersion {
88 pub major: u32,
89 pub minor: u32,
90 pub source_ids: Vec<String>,
92}
93
94impl ModelVersion {
95 pub fn new(major: u32, minor: u32) -> Self {
96 Self {
97 major,
98 minor,
99 source_ids: Vec::new(),
100 }
101 }
102
103 pub fn with_sources(major: u32, minor: u32, source_ids: Vec<String>) -> Self {
104 Self {
105 major,
106 minor,
107 source_ids,
108 }
109 }
110}
111
112impl std::fmt::Display for ModelVersion {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "{}.{}", self.major, self.minor)
115 }
116}
117
118#[derive(Debug, Clone, Default, Serialize, Deserialize)]
120pub struct ModelMetadata {
121 pub name: Option<String>,
122 pub description: Option<String>,
123 pub tags: HashMap<String, String>,
124}
125
126impl ModelMetadata {
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 pub fn with_name(mut self, name: impl Into<String>) -> Self {
132 self.name = Some(name.into());
133 self
134 }
135
136 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
137 self.description = Some(desc.into());
138 self
139 }
140
141 pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
142 self.tags.insert(key.into(), value.into());
143 self
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_stats_model_id_generate() {
153 let id1 = StatsModelId::generate();
154 let id2 = StatsModelId::generate();
155 assert!(!id1.0.is_empty());
156 assert!(!id2.0.is_empty());
157 assert!(id1.as_str().starts_with("stats-"));
158 }
159
160 #[test]
161 fn test_model_type_dir_name() {
162 assert_eq!(ModelType::ActionScore.dir_name(), "action_scores");
163 assert_eq!(ModelType::OptimalParams.dir_name(), "optimal_params");
164 assert_eq!(
165 ModelType::Custom("my_model".to_string()).dir_name(),
166 "my_model"
167 );
168 }
169
170 #[test]
171 fn test_model_version() {
172 let v = ModelVersion::new(1, 2);
173 assert_eq!(format!("{}", v), "1.2");
174 }
175
176 #[test]
177 fn test_model_metadata_builder() {
178 let meta = ModelMetadata::new()
179 .with_name("test")
180 .with_description("desc")
181 .with_tag("env", "prod");
182
183 assert_eq!(meta.name.as_deref(), Some("test"));
184 assert_eq!(meta.description.as_deref(), Some("desc"));
185 assert_eq!(meta.tags.get("env").map(|s| s.as_str()), Some("prod"));
186 }
187}