Function one_hot_encode

Source
pub fn one_hot_encode<F: Float + Debug>(
    indices: &Array1<usize>,
    num_classes: usize,
) -> Result<Array2<F>>
Expand description

Calculate the one-hot encoding of a vector of indices

§Arguments

  • indices - Vector of class indices
  • num_classes - Number of classes

§Returns

  • A 2D array where each row is a one-hot encoded vector

§Examples

use scirs2_neural::utils::one_hot_encode;
use ndarray::arr1;

let indices = arr1(&[0, 2, 1]);
let one_hot = one_hot_encode::<f64>(&indices, 3).unwrap();

assert_eq!(one_hot.shape(), &[3, 3]);
assert_eq!(one_hot[[0, 0]], 1.0f64); // First sample, class 0
assert_eq!(one_hot[[1, 2]], 1.0f64); // Second sample, class 2
assert_eq!(one_hot[[2, 1]], 1.0f64); // Third sample, class 1