1use std::collections::HashMap;
42
43use serde::{Deserialize, Serialize};
44
45use super::snapshot::LearningSnapshot;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OfflineModel {
52 pub version: u32,
54 pub parameters: OptimalParameters,
56 pub recommended_paths: Vec<RecommendedPath>,
58 pub strategy_config: StrategyConfig,
60 pub analyzed_sessions: usize,
62 pub updated_at: u64,
64 #[serde(default)]
66 pub action_order: Option<LearnedActionOrder>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct LearnedActionOrder {
75 pub discover: Vec<String>,
77 pub not_discover: Vec<String>,
79 pub action_set_hash: u64,
83 #[serde(default)]
85 pub source: ActionOrderSource,
86}
87
88#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub enum ActionOrderSource {
91 #[default]
93 Llm,
94 Static,
96 Manual,
98}
99
100impl LearnedActionOrder {
101 pub fn new(discover: Vec<String>, not_discover: Vec<String>, actions: &[String]) -> Self {
103 Self {
104 discover,
105 not_discover,
106 action_set_hash: Self::compute_hash(actions),
107 source: ActionOrderSource::Llm,
108 }
109 }
110
111 pub fn compute_hash(actions: &[String]) -> u64 {
115 use std::collections::hash_map::DefaultHasher;
116 use std::hash::{Hash, Hasher};
117
118 let mut sorted: Vec<&str> = actions.iter().map(|s| s.as_str()).collect();
119 sorted.sort();
120
121 let mut hasher = DefaultHasher::new();
122 for action in sorted {
123 action.hash(&mut hasher);
124 }
125 hasher.finish()
126 }
127
128 pub fn matches_actions(&self, actions: &[String]) -> bool {
130 self.action_set_hash == Self::compute_hash(actions)
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct OptimalParameters {
137 pub ucb1_c: f64,
139 pub learning_weight: f64,
141 pub ngram_weight: f64,
143}
144
145impl Default for OptimalParameters {
146 fn default() -> Self {
147 Self {
148 ucb1_c: std::f64::consts::SQRT_2,
149 learning_weight: 0.3,
150 ngram_weight: 1.0,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct RecommendedPath {
158 pub actions: Vec<String>,
160 pub success_rate: f64,
162 pub observations: u32,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct StrategyConfig {
169 pub maturity_threshold: u32,
171 pub error_rate_threshold: f64,
173 pub initial_strategy: String,
175}
176
177impl Default for StrategyConfig {
178 fn default() -> Self {
179 Self {
180 maturity_threshold: 10,
181 error_rate_threshold: 0.3,
182 initial_strategy: "ucb1".to_string(),
183 }
184 }
185}
186
187impl Default for OfflineModel {
188 fn default() -> Self {
189 Self {
190 version: 1,
191 parameters: OptimalParameters::default(),
192 recommended_paths: Vec::new(),
193 strategy_config: StrategyConfig::default(),
194 analyzed_sessions: 0,
195 updated_at: 0,
196 action_order: None,
197 }
198 }
199}
200
201pub struct OfflineAnalyzer<'a> {
205 snapshots: &'a [LearningSnapshot],
206}
207
208impl<'a> OfflineAnalyzer<'a> {
209 pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
211 Self { snapshots }
212 }
213
214 pub fn analyze(&self) -> OfflineModel {
216 let now = std::time::SystemTime::now()
217 .duration_since(std::time::UNIX_EPOCH)
218 .map(|d| d.as_secs())
219 .unwrap_or(0);
220
221 OfflineModel {
222 version: 1,
223 parameters: self.analyze_parameters(),
224 recommended_paths: self.extract_paths(),
225 strategy_config: self.analyze_strategy(),
226 analyzed_sessions: self.snapshots.len(),
227 updated_at: now,
228 action_order: None, }
230 }
231
232 pub fn analyze_parameters(&self) -> OptimalParameters {
237 if self.snapshots.is_empty() {
238 return OptimalParameters::default();
239 }
240
241 let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
243 (
244 acc.0 + s.episode_transitions.success_episodes,
245 acc.1 + s.episode_transitions.failure_episodes,
246 )
247 });
248
249 let success_rate = if total_success + total_failure > 0 {
250 total_success as f64 / (total_success + total_failure) as f64
251 } else {
252 0.5
253 };
254
255 let ucb1_c = if success_rate > 0.8 {
259 1.0 } else if success_rate < 0.5 {
261 2.0 } else {
263 std::f64::consts::SQRT_2 };
265
266 let ngram_effectiveness = self.evaluate_ngram_effectiveness();
268 let ngram_weight = if ngram_effectiveness > 0.7 {
269 1.5 } else if ngram_effectiveness < 0.3 {
271 0.5 } else {
273 1.0
274 };
275
276 OptimalParameters {
277 ucb1_c,
278 learning_weight: 0.3, ngram_weight,
280 }
281 }
282
283 fn evaluate_ngram_effectiveness(&self) -> f64 {
287 let mut all_rates: Vec<f64> = Vec::new();
288
289 for snapshot in self.snapshots {
290 for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
291 let total = success + failure;
292 if total >= 3 {
293 all_rates.push(success as f64 / total as f64);
295 }
296 }
297 }
298
299 if all_rates.is_empty() {
300 return 0.5; }
302
303 let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
305 let variance =
306 all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
307
308 (variance / 0.25).min(1.0)
310 }
311
312 pub fn extract_paths(&self) -> Vec<RecommendedPath> {
316 let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
318
319 for snapshot in self.snapshots {
320 for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
321 let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
322 let entry = path_stats.entry(path).or_insert((0, 0));
323 entry.0 += success;
324 entry.1 += failure;
325 }
326 }
327
328 let mut paths: Vec<RecommendedPath> = path_stats
330 .into_iter()
331 .filter(|(_, (s, f))| s + f >= 5) .map(|(actions, (success, failure))| {
333 let total = success + failure;
334 RecommendedPath {
335 actions,
336 success_rate: success as f64 / total as f64,
337 observations: total,
338 }
339 })
340 .collect();
341
342 paths.sort_by(|a, b| {
343 b.success_rate
344 .partial_cmp(&a.success_rate)
345 .unwrap_or(std::cmp::Ordering::Equal)
346 });
347
348 paths.into_iter().take(10).collect() }
350
351 pub fn analyze_strategy(&self) -> StrategyConfig {
355 if self.snapshots.is_empty() {
356 return StrategyConfig::default();
357 }
358
359 let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
361 (
362 acc.0 + s.episode_transitions.success_episodes,
363 acc.1 + s.episode_transitions.failure_episodes,
364 )
365 });
366
367 let avg_error_rate = if total_success + total_failure > 0 {
368 total_failure as f64 / (total_success + total_failure) as f64
369 } else {
370 0.3
371 };
372
373 let total_actions: u64 = self
375 .snapshots
376 .iter()
377 .map(|s| s.metadata.total_actions as u64)
378 .sum();
379 let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
380
381 let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
383
384 let initial_strategy = if avg_error_rate > 0.4 {
386 "thompson" } else if avg_error_rate < 0.1 {
388 "greedy" } else {
390 "ucb1" };
392
393 StrategyConfig {
394 maturity_threshold,
395 error_rate_threshold: (avg_error_rate * 1.5).min(0.5), initial_strategy: initial_strategy.to_string(),
397 }
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
406 let mut snapshot = LearningSnapshot::empty();
407 snapshot.episode_transitions.success_episodes = success;
408 snapshot.episode_transitions.failure_episodes = failure;
409 snapshot.metadata.total_actions = (success + failure) * 5;
410 snapshot
411 }
412
413 #[test]
414 fn test_analyzer_empty_snapshots() {
415 let snapshots: Vec<LearningSnapshot> = vec![];
416 let analyzer = OfflineAnalyzer::new(&snapshots);
417 let model = analyzer.analyze();
418
419 assert_eq!(model.analyzed_sessions, 0);
420 assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
421 }
422
423 #[test]
424 fn test_analyzer_high_success_rate() {
425 let snapshots = vec![
426 create_test_snapshot(9, 1),
427 create_test_snapshot(8, 2),
428 create_test_snapshot(10, 0),
429 ];
430 let analyzer = OfflineAnalyzer::new(&snapshots);
431 let params = analyzer.analyze_parameters();
432
433 assert!(params.ucb1_c < std::f64::consts::SQRT_2);
435 }
436
437 #[test]
438 fn test_analyzer_low_success_rate() {
439 let snapshots = vec![
440 create_test_snapshot(3, 7),
441 create_test_snapshot(4, 6),
442 create_test_snapshot(2, 8),
443 ];
444 let analyzer = OfflineAnalyzer::new(&snapshots);
445 let params = analyzer.analyze_parameters();
446
447 assert!(params.ucb1_c > std::f64::consts::SQRT_2);
449 }
450
451 #[test]
452 fn test_strategy_config_high_error() {
453 let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
454 let analyzer = OfflineAnalyzer::new(&snapshots);
455 let config = analyzer.analyze_strategy();
456
457 assert_eq!(config.initial_strategy, "thompson");
458 }
459
460 #[test]
461 fn test_strategy_config_low_error() {
462 let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
463 let analyzer = OfflineAnalyzer::new(&snapshots);
464 let config = analyzer.analyze_strategy();
465
466 assert_eq!(config.initial_strategy, "greedy");
467 }
468}