1use serde::{de::DeserializeOwned, Serialize};
39
40use super::episode::Episode;
41use super::learn_model::LearnError;
42
43pub trait ComponentLearner: Send + Sync {
53 type Output: LearnedComponent;
55
56 fn name(&self) -> &str;
58
59 fn objective(&self) -> &str;
61
62 fn learn(&self, episodes: &[Episode]) -> Result<Self::Output, LearnError>;
64
65 fn update(
67 &self,
68 existing: &Self::Output,
69 new_episodes: &[Episode],
70 ) -> Result<Self::Output, LearnError> {
71 let mut learned = self.learn(new_episodes)?;
72 learned.merge(existing);
73 Ok(learned)
74 }
75}
76
77pub trait LearnedComponent: Send + Sync + Serialize + DeserializeOwned + Clone {
86 fn component_id() -> &'static str
88 where
89 Self: Sized;
90
91 fn confidence(&self) -> f64;
96
97 fn session_count(&self) -> usize;
99
100 fn updated_at(&self) -> u64;
102
103 fn merge(&mut self, other: &Self)
107 where
108 Self: Sized,
109 {
110 if other.confidence() > self.confidence() {
111 *self = other.clone();
112 }
113 }
114
115 fn version() -> u32
117 where
118 Self: Sized,
119 {
120 1
121 }
122}
123
124use crate::exploration::DependencyGraph;
129
130pub use super::offline::RecommendedPath;
132
133#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
138pub struct LearnedDepGraph {
139 pub graph: DependencyGraph,
141
142 pub action_order: Vec<String>,
145
146 #[serde(default)]
148 pub discover_order: Vec<String>,
149
150 #[serde(default)]
152 pub not_discover_order: Vec<String>,
153
154 #[serde(default)]
156 pub recommended_paths: Vec<RecommendedPath>,
157
158 pub confidence: f64,
160
161 pub learned_from: Vec<String>,
163
164 pub updated_at: u64,
166}
167
168impl LearnedDepGraph {
169 pub fn new(graph: DependencyGraph, action_order: Vec<String>) -> Self {
171 Self {
172 graph,
173 action_order,
174 discover_order: Vec::new(),
175 not_discover_order: Vec::new(),
176 recommended_paths: Vec::new(),
177 confidence: 0.0,
178 learned_from: Vec::new(),
179 updated_at: std::time::SystemTime::now()
180 .duration_since(std::time::UNIX_EPOCH)
181 .map(|d| d.as_secs())
182 .unwrap_or(0),
183 }
184 }
185
186 pub fn with_orders(
188 graph: DependencyGraph,
189 discover_order: Vec<String>,
190 not_discover_order: Vec<String>,
191 ) -> Self {
192 let mut all_actions = discover_order.clone();
193 all_actions.extend(not_discover_order.clone());
194 Self {
195 graph,
196 action_order: all_actions,
197 discover_order,
198 not_discover_order,
199 recommended_paths: Vec::new(),
200 confidence: 0.0,
201 learned_from: Vec::new(),
202 updated_at: std::time::SystemTime::now()
203 .duration_since(std::time::UNIX_EPOCH)
204 .map(|d| d.as_secs())
205 .unwrap_or(0),
206 }
207 }
208
209 pub fn with_confidence(mut self, confidence: f64) -> Self {
211 self.confidence = confidence;
212 self
213 }
214
215 pub fn with_sessions(mut self, session_ids: Vec<String>) -> Self {
217 self.learned_from = session_ids;
218 self
219 }
220
221 pub fn with_recommended_paths(mut self, paths: Vec<RecommendedPath>) -> Self {
223 self.recommended_paths = paths;
224 self
225 }
226}
227
228impl LearnedComponent for LearnedDepGraph {
229 fn component_id() -> &'static str {
230 "dep_graph"
231 }
232
233 fn confidence(&self) -> f64 {
234 self.confidence
235 }
236
237 fn session_count(&self) -> usize {
238 self.learned_from.len()
239 }
240
241 fn updated_at(&self) -> u64 {
242 self.updated_at
243 }
244
245 fn merge(&mut self, other: &Self) {
246 if other.learned_from.len() > self.learned_from.len() || other.confidence > self.confidence
248 {
249 self.graph = other.graph.clone();
250 self.action_order = other.action_order.clone();
251 self.confidence = other.confidence;
252 }
253 for id in &other.learned_from {
255 if !self.learned_from.contains(id) {
256 self.learned_from.push(id.clone());
257 }
258 }
259 for path in &other.recommended_paths {
261 if !self
262 .recommended_paths
263 .iter()
264 .any(|p| p.actions == path.actions)
265 {
266 self.recommended_paths.push(path.clone());
267 }
268 }
269 self.updated_at = other.updated_at.max(self.updated_at);
270 }
271}
272
273#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
281pub struct LearnedExploration {
282 pub ucb1_c: f64,
284
285 pub learning_weight: f64,
287
288 pub ngram_weight: f64,
290
291 pub confidence: f64,
293
294 pub session_count: usize,
296
297 pub updated_at: u64,
299}
300
301impl Default for LearnedExploration {
302 fn default() -> Self {
303 Self {
304 ucb1_c: 1.414,
305 learning_weight: 0.3,
306 ngram_weight: 1.0,
307 confidence: 0.0,
308 session_count: 0,
309 updated_at: 0,
310 }
311 }
312}
313
314impl LearnedExploration {
315 pub fn new(ucb1_c: f64, learning_weight: f64, ngram_weight: f64) -> Self {
317 Self {
318 ucb1_c,
319 learning_weight,
320 ngram_weight,
321 confidence: 0.0,
322 session_count: 0,
323 updated_at: std::time::SystemTime::now()
324 .duration_since(std::time::UNIX_EPOCH)
325 .map(|d| d.as_secs())
326 .unwrap_or(0),
327 }
328 }
329}
330
331impl LearnedComponent for LearnedExploration {
332 fn component_id() -> &'static str {
333 "exploration"
334 }
335
336 fn confidence(&self) -> f64 {
337 self.confidence
338 }
339
340 fn session_count(&self) -> usize {
341 self.session_count
342 }
343
344 fn updated_at(&self) -> u64 {
345 self.updated_at
346 }
347}
348
349#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
357pub struct LearnedStrategy {
358 pub initial_strategy: String,
360
361 pub maturity_threshold: usize,
363
364 pub error_rate_threshold: f64,
366
367 pub confidence: f64,
369
370 pub session_count: usize,
372
373 pub updated_at: u64,
375}
376
377impl Default for LearnedStrategy {
378 fn default() -> Self {
379 Self {
380 initial_strategy: "ucb1".to_string(),
381 maturity_threshold: 5,
382 error_rate_threshold: 0.45,
383 confidence: 0.0,
384 session_count: 0,
385 updated_at: 0,
386 }
387 }
388}
389
390impl LearnedComponent for LearnedStrategy {
391 fn component_id() -> &'static str {
392 "strategy"
393 }
394
395 fn confidence(&self) -> f64 {
396 self.confidence
397 }
398
399 fn session_count(&self) -> usize {
400 self.session_count
401 }
402
403 fn updated_at(&self) -> u64 {
404 self.updated_at
405 }
406}
407
408#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::exploration::DependencyGraph;
416
417 #[test]
418 fn test_learned_dep_graph_creation() {
419 let graph = DependencyGraph::new();
420 let learned = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
421 .with_confidence(0.8)
422 .with_sessions(vec!["s1".to_string(), "s2".to_string()]);
423
424 assert_eq!(learned.confidence(), 0.8);
425 assert_eq!(learned.session_count(), 2);
426 assert_eq!(LearnedDepGraph::component_id(), "dep_graph");
427 }
428
429 #[test]
430 fn test_learned_dep_graph_merge() {
431 let graph = DependencyGraph::new();
432 let mut learned1 = LearnedDepGraph::new(graph.clone(), vec!["A".to_string()])
433 .with_confidence(0.5)
434 .with_sessions(vec!["s1".to_string()]);
435
436 let learned2 = LearnedDepGraph::new(graph, vec!["A".to_string(), "B".to_string()])
437 .with_confidence(0.8)
438 .with_sessions(vec!["s2".to_string(), "s3".to_string()]);
439
440 learned1.merge(&learned2);
441
442 assert_eq!(learned1.confidence, 0.8);
444 assert_eq!(learned1.action_order.len(), 2);
445 assert_eq!(learned1.learned_from.len(), 3);
447 }
448
449 #[test]
450 fn test_learned_exploration_default() {
451 let exploration = LearnedExploration::default();
452 assert_eq!(exploration.ucb1_c, 1.414);
453 assert_eq!(LearnedExploration::component_id(), "exploration");
454 }
455
456 #[test]
457 fn test_learned_strategy_default() {
458 let strategy = LearnedStrategy::default();
459 assert_eq!(strategy.initial_strategy, "ucb1");
460 assert_eq!(LearnedStrategy::component_id(), "strategy");
461 }
462
463 #[test]
464 fn test_serialization() {
465 let exploration = LearnedExploration::new(2.0, 0.5, 1.5);
466 let json = serde_json::to_string(&exploration).unwrap();
467 let restored: LearnedExploration = serde_json::from_str(&json).unwrap();
468 assert_eq!(restored.ucb1_c, 2.0);
469 }
470}