1use serde::{de::DeserializeOwned, Serialize};
39
40use super::episode::Episode;
41use super::learn_model::LearnError;
42
43pub trait ComponentLearner: Send + Sync {
53 type Output: LearnedComponent;
55
56 fn name(&self) -> &str;
58
59 fn objective(&self) -> &str;
61
62 fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError>;
64
65 fn update(
67 &self,
68 existing: &Self::Output,
69 new_episodes: &[Episode],
70 ) -> Result<Self::Output, LearnError> {
71 let mut learned = self.learn(new_episodes)?;
72 learned.merge(existing);
73 Ok(learned)
74 }
75}
76
77pub trait LearnedComponent: Send + Sync + Serialize + DeserializeOwned + Clone {
86 fn component_id() -> &'static str
88 where
89 Self: Sized;
90
91 fn confidence(&self) -> f64;
96
97 fn session_count(&self) -> usize;
99
100 fn updated_at(&self) -> u64;
102
103 fn merge(&mut self, other: &Self)
107 where
108 Self: Sized,
109 {
110 if other.confidence() > self.confidence() {
111 *self = other.clone();
112 }
113 }
114
115 fn version() -> u32
117 where
118 Self: Sized,
119 {
120 1
121 }
122}
123
124use crate::exploration::DependencyGraph;
129
130pub use super::offline::RecommendedPath;
132
133#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
138pub struct LearnedDepGraph {
139 pub graph: DependencyGraph,
141
142 pub action_order: Vec<String>,
144
145 #[serde(default)]
147 pub recommended_paths: Vec<RecommendedPath>,
148
149 pub confidence: f64,
151
152 pub learned_from: Vec<String>,
154
155 pub updated_at: u64,
157}
158
159impl LearnedDepGraph {
160 pub fn new(graph: DependencyGraph, action_order: Vec<String>) -> Self {
162 Self {
163 graph,
164 action_order,
165 recommended_paths: Vec::new(),
166 confidence: 0.0,
167 learned_from: Vec::new(),
168 updated_at: std::time::SystemTime::now()
169 .duration_since(std::time::UNIX_EPOCH)
170 .map(|d| d.as_secs())
171 .unwrap_or(0),
172 }
173 }
174
175 pub fn with_confidence(mut self, confidence: f64) -> Self {
177 self.confidence = confidence;
178 self
179 }
180
181 pub fn with_sessions(mut self, session_ids: Vec<String>) -> Self {
183 self.learned_from = session_ids;
184 self
185 }
186
187 pub fn with_recommended_paths(mut self, paths: Vec<RecommendedPath>) -> Self {
189 self.recommended_paths = paths;
190 self
191 }
192}
193
194impl LearnedComponent for LearnedDepGraph {
195 fn component_id() -> &'static str {
196 "dep_graph"
197 }
198
199 fn confidence(&self) -> f64 {
200 self.confidence
201 }
202
203 fn session_count(&self) -> usize {
204 self.learned_from.len()
205 }
206
207 fn updated_at(&self) -> u64 {
208 self.updated_at
209 }
210
211 fn merge(&mut self, other: &Self) {
212 if other.learned_from.len() > self.learned_from.len() || other.confidence > self.confidence
214 {
215 self.graph = other.graph.clone();
216 self.action_order = other.action_order.clone();
217 self.confidence = other.confidence;
218 }
219 for id in &other.learned_from {
221 if !self.learned_from.contains(id) {
222 self.learned_from.push(id.clone());
223 }
224 }
225 for path in &other.recommended_paths {
227 if !self
228 .recommended_paths
229 .iter()
230 .any(|p| p.actions == path.actions)
231 {
232 self.recommended_paths.push(path.clone());
233 }
234 }
235 self.updated_at = other.updated_at.max(self.updated_at);
236 }
237}
238
239#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
247pub struct LearnedExploration {
248 pub ucb1_c: f64,
250
251 pub learning_weight: f64,
253
254 pub ngram_weight: f64,
256
257 pub confidence: f64,
259
260 pub session_count: usize,
262
263 pub updated_at: u64,
265}
266
267impl Default for LearnedExploration {
268 fn default() -> Self {
269 Self {
270 ucb1_c: 1.414,
271 learning_weight: 0.3,
272 ngram_weight: 1.0,
273 confidence: 0.0,
274 session_count: 0,
275 updated_at: 0,
276 }
277 }
278}
279
280impl LearnedExploration {
281 pub fn new(ucb1_c: f64, learning_weight: f64, ngram_weight: f64) -> Self {
283 Self {
284 ucb1_c,
285 learning_weight,
286 ngram_weight,
287 confidence: 0.0,
288 session_count: 0,
289 updated_at: std::time::SystemTime::now()
290 .duration_since(std::time::UNIX_EPOCH)
291 .map(|d| d.as_secs())
292 .unwrap_or(0),
293 }
294 }
295}
296
297impl LearnedComponent for LearnedExploration {
298 fn component_id() -> &'static str {
299 "exploration"
300 }
301
302 fn confidence(&self) -> f64 {
303 self.confidence
304 }
305
306 fn session_count(&self) -> usize {
307 self.session_count
308 }
309
310 fn updated_at(&self) -> u64 {
311 self.updated_at
312 }
313}
314
315#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
323pub struct LearnedStrategy {
324 pub initial_strategy: String,
326
327 pub maturity_threshold: usize,
329
330 pub error_rate_threshold: f64,
332
333 pub confidence: f64,
335
336 pub session_count: usize,
338
339 pub updated_at: u64,
341}
342
343impl Default for LearnedStrategy {
344 fn default() -> Self {
345 Self {
346 initial_strategy: "ucb1".to_string(),
347 maturity_threshold: 5,
348 error_rate_threshold: 0.45,
349 confidence: 0.0,
350 session_count: 0,
351 updated_at: 0,
352 }
353 }
354}
355
356impl LearnedComponent for LearnedStrategy {
357 fn component_id() -> &'static str {
358 "strategy"
359 }
360
361 fn confidence(&self) -> f64 {
362 self.confidence
363 }
364
365 fn session_count(&self) -> usize {
366 self.session_count
367 }
368
369 fn updated_at(&self) -> u64 {
370 self.updated_at
371 }
372}
373
374#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::exploration::DependencyGraph;
382
383 #[test]
384 fn test_learned_dep_graph_creation() {
385 let graph = DependencyGraph::new();
386 let learned = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
387 .with_confidence(0.8)
388 .with_sessions(vec!["s1".to_string(), "s2".to_string()]);
389
390 assert_eq!(learned.confidence(), 0.8);
391 assert_eq!(learned.session_count(), 2);
392 assert_eq!(LearnedDepGraph::component_id(), "dep_graph");
393 }
394
395 #[test]
396 fn test_learned_dep_graph_merge() {
397 let graph = DependencyGraph::new();
398 let mut learned1 = LearnedDepGraph::new(graph.clone(), vec!["A".to_string()])
399 .with_confidence(0.5)
400 .with_sessions(vec!["s1".to_string()]);
401
402 let learned2 = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
403 .with_confidence(0.8)
404 .with_sessions(vec!["s2".to_string(), "s3".to_string()]);
405
406 learned1.merge(&learned2);
407
408 assert_eq!(learned1.confidence, 0.8);
410 assert_eq!(learned1.action_order.len(), 2);
411 assert_eq!(learned1.learned_from.len(), 3);
413 }
414
415 #[test]
416 fn test_learned_exploration_default() {
417 let exploration = LearnedExploration::default();
418 assert_eq!(exploration.ucb1_c, 1.414);
419 assert_eq!(LearnedExploration::component_id(), "exploration");
420 }
421
422 #[test]
423 fn test_learned_strategy_default() {
424 let strategy = LearnedStrategy::default();
425 assert_eq!(strategy.initial_strategy, "ucb1");
426 assert_eq!(LearnedStrategy::component_id(), "strategy");
427 }
428
429 #[test]
430 fn test_serialization() {
431 let exploration = LearnedExploration::new(2.0, 0.5, 1.5);
432 let json = serde_json::to_string(&exploration).unwrap();
433 let restored: LearnedExploration = serde_json::from_str(&json).unwrap();
434 assert_eq!(restored.ucb1_c, 2.0);
435 }
436}