1use crate::error::{NeuralError, Result};
2use ndarray::Array1;
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::Write;
6use std::path::Path;
7
8#[derive(Clone, Debug)]
10pub struct PlotOptions {
11 pub width: usize,
13 pub height: usize,
15 pub max_x_ticks: usize,
17 pub max_y_ticks: usize,
19 pub line_char: char,
21 pub point_char: char,
23 pub background_char: char,
25 pub show_grid: bool,
27 pub show_legend: bool,
29}
30
31impl Default for PlotOptions {
32 fn default() -> Self {
33 Self {
34 width: 80,
35 height: 20,
36 max_x_ticks: 10,
37 max_y_ticks: 5,
38 line_char: '─',
39 point_char: '●',
40 background_char: ' ',
41 show_grid: true,
42 show_legend: true,
43 }
44 }
45}
46
47pub fn ascii_plot<F: num_traits::Float + std::fmt::Display + std::fmt::Debug>(
59 data: &HashMap<String, Vec<F>>,
60 title: Option<&str>,
61 options: Option<PlotOptions>,
62) -> Result<String> {
63 let options = options.unwrap_or_default();
64 let width = options.width;
65 let height = options.height;
66
67 if data.is_empty() {
68 return Err(NeuralError::ValidationError("No data to plot".to_string()));
69 }
70
71 let mut min_y = F::infinity();
73 let mut max_y = F::neg_infinity();
74 let mut max_len = 0;
75
76 for values in data.values() {
77 if values.is_empty() {
78 continue;
79 }
80
81 max_len = max_len.max(values.len());
82
83 for &v in values {
84 if v.is_finite() {
85 min_y = min_y.min(v);
86 max_y = max_y.max(v);
87 }
88 }
89 }
90
91 if max_len == 0 {
92 return Err(NeuralError::ValidationError(
93 "All data series are empty".to_string(),
94 ));
95 }
96
97 if !min_y.is_finite() || !max_y.is_finite() {
98 return Err(NeuralError::ValidationError(
99 "Data contains non-finite values".to_string(),
100 ));
101 }
102
103 let y_range = max_y - min_y;
105 let margin = y_range * F::from(0.05).unwrap();
106 min_y = min_y - margin;
107 max_y = max_y + margin;
108
109 if (max_y - min_y).abs() < F::epsilon() {
111 min_y = min_y - F::from(0.5).unwrap();
112 max_y = max_y + F::from(0.5).unwrap();
113 }
114
115 let mut plot = vec![vec![options.background_char; width]; height];
117
118 if options.show_grid {
120 for (y, row) in plot.iter_mut().enumerate().take(height) {
121 for (x, cell) in row.iter_mut().enumerate().take(width) {
122 if x % (width / options.max_x_ticks.max(1)) == 0
123 && y % (height / options.max_y_ticks.max(1)) == 0
124 {
125 *cell = '·';
126 }
127 }
128 }
129 }
130
131 for row in plot.iter_mut().take(height) {
133 row[0] = '│';
134 }
135
136 for x in 0..width {
137 plot[height - 1][x] = '─';
138 }
139
140 plot[height - 1][0] = '└';
141
142 let symbols = ['●', '■', '▲', '◆', '★', '✖', '◎'];
144
145 let mut result = String::with_capacity(height * (width + 2) + 100);
146
147 if let Some(title) = title {
149 let title_padding = (width - title.len()) / 2;
150 result.push_str(&" ".repeat(title_padding));
151 result.push_str(title);
152 result.push('\n');
153 result.push('\n');
154 }
155
156 let mut legend_entries = Vec::new();
157
158 for (i, (name, values)) in data.iter().enumerate() {
159 let symbol = symbols[i % symbols.len()];
160
161 if values.is_empty() {
162 continue;
163 }
164
165 legend_entries.push((name, symbol));
167
168 for (x_idx, &y_val) in values.iter().enumerate() {
170 if !y_val.is_finite() {
171 continue;
172 }
173
174 let x = ((x_idx as f64) / (max_len as f64 - 1.0) * (width as f64 - 2.0)).round()
175 as usize
176 + 1;
177
178 if x >= width {
179 continue;
180 }
181
182 let y_norm = ((y_val - min_y) / (max_y - min_y)).to_f64().unwrap();
183 let y = height - (y_norm * (height as f64 - 2.0)).round() as usize - 1;
184
185 if y < height {
186 plot[y][x] = symbol;
187 }
188 }
189 }
190
191 let y_ticks = (0..options.max_y_ticks.min(height))
193 .map(|i| {
194 let val = max_y
195 - F::from(i as f64 / (options.max_y_ticks as f64 - 1.0)).unwrap() * (max_y - min_y);
196 format!("{:.2}", val)
197 })
198 .collect::<Vec<_>>();
199
200 let max_y_tick_width = y_ticks.iter().map(|t| t.len()).max().unwrap_or(0);
201
202 for y in 0..height {
203 if y % (height / options.max_y_ticks.max(1)) == 0 && y < y_ticks.len() {
205 let tick = &y_ticks[y];
206 result.push_str(&format!("{:>width$} ", tick, width = max_y_tick_width));
207 } else {
208 result.push_str(&" ".repeat(max_y_tick_width + 1));
209 }
210
211 for x in 0..width {
213 result.push(plot[y][x]);
214 }
215
216 result.push('\n');
217 }
218
219 result.push_str(&" ".repeat(max_y_tick_width + 1));
221 for i in 0..options.max_x_ticks {
222 let _x = i * width / options.max_x_ticks;
223 let epoch = (i as f64 * (max_len as f64 - 1.0) / (options.max_x_ticks as f64 - 1.0)).round()
224 as usize;
225
226 let tick = format!("{}", epoch);
227 let padding = width / options.max_x_ticks - tick.len();
228 let left_padding = padding / 2;
229 let right_padding = padding - left_padding;
230
231 result.push_str(&" ".repeat(left_padding));
232 result.push_str(&tick);
233 result.push_str(&" ".repeat(right_padding));
234 }
235
236 result.push('\n');
237
238 if options.show_legend && !legend_entries.is_empty() {
240 result.push('\n');
241 result.push_str("Legend: ");
242
243 for (i, (name, symbol)) in legend_entries.iter().enumerate() {
244 if i > 0 {
245 result.push_str(", ");
246 }
247 result.push_str(&format!("{} {}", symbol, name));
248 }
249
250 result.push('\n');
251 }
252
253 Ok(result)
254}
255
256pub fn export_history_to_csv<F: std::fmt::Display>(
267 history: &HashMap<String, Vec<F>>,
268 filepath: impl AsRef<Path>,
269) -> Result<()> {
270 let mut file = File::create(filepath)
271 .map_err(|e| NeuralError::IOError(format!("Failed to create CSV file: {}", e)))?;
272
273 let max_len = history.values().map(|v| v.len()).max().unwrap_or(0);
275
276 let mut header = String::from("epoch");
278
279 let mut keys: Vec<&String> = history.keys().collect();
281 keys.sort();
282
283 for key in keys.iter() {
284 header.push_str(&format!(",{}", key));
285 }
286 header.push('\n');
287
288 file.write_all(header.as_bytes())
289 .map_err(|e| NeuralError::IOError(format!("Failed to write CSV header: {}", e)))?;
290
291 for i in 0..max_len {
293 let mut row = i.to_string();
294
295 for key in keys.iter() {
297 row.push(',');
298 if let Some(values) = history.get(*key) {
299 if i < values.len() {
300 row.push_str(&format!("{}", values[i]));
301 }
302 }
303 }
304
305 row.push('\n');
306
307 file.write_all(row.as_bytes())
308 .map_err(|e| NeuralError::IOError(format!("Failed to write CSV row: {}", e)))?;
309 }
310
311 Ok(())
312}
313
314pub enum LearningRateSchedule<F: num_traits::Float> {
316 Constant(F),
318 StepDecay {
320 initial_lr: F,
322 decay_factor: F,
324 step_size: usize,
326 },
327 ExponentialDecay {
329 initial_lr: F,
331 decay_factor: F,
333 },
334 Custom(Box<dyn Fn(usize) -> F>),
336}
337
338impl<F: num_traits::Float> LearningRateSchedule<F> {
339 pub fn get_learning_rate(&self, epoch: usize) -> F {
341 match self {
342 Self::Constant(lr) => *lr,
343 Self::StepDecay {
344 initial_lr,
345 decay_factor,
346 step_size,
347 } => {
348 let num_steps = epoch / step_size;
349 *initial_lr * (*decay_factor).powi(num_steps as i32)
350 }
351 Self::ExponentialDecay {
352 initial_lr,
353 decay_factor,
354 } => *initial_lr * (*decay_factor).powi(epoch as i32),
355 Self::Custom(f) => f(epoch),
356 }
357 }
358
359 pub fn generate_schedule(&self, num_epochs: usize) -> Array1<F> {
369 Array1::from_shape_fn(num_epochs, |i| self.get_learning_rate(i))
370 }
371}
372
373pub fn analyze_training_history<F: num_traits::Float + std::fmt::Display>(
383 history: &HashMap<String, Vec<F>>,
384) -> Vec<String> {
385 let mut issues = Vec::new();
386
387 if let (Some(train_loss), Some(val_loss)) = (history.get("train_loss"), history.get("val_loss"))
389 {
390 if train_loss.len() < 2 || val_loss.len() < 2 {
391 return vec!["Not enough epochs to analyze training history.".to_string()];
392 }
393
394 let last_train = train_loss.last().unwrap();
396 let last_val = val_loss.last().unwrap();
397
398 if last_val.to_f64().unwrap() > last_train.to_f64().unwrap() * 1.1 {
399 issues.push("Potential overfitting: validation loss is significantly higher than training loss.".to_string());
400 issues.push(" - Try adding regularization (L1, L2, dropout)".to_string());
401 issues.push(" - Consider data augmentation".to_string());
402 issues.push(" - Try reducing model complexity".to_string());
403 }
404
405 let last_train_float = last_train.to_f64().unwrap();
407 if last_train_float > 0.1 {
408 issues.push("Potential underfitting: training loss is still high.".to_string());
409 issues.push(" - Try increasing model complexity".to_string());
410 issues.push(" - Train for more epochs".to_string());
411 issues.push(" - Try different optimization algorithms or learning rates".to_string());
412 }
413
414 let mut fluctuations = 0;
416 for i in 1..train_loss.len() {
417 if train_loss[i] > train_loss[i - 1] {
418 fluctuations += 1;
419 }
420 }
421
422 let fluctuation_rate = fluctuations as f64 / (train_loss.len() as f64 - 1.0);
423 if fluctuation_rate > 0.3 {
424 issues.push("Unstable training: loss values fluctuate frequently.".to_string());
425 issues.push(" - Try reducing learning rate".to_string());
426 issues.push(
427 " - Use a different optimizer (Adam usually helps stabilize training)".to_string(),
428 );
429 issues.push(" - Try gradient clipping".to_string());
430 }
431
432 if train_loss.len() >= 4 {
434 let first_half_improvement = train_loss[train_loss.len() / 2].to_f64().unwrap()
436 - train_loss[0].to_f64().unwrap();
437 let second_half_improvement = train_loss.last().unwrap().to_f64().unwrap()
438 - train_loss[train_loss.len() / 2].to_f64().unwrap();
439
440 if second_half_improvement.abs() < first_half_improvement.abs() * 0.2 {
441 issues.push("Training plateau: little improvement in later epochs.".to_string());
442 issues.push(" - Try learning rate scheduling".to_string());
443 issues.push(" - Use early stopping to avoid wasting computation".to_string());
444 issues.push(" - Consider a different optimizer or model architecture".to_string());
445 }
446 }
447
448 let mut val_increasing_count = 0;
450 for i in 1..val_loss.len().min(5) {
451 if val_loss[val_loss.len() - i] > val_loss[val_loss.len() - i - 1] {
453 val_increasing_count += 1;
454 }
455 }
456
457 if val_increasing_count >= 3 && val_loss.len() >= 5 {
458 issues.push(
459 "Validation loss is increasing in recent epochs, indicating overfitting."
460 .to_string(),
461 );
462 issues.push(" - Consider stopping training now to prevent overfitting".to_string());
463 issues.push(" - Increase regularization strength".to_string());
464 issues.push(" - Reduce model complexity".to_string());
465 }
466 }
467
468 if let Some(accuracy) = history.get("accuracy") {
470 if accuracy.len() >= 3 {
471 let last_accuracy = accuracy.last().unwrap().to_f64().unwrap();
472
473 if last_accuracy > 0.95 {
475 issues.push("Model has achieved very high accuracy (>95%).".to_string());
476 issues.push(
477 " - Consider stopping training or validating on more challenging data"
478 .to_string(),
479 );
480 }
481
482 if accuracy.len() >= 5 {
484 let recent_change = (accuracy.last().unwrap().to_f64().unwrap()
485 - accuracy[accuracy.len() - 5].to_f64().unwrap())
486 .abs();
487
488 if recent_change < 0.01 {
489 issues.push(
490 "Accuracy has plateaued with minimal improvement in recent epochs."
491 .to_string(),
492 );
493 issues.push(" - Try adjusting learning rate".to_string());
494 issues.push(" - Consider stopping training to avoid overfitting".to_string());
495 }
496 }
497 }
498 }
499
500 if issues.is_empty() {
501 issues.push("No significant issues detected in the training process.".to_string());
502 }
503
504 issues
505}