scirs2_neural/utils/
visualization.rs

1use crate::error::{NeuralError, Result};
2use ndarray::Array1;
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::Write;
6use std::path::Path;
7
8/// Represents ASCII plotting options
9#[derive(Clone, Debug)]
10pub struct PlotOptions {
11    /// Width of the plot in characters
12    pub width: usize,
13    /// Height of the plot in characters
14    pub height: usize,
15    /// Maximum number of ticks on x-axis
16    pub max_x_ticks: usize,
17    /// Maximum number of ticks on y-axis  
18    pub max_y_ticks: usize,
19    /// Character to use for the plot line
20    pub line_char: char,
21    /// Character to use for plot points
22    pub point_char: char,
23    /// Character to use for the plot background
24    pub background_char: char,
25    /// Whether to show a grid
26    pub show_grid: bool,
27    /// Whether to show a legend
28    pub show_legend: bool,
29}
30
31impl Default for PlotOptions {
32    fn default() -> Self {
33        Self {
34            width: 80,
35            height: 20,
36            max_x_ticks: 10,
37            max_y_ticks: 5,
38            line_char: '─',
39            point_char: '●',
40            background_char: ' ',
41            show_grid: true,
42            show_legend: true,
43        }
44    }
45}
46
47/// Simple ASCII plotting function for training metrics visualization
48///
49/// # Arguments
50///
51/// * `data` - Map of series names to data arrays
52/// * `title` - Optional title for the plot
53/// * `options` - Optional plot options
54///
55/// # Returns
56///
57/// * `Result<String>` - The ASCII plot as a string
58pub fn ascii_plot<F: num_traits::Float + std::fmt::Display + std::fmt::Debug>(
59    data: &HashMap<String, Vec<F>>,
60    title: Option<&str>,
61    options: Option<PlotOptions>,
62) -> Result<String> {
63    let options = options.unwrap_or_default();
64    let width = options.width;
65    let height = options.height;
66
67    if data.is_empty() {
68        return Err(NeuralError::ValidationError("No data to plot".to_string()));
69    }
70
71    // Find the global min and max for y-axis scaling
72    let mut min_y = F::infinity();
73    let mut max_y = F::neg_infinity();
74    let mut max_len = 0;
75
76    for values in data.values() {
77        if values.is_empty() {
78            continue;
79        }
80
81        max_len = max_len.max(values.len());
82
83        for &v in values {
84            if v.is_finite() {
85                min_y = min_y.min(v);
86                max_y = max_y.max(v);
87            }
88        }
89    }
90
91    if max_len == 0 {
92        return Err(NeuralError::ValidationError(
93            "All data series are empty".to_string(),
94        ));
95    }
96
97    if !min_y.is_finite() || !max_y.is_finite() {
98        return Err(NeuralError::ValidationError(
99            "Data contains non-finite values".to_string(),
100        ));
101    }
102
103    // Add a small margin to the y-range
104    let y_range = max_y - min_y;
105    let margin = y_range * F::from(0.05).unwrap();
106    min_y = min_y - margin;
107    max_y = max_y + margin;
108
109    // If min and max are the same, create a small range
110    if (max_y - min_y).abs() < F::epsilon() {
111        min_y = min_y - F::from(0.5).unwrap();
112        max_y = max_y + F::from(0.5).unwrap();
113    }
114
115    // Create the plot canvas
116    let mut plot = vec![vec![options.background_char; width]; height];
117
118    // Draw grid if enabled
119    if options.show_grid {
120        for (y, row) in plot.iter_mut().enumerate().take(height) {
121            for (x, cell) in row.iter_mut().enumerate().take(width) {
122                if x % (width / options.max_x_ticks.max(1)) == 0
123                    && y % (height / options.max_y_ticks.max(1)) == 0
124                {
125                    *cell = '·';
126                }
127            }
128        }
129    }
130
131    // Draw axes
132    for row in plot.iter_mut().take(height) {
133        row[0] = '│';
134    }
135
136    for x in 0..width {
137        plot[height - 1][x] = '─';
138    }
139
140    plot[height - 1][0] = '└';
141
142    // Plot each series with different symbols
143    let symbols = ['●', '■', '▲', '◆', '★', '✖', '◎'];
144
145    let mut result = String::with_capacity(height * (width + 2) + 100);
146
147    // Add title if provided
148    if let Some(title) = title {
149        let title_padding = (width - title.len()) / 2;
150        result.push_str(&" ".repeat(title_padding));
151        result.push_str(title);
152        result.push('\n');
153        result.push('\n');
154    }
155
156    let mut legend_entries = Vec::new();
157
158    for (i, (name, values)) in data.iter().enumerate() {
159        let symbol = symbols[i % symbols.len()];
160
161        if values.is_empty() {
162            continue;
163        }
164
165        // Store legend entry
166        legend_entries.push((name, symbol));
167
168        // Plot the series
169        for (x_idx, &y_val) in values.iter().enumerate() {
170            if !y_val.is_finite() {
171                continue;
172            }
173
174            let x = ((x_idx as f64) / (max_len as f64 - 1.0) * (width as f64 - 2.0)).round()
175                as usize
176                + 1;
177
178            if x >= width {
179                continue;
180            }
181
182            let y_norm = ((y_val - min_y) / (max_y - min_y)).to_f64().unwrap();
183            let y = height - (y_norm * (height as f64 - 2.0)).round() as usize - 1;
184
185            if y < height {
186                plot[y][x] = symbol;
187            }
188        }
189    }
190
191    // Render the plot
192    let y_ticks = (0..options.max_y_ticks.min(height))
193        .map(|i| {
194            let val = max_y
195                - F::from(i as f64 / (options.max_y_ticks as f64 - 1.0)).unwrap() * (max_y - min_y);
196            format!("{:.2}", val)
197        })
198        .collect::<Vec<_>>();
199
200    let max_y_tick_width = y_ticks.iter().map(|t| t.len()).max().unwrap_or(0);
201
202    for y in 0..height {
203        // Add y-axis ticks for specific rows
204        if y % (height / options.max_y_ticks.max(1)) == 0 && y < y_ticks.len() {
205            let tick = &y_ticks[y];
206            result.push_str(&format!("{:>width$} ", tick, width = max_y_tick_width));
207        } else {
208            result.push_str(&" ".repeat(max_y_tick_width + 1));
209        }
210
211        // Add the plot row
212        for x in 0..width {
213            result.push(plot[y][x]);
214        }
215
216        result.push('\n');
217    }
218
219    // Add x-axis labels
220    result.push_str(&" ".repeat(max_y_tick_width + 1));
221    for i in 0..options.max_x_ticks {
222        let _x = i * width / options.max_x_ticks;
223        let epoch = (i as f64 * (max_len as f64 - 1.0) / (options.max_x_ticks as f64 - 1.0)).round()
224            as usize;
225
226        let tick = format!("{}", epoch);
227        let padding = width / options.max_x_ticks - tick.len();
228        let left_padding = padding / 2;
229        let right_padding = padding - left_padding;
230
231        result.push_str(&" ".repeat(left_padding));
232        result.push_str(&tick);
233        result.push_str(&" ".repeat(right_padding));
234    }
235
236    result.push('\n');
237
238    // Add legend if enabled
239    if options.show_legend && !legend_entries.is_empty() {
240        result.push('\n');
241        result.push_str("Legend: ");
242
243        for (i, (name, symbol)) in legend_entries.iter().enumerate() {
244            if i > 0 {
245                result.push_str(", ");
246            }
247            result.push_str(&format!("{} {}", symbol, name));
248        }
249
250        result.push('\n');
251    }
252
253    Ok(result)
254}
255
256/// Export training history to a CSV file
257///
258/// # Arguments
259///
260/// * `history` - Map of metric names to values
261/// * `filepath` - Path to save the CSV file
262///
263/// # Returns
264///
265/// * `Result<()>` - Result of the operation
266pub fn export_history_to_csv<F: std::fmt::Display>(
267    history: &HashMap<String, Vec<F>>,
268    filepath: impl AsRef<Path>,
269) -> Result<()> {
270    let mut file = File::create(filepath)
271        .map_err(|e| NeuralError::IOError(format!("Failed to create CSV file: {}", e)))?;
272
273    // Find the maximum array length
274    let max_len = history.values().map(|v| v.len()).max().unwrap_or(0);
275
276    // Write header
277    let mut header = String::from("epoch");
278
279    // Get sorted keys for consistent column order
280    let mut keys: Vec<&String> = history.keys().collect();
281    keys.sort();
282
283    for key in keys.iter() {
284        header.push_str(&format!(",{}", key));
285    }
286    header.push('\n');
287
288    file.write_all(header.as_bytes())
289        .map_err(|e| NeuralError::IOError(format!("Failed to write CSV header: {}", e)))?;
290
291    // Write data rows
292    for i in 0..max_len {
293        let mut row = i.to_string();
294
295        // Ensure columns match the header order using the same sorted keys
296        for key in keys.iter() {
297            row.push(',');
298            if let Some(values) = history.get(*key) {
299                if i < values.len() {
300                    row.push_str(&format!("{}", values[i]));
301                }
302            }
303        }
304
305        row.push('\n');
306
307        file.write_all(row.as_bytes())
308            .map_err(|e| NeuralError::IOError(format!("Failed to write CSV row: {}", e)))?;
309    }
310
311    Ok(())
312}
313
314/// Simple utility to generate a learning rate schedule
315pub enum LearningRateSchedule<F: num_traits::Float> {
316    /// Constant learning rate
317    Constant(F),
318    /// Step decay learning rate
319    StepDecay {
320        /// Initial learning rate
321        initial_lr: F,
322        /// Decay factor
323        decay_factor: F,
324        /// Epochs per step
325        step_size: usize,
326    },
327    /// Exponential decay learning rate
328    ExponentialDecay {
329        /// Initial learning rate
330        initial_lr: F,
331        /// Decay factor
332        decay_factor: F,
333    },
334    /// Custom learning rate schedule function
335    Custom(Box<dyn Fn(usize) -> F>),
336}
337
338impl<F: num_traits::Float> LearningRateSchedule<F> {
339    /// Get the learning rate for a given epoch
340    pub fn get_learning_rate(&self, epoch: usize) -> F {
341        match self {
342            Self::Constant(lr) => *lr,
343            Self::StepDecay {
344                initial_lr,
345                decay_factor,
346                step_size,
347            } => {
348                let num_steps = epoch / step_size;
349                *initial_lr * (*decay_factor).powi(num_steps as i32)
350            }
351            Self::ExponentialDecay {
352                initial_lr,
353                decay_factor,
354            } => *initial_lr * (*decay_factor).powi(epoch as i32),
355            Self::Custom(f) => f(epoch),
356        }
357    }
358
359    /// Generate the learning rate schedule for all epochs
360    ///
361    /// # Arguments
362    ///
363    /// * `num_epochs` - Number of epochs
364    ///
365    /// # Returns
366    ///
367    /// * `Array1<F>` - Learning rate for each epoch
368    pub fn generate_schedule(&self, num_epochs: usize) -> Array1<F> {
369        Array1::from_shape_fn(num_epochs, |i| self.get_learning_rate(i))
370    }
371}
372
373/// Analyze training history to find potential issues
374///
375/// # Arguments
376///
377/// * `history` - Map of metric names to values
378///
379/// # Returns
380///
381/// * `Vec<String>` - List of potential issues and suggestions
382pub fn analyze_training_history<F: num_traits::Float + std::fmt::Display>(
383    history: &HashMap<String, Vec<F>>,
384) -> Vec<String> {
385    let mut issues = Vec::new();
386
387    // Check if we have training and validation loss
388    if let (Some(train_loss), Some(val_loss)) = (history.get("train_loss"), history.get("val_loss"))
389    {
390        if train_loss.len() < 2 || val_loss.len() < 2 {
391            return vec!["Not enough epochs to analyze training history.".to_string()];
392        }
393
394        // Check for overfitting
395        let last_train = train_loss.last().unwrap();
396        let last_val = val_loss.last().unwrap();
397
398        if last_val.to_f64().unwrap() > last_train.to_f64().unwrap() * 1.1 {
399            issues.push("Potential overfitting: validation loss is significantly higher than training loss.".to_string());
400            issues.push("  - Try adding regularization (L1, L2, dropout)".to_string());
401            issues.push("  - Consider data augmentation".to_string());
402            issues.push("  - Try reducing model complexity".to_string());
403        }
404
405        // Check for underfitting
406        let last_train_float = last_train.to_f64().unwrap();
407        if last_train_float > 0.1 {
408            issues.push("Potential underfitting: training loss is still high.".to_string());
409            issues.push("  - Try increasing model complexity".to_string());
410            issues.push("  - Train for more epochs".to_string());
411            issues.push("  - Try different optimization algorithms or learning rates".to_string());
412        }
413
414        // Check for unstable training
415        let mut fluctuations = 0;
416        for i in 1..train_loss.len() {
417            if train_loss[i] > train_loss[i - 1] {
418                fluctuations += 1;
419            }
420        }
421
422        let fluctuation_rate = fluctuations as f64 / (train_loss.len() as f64 - 1.0);
423        if fluctuation_rate > 0.3 {
424            issues.push("Unstable training: loss values fluctuate frequently.".to_string());
425            issues.push("  - Try reducing learning rate".to_string());
426            issues.push(
427                "  - Use a different optimizer (Adam usually helps stabilize training)".to_string(),
428            );
429            issues.push("  - Try gradient clipping".to_string());
430        }
431
432        // Check for plateauing
433        if train_loss.len() >= 4 {
434            // Ensure we have enough data points for this analysis
435            let first_half_improvement = train_loss[train_loss.len() / 2].to_f64().unwrap()
436                - train_loss[0].to_f64().unwrap();
437            let second_half_improvement = train_loss.last().unwrap().to_f64().unwrap()
438                - train_loss[train_loss.len() / 2].to_f64().unwrap();
439
440            if second_half_improvement.abs() < first_half_improvement.abs() * 0.2 {
441                issues.push("Training plateau: little improvement in later epochs.".to_string());
442                issues.push("  - Try learning rate scheduling".to_string());
443                issues.push("  - Use early stopping to avoid wasting computation".to_string());
444                issues.push("  - Consider a different optimizer or model architecture".to_string());
445            }
446        }
447
448        // Check for divergent validation loss
449        let mut val_increasing_count = 0;
450        for i in 1..val_loss.len().min(5) {
451            // Look at the last 5 epochs or less
452            if val_loss[val_loss.len() - i] > val_loss[val_loss.len() - i - 1] {
453                val_increasing_count += 1;
454            }
455        }
456
457        if val_increasing_count >= 3 && val_loss.len() >= 5 {
458            issues.push(
459                "Validation loss is increasing in recent epochs, indicating overfitting."
460                    .to_string(),
461            );
462            issues.push("  - Consider stopping training now to prevent overfitting".to_string());
463            issues.push("  - Increase regularization strength".to_string());
464            issues.push("  - Reduce model complexity".to_string());
465        }
466    }
467
468    // Check accuracy trends if available
469    if let Some(accuracy) = history.get("accuracy") {
470        if accuracy.len() >= 3 {
471            let last_accuracy = accuracy.last().unwrap().to_f64().unwrap();
472
473            // Check if accuracy is high
474            if last_accuracy > 0.95 {
475                issues.push("Model has achieved very high accuracy (>95%).".to_string());
476                issues.push(
477                    "  - Consider stopping training or validating on more challenging data"
478                        .to_string(),
479                );
480            }
481
482            // Check for accuracy plateaus
483            if accuracy.len() >= 5 {
484                let recent_change = (accuracy.last().unwrap().to_f64().unwrap()
485                    - accuracy[accuracy.len() - 5].to_f64().unwrap())
486                .abs();
487
488                if recent_change < 0.01 {
489                    issues.push(
490                        "Accuracy has plateaued with minimal improvement in recent epochs."
491                            .to_string(),
492                    );
493                    issues.push("  - Try adjusting learning rate".to_string());
494                    issues.push("  - Consider stopping training to avoid overfitting".to_string());
495                }
496            }
497        }
498    }
499
500    if issues.is_empty() {
501        issues.push("No significant issues detected in the training process.".to_string());
502    }
503
504    issues
505}