1use std::collections::HashMap;
42
43use serde::{Deserialize, Serialize};
44
45use super::snapshot::LearningSnapshot;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct OfflineModel {
52 pub version: u32,
54 pub parameters: OptimalParameters,
56 pub recommended_paths: Vec<RecommendedPath>,
58 pub strategy_config: StrategyConfig,
60 pub analyzed_sessions: usize,
62 pub updated_at: u64,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct OptimalParameters {
69 pub ucb1_c: f64,
71 pub learning_weight: f64,
73 pub ngram_weight: f64,
75}
76
77impl Default for OptimalParameters {
78 fn default() -> Self {
79 Self {
80 ucb1_c: std::f64::consts::SQRT_2,
81 learning_weight: 0.3,
82 ngram_weight: 1.0,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RecommendedPath {
90 pub actions: Vec<String>,
92 pub success_rate: f64,
94 pub observations: u32,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct StrategyConfig {
101 pub maturity_threshold: u32,
103 pub error_rate_threshold: f64,
105 pub initial_strategy: String,
107}
108
109impl Default for StrategyConfig {
110 fn default() -> Self {
111 Self {
112 maturity_threshold: 10,
113 error_rate_threshold: 0.3,
114 initial_strategy: "ucb1".to_string(),
115 }
116 }
117}
118
119impl Default for OfflineModel {
120 fn default() -> Self {
121 Self {
122 version: 1,
123 parameters: OptimalParameters::default(),
124 recommended_paths: Vec::new(),
125 strategy_config: StrategyConfig::default(),
126 analyzed_sessions: 0,
127 updated_at: 0,
128 }
129 }
130}
131
132pub struct OfflineAnalyzer<'a> {
136 snapshots: &'a [LearningSnapshot],
137}
138
139impl<'a> OfflineAnalyzer<'a> {
140 pub fn new(snapshots: &'a [LearningSnapshot]) -> Self {
142 Self { snapshots }
143 }
144
145 pub fn analyze(&self) -> OfflineModel {
147 let now = std::time::SystemTime::now()
148 .duration_since(std::time::UNIX_EPOCH)
149 .map(|d| d.as_secs())
150 .unwrap_or(0);
151
152 OfflineModel {
153 version: 1,
154 parameters: self.analyze_parameters(),
155 recommended_paths: self.extract_paths(),
156 strategy_config: self.analyze_strategy(),
157 analyzed_sessions: self.snapshots.len(),
158 updated_at: now,
159 }
160 }
161
162 pub fn analyze_parameters(&self) -> OptimalParameters {
167 if self.snapshots.is_empty() {
168 return OptimalParameters::default();
169 }
170
171 let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
173 (
174 acc.0 + s.episode_transitions.success_episodes,
175 acc.1 + s.episode_transitions.failure_episodes,
176 )
177 });
178
179 let success_rate = if total_success + total_failure > 0 {
180 total_success as f64 / (total_success + total_failure) as f64
181 } else {
182 0.5
183 };
184
185 let ucb1_c = if success_rate > 0.8 {
189 1.0 } else if success_rate < 0.5 {
191 2.0 } else {
193 std::f64::consts::SQRT_2 };
195
196 let ngram_effectiveness = self.evaluate_ngram_effectiveness();
198 let ngram_weight = if ngram_effectiveness > 0.7 {
199 1.5 } else if ngram_effectiveness < 0.3 {
201 0.5 } else {
203 1.0
204 };
205
206 OptimalParameters {
207 ucb1_c,
208 learning_weight: 0.3, ngram_weight,
210 }
211 }
212
213 fn evaluate_ngram_effectiveness(&self) -> f64 {
217 let mut all_rates: Vec<f64> = Vec::new();
218
219 for snapshot in self.snapshots {
220 for &(success, failure) in snapshot.ngram_stats.trigrams.values() {
221 let total = success + failure;
222 if total >= 3 {
223 all_rates.push(success as f64 / total as f64);
225 }
226 }
227 }
228
229 if all_rates.is_empty() {
230 return 0.5; }
232
233 let mean = all_rates.iter().sum::<f64>() / all_rates.len() as f64;
235 let variance =
236 all_rates.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / all_rates.len() as f64;
237
238 (variance / 0.25).min(1.0)
240 }
241
242 pub fn extract_paths(&self) -> Vec<RecommendedPath> {
246 let mut path_stats: HashMap<Vec<String>, (u32, u32)> = HashMap::new();
248
249 for snapshot in self.snapshots {
250 for (key, &(success, failure)) in &snapshot.ngram_stats.trigrams {
251 let path = vec![key.0.clone(), key.1.clone(), key.2.clone()];
252 let entry = path_stats.entry(path).or_insert((0, 0));
253 entry.0 += success;
254 entry.1 += failure;
255 }
256 }
257
258 let mut paths: Vec<RecommendedPath> = path_stats
260 .into_iter()
261 .filter(|(_, (s, f))| s + f >= 5) .map(|(actions, (success, failure))| {
263 let total = success + failure;
264 RecommendedPath {
265 actions,
266 success_rate: success as f64 / total as f64,
267 observations: total,
268 }
269 })
270 .collect();
271
272 paths.sort_by(|a, b| {
273 b.success_rate
274 .partial_cmp(&a.success_rate)
275 .unwrap_or(std::cmp::Ordering::Equal)
276 });
277
278 paths.into_iter().take(10).collect() }
280
281 pub fn analyze_strategy(&self) -> StrategyConfig {
285 if self.snapshots.is_empty() {
286 return StrategyConfig::default();
287 }
288
289 let (total_success, total_failure) = self.snapshots.iter().fold((0u32, 0u32), |acc, s| {
291 (
292 acc.0 + s.episode_transitions.success_episodes,
293 acc.1 + s.episode_transitions.failure_episodes,
294 )
295 });
296
297 let avg_error_rate = if total_success + total_failure > 0 {
298 total_failure as f64 / (total_success + total_failure) as f64
299 } else {
300 0.3
301 };
302
303 let total_actions: u64 = self
305 .snapshots
306 .iter()
307 .map(|s| s.metadata.total_actions as u64)
308 .sum();
309 let avg_actions = total_actions as f64 / self.snapshots.len().max(1) as f64;
310
311 let maturity_threshold = ((avg_actions * 0.1) as u32).clamp(5, 50);
313
314 let initial_strategy = if avg_error_rate > 0.4 {
316 "thompson" } else if avg_error_rate < 0.1 {
318 "greedy" } else {
320 "ucb1" };
322
323 StrategyConfig {
324 maturity_threshold,
325 error_rate_threshold: (avg_error_rate * 1.5).min(0.5), initial_strategy: initial_strategy.to_string(),
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn create_test_snapshot(success: u32, failure: u32) -> LearningSnapshot {
336 let mut snapshot = LearningSnapshot::empty();
337 snapshot.episode_transitions.success_episodes = success;
338 snapshot.episode_transitions.failure_episodes = failure;
339 snapshot.metadata.total_actions = (success + failure) * 5;
340 snapshot
341 }
342
343 #[test]
344 fn test_analyzer_empty_snapshots() {
345 let snapshots: Vec<LearningSnapshot> = vec![];
346 let analyzer = OfflineAnalyzer::new(&snapshots);
347 let model = analyzer.analyze();
348
349 assert_eq!(model.analyzed_sessions, 0);
350 assert!((model.parameters.ucb1_c - std::f64::consts::SQRT_2).abs() < 0.01);
351 }
352
353 #[test]
354 fn test_analyzer_high_success_rate() {
355 let snapshots = vec![
356 create_test_snapshot(9, 1),
357 create_test_snapshot(8, 2),
358 create_test_snapshot(10, 0),
359 ];
360 let analyzer = OfflineAnalyzer::new(&snapshots);
361 let params = analyzer.analyze_parameters();
362
363 assert!(params.ucb1_c < std::f64::consts::SQRT_2);
365 }
366
367 #[test]
368 fn test_analyzer_low_success_rate() {
369 let snapshots = vec![
370 create_test_snapshot(3, 7),
371 create_test_snapshot(4, 6),
372 create_test_snapshot(2, 8),
373 ];
374 let analyzer = OfflineAnalyzer::new(&snapshots);
375 let params = analyzer.analyze_parameters();
376
377 assert!(params.ucb1_c > std::f64::consts::SQRT_2);
379 }
380
381 #[test]
382 fn test_strategy_config_high_error() {
383 let snapshots = vec![create_test_snapshot(3, 7), create_test_snapshot(4, 6)];
384 let analyzer = OfflineAnalyzer::new(&snapshots);
385 let config = analyzer.analyze_strategy();
386
387 assert_eq!(config.initial_strategy, "thompson");
388 }
389
390 #[test]
391 fn test_strategy_config_low_error() {
392 let snapshots = vec![create_test_snapshot(19, 1), create_test_snapshot(18, 2)];
393 let analyzer = OfflineAnalyzer::new(&snapshots);
394 let config = analyzer.analyze_strategy();
395
396 assert_eq!(config.initial_strategy, "greedy");
397 }
398}