1use std::collections::HashMap;
42
43use serde::{Deserialize, Serialize};
44
45use super::snapshot::LearningSnapshot;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OfflineModel {
52 pub version: u32,
54 pub parameters: OptimalParameters,
56 pub recommended_paths: Vec<RecommendedPath>,
58 pub strategy_config: StrategyConfig,
60 pub analyzed_sessions: usize,
62 pub updated_at: u64,
64 #[serde(default)]
66 pub action_order: Option<LearnedActionOrder>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct LearnedActionOrder {
82 pub discover: Vec<String>,
84 pub not_discover: Vec<String>,
86 pub action_set_hash: u64,
90 #[serde(default)]
92 pub source: ActionOrderSource,
93 #[serde(default)]
95 pub lora: Option<crate::types::LoraConfig>,
96 #[serde(default)]
98 pub validated_accuracy: Option<f64>,
99}
100
101#[derive(Debug, Clone, Default, Serialize, Deserialize)]
103pub enum ActionOrderSource {
104 #[default]
106 Llm,
107 Static,
109 Manual,
111}
112
113impl LearnedActionOrder {
114 pub fn new(discover: Vec<String>, not_discover: Vec<String>, actions: &[String]) -> Self {
116 Self {
117 discover,
118 not_discover,
119 action_set_hash: Self::compute_hash(actions),
120 source: ActionOrderSource::Llm,
121 lora: None,
122 validated_accuracy: None,
123 }
124 }
125
126 pub fn with_lora(mut self, lora: crate::types::LoraConfig) -> Self {
128 self.lora = Some(lora);
129 self
130 }
131
132 pub fn with_accuracy(mut self, accuracy: f64) -> Self {
134 self.validated_accuracy = Some(accuracy);
135 self
136 }
137
138 pub fn with_source(mut self, source: ActionOrderSource) -> Self {
140 self.source = source;
141 self
142 }
143
144 pub fn compute_hash(actions: &[String]) -> u64 {
148 use std::collections::hash_map::DefaultHasher;
149 use std::hash::{Hash, Hasher};
150
151 let mut sorted: Vec<&str> = actions.iter().map(|s| s.as_str()).collect();
152 sorted.sort();
153
154 let mut hasher = DefaultHasher::new();
155 for action in sorted {
156 action.hash(&mut hasher);
157 }
158 hasher.finish()
159 }
160
161 pub fn is_exact_match(&self, actions: &[String]) -> bool {
165 self.action_set_hash == Self::compute_hash(actions)
166 }
167
168 #[inline]
170 pub fn matches_actions(&self, actions: &[String]) -> bool {
171 self.is_exact_match(actions)
172 }
173
174 pub fn match_rate(&self, actions: &[String]) -> f64 {
176 use std::collections::HashSet;
177
178 let mut self_actions: Vec<String> = self.discover.clone();
179 self_actions.extend(self.not_discover.clone());
180
181 if self_actions.is_empty() && actions.is_empty() {
182 return 1.0;
183 }
184 if self_actions.is_empty() || actions.is_empty() {
185 return 0.0;
186 }
187
188 let self_set: HashSet<_> = self_actions.iter().collect();
189 let other_set: HashSet<_> = actions.iter().collect();
190
191 let intersection = self_set.intersection(&other_set).count();
192 let union = self_set.union(&other_set).count();
193
194 intersection as f64 / union as f64
195 }
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct OptimalParameters {
201 pub ucb1_c: f64,
203 pub learning_weight: f64,
205 pub ngram_weight: f64,
207}
208
209impl Default for OptimalParameters {
210 fn default() -> Self {
211 Self {
212 ucb1_c: std::f64::consts::SQRT_2,
213 learning_weight: 0.3,
214 ngram_weight: 1.0,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct RecommendedPath {
222 pub actions: Vec<String>,
224 pub success_rate: f64,
226 pub observations: u32,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct StrategyConfig {
233 pub maturity_threshold: u32,
235 pub error_rate_threshold: f64,
237 pub initial_strategy: String,
239}
240
241impl Default for StrategyConfig {
242 fn default() -> Self {
243 Self {
244 maturity_threshold: 10,
245 error_rate_threshold: 0.3,
246 initial_strategy: "ucb1".to_string(),
247 }
248 }
249}
250
251impl Default for OfflineModel {
252 fn default() -> Self {
253 Self {
254 version: 1,
255 parameters: OptimalParameters::default(),
256 recommended_paths: Vec::new(),
257 strategy_config: StrategyConfig::default(),
258 analyzed_sessions: 0,
259 updated_at: 0,
260 action_order: None,
261 }
262 }
263}
264
265pub struct OfflineAnalyzer<'a> {
269 snapshots: &'a [LearningSnapshot],
270}
271
272impl<'a> OfflineAnalyzer<'a> {
273 pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
275 Self { snapshots }
276 }
277
278 pub fn analyze(&self) -> OfflineModel {
280 let now = std::time::SystemTime::now()
281 .duration_since(std::time::UNIX_EPOCH)
282 .map(|d| d.as_secs())
283 .unwrap_or(0);
284
285 OfflineModel {
286 version: 1,
287 parameters: self.analyze_parameters(),
288 recommended_paths: self.extract_paths(),
289 strategy_config: self.analyze_strategy(),
290 analyzed_sessions: self.snapshots.len(),
291 updated_at: now,
292 action_order: None, }
294 }
295
296 pub fn analyze_parameters(&self) -> OptimalParameters {
301 if self.snapshots.is_empty() {
302 return OptimalParameters::default();
303 }
304
305 let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
307 (
308 acc.0 + s.episode_transitions.success_episodes,
309 acc.1 + s.episode_transitions.failure_episodes,
310 )
311 });
312
313 let success_rate = if total_success + total_failure > 0 {
314 total_success as f64 / (total_success + total_failure) as f64
315 } else {
316 0.5
317 };
318
319 let ucb1_c = if success_rate > 0.8 {
323 1.0 } else if success_rate < 0.5 {
325 2.0 } else {
327 std::f64::consts::SQRT_2 };
329
330 let ngram_effectiveness = self.evaluate_ngram_effectiveness();
332 let ngram_weight = if ngram_effectiveness > 0.7 {
333 1.5 } else if ngram_effectiveness < 0.3 {
335 0.5 } else {
337 1.0
338 };
339
340 OptimalParameters {
341 ucb1_c,
342 learning_weight: 0.3, ngram_weight,
344 }
345 }
346
347 fn evaluate_ngram_effectiveness(&self) -> f64 {
351 let mut all_rates: Vec<f64> = Vec::new();
352
353 for snapshot in self.snapshots {
354 for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
355 let total = success + failure;
356 if total >= 3 {
357 all_rates.push(success as f64 / total as f64);
359 }
360 }
361 }
362
363 if all_rates.is_empty() {
364 return 0.5; }
366
367 let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
369 let variance =
370 all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
371
372 (variance / 0.25).min(1.0)
374 }
375
376 pub fn extract_paths(&self) -> Vec<RecommendedPath> {
380 let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
382
383 for snapshot in self.snapshots {
384 for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
385 let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
386 let entry = path_stats.entry(path).or_insert((0, 0));
387 entry.0 += success;
388 entry.1 += failure;
389 }
390 }
391
392 let mut paths: Vec<RecommendedPath> = path_stats
394 .into_iter()
395 .filter(|(_, (s, f))| s + f >= 5) .map(|(actions, (success, failure))| {
397 let total = success + failure;
398 RecommendedPath {
399 actions,
400 success_rate: success as f64 / total as f64,
401 observations: total,
402 }
403 })
404 .collect();
405
406 paths.sort_by(|a, b| {
407 b.success_rate
408 .partial_cmp(&a.success_rate)
409 .unwrap_or(std::cmp::Ordering::Equal)
410 });
411
412 paths.into_iter().take(10).collect() }
414
415 pub fn analyze_strategy(&self) -> StrategyConfig {
419 if self.snapshots.is_empty() {
420 return StrategyConfig::default();
421 }
422
423 let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
425 (
426 acc.0 + s.episode_transitions.success_episodes,
427 acc.1 + s.episode_transitions.failure_episodes,
428 )
429 });
430
431 let avg_error_rate = if total_success + total_failure > 0 {
432 total_failure as f64 / (total_success + total_failure) as f64
433 } else {
434 0.3
435 };
436
437 let total_actions: u64 = self
439 .snapshots
440 .iter()
441 .map(|s| s.metadata.total_actions as u64)
442 .sum();
443 let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
444
445 let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
447
448 let initial_strategy = if avg_error_rate > 0.4 {
450 "thompson" } else if avg_error_rate < 0.1 {
452 "greedy" } else {
454 "ucb1" };
456
457 StrategyConfig {
458 maturity_threshold,
459 error_rate_threshold: (avg_error_rate * 1.5).min(0.5), initial_strategy: initial_strategy.to_string(),
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
470 let mut snapshot = LearningSnapshot::empty();
471 snapshot.episode_transitions.success_episodes = success;
472 snapshot.episode_transitions.failure_episodes = failure;
473 snapshot.metadata.total_actions = (success + failure) * 5;
474 snapshot
475 }
476
477 #[test]
478 fn test_analyzer_empty_snapshots() {
479 let snapshots: Vec<LearningSnapshot> = vec![];
480 let analyzer = OfflineAnalyzer::new(&snapshots);
481 let model = analyzer.analyze();
482
483 assert_eq!(model.analyzed_sessions, 0);
484 assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
485 }
486
487 #[test]
488 fn test_analyzer_high_success_rate() {
489 let snapshots = vec![
490 create_test_snapshot(9, 1),
491 create_test_snapshot(8, 2),
492 create_test_snapshot(10, 0),
493 ];
494 let analyzer = OfflineAnalyzer::new(&snapshots);
495 let params = analyzer.analyze_parameters();
496
497 assert!(params.ucb1_c < std::f64::consts::SQRT_2);
499 }
500
501 #[test]
502 fn test_analyzer_low_success_rate() {
503 let snapshots = vec![
504 create_test_snapshot(3, 7),
505 create_test_snapshot(4, 6),
506 create_test_snapshot(2, 8),
507 ];
508 let analyzer = OfflineAnalyzer::new(&snapshots);
509 let params = analyzer.analyze_parameters();
510
511 assert!(params.ucb1_c > std::f64::consts::SQRT_2);
513 }
514
515 #[test]
516 fn test_strategy_config_high_error() {
517 let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
518 let analyzer = OfflineAnalyzer::new(&snapshots);
519 let config = analyzer.analyze_strategy();
520
521 assert_eq!(config.initial_strategy, "thompson");
522 }
523
524 #[test]
525 fn test_strategy_config_low_error() {
526 let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
527 let analyzer = OfflineAnalyzer::new(&snapshots);
528 let config = analyzer.analyze_strategy();
529
530 assert_eq!(config.initial_strategy, "greedy");
531 }
532}