tensorrs/utils/
mod.rs

1use crate::linalg::Vector;
2use crate::Float;
3
4mod autodiff;
5
6//pub use autodiff::*;
7
8/// Performs one-hot encoding on a vector of categorical indices.
9///
10/// Given a vector of indices (`data`), this function converts each index into a one-hot encoded vector
11/// of the specified `size`. Each resulting vector has all elements set to zero except for the position
12/// corresponding to the index, which is set to one.
13///
14/// # Parameters
15/// - `data`: Input vector containing categorical indices (0-based).
16/// - `size`: Dimension of the output one-hot vectors. Must be greater than the maximum index in `data`.
17/// - `_`: Phantom parameter to infer the float type `T` (e.g., `f32` or `f64`). Value is ignored.
18/// # Examples
19/// ```
20/// use tensorrs::linalg::Vector;
21/// use tensorrs::utils::one_hot_encoding;
22/// use tensorrs::vector;
23///
24/// let data = vec![2, 0, 3];
25/// let encoded = one_hot_encoding(data, 4, 0.0f32);
26///
27/// assert_eq!(
28///     encoded,
29///     vec![
30///         vector![0.0, 0.0, 1.0, 0.0],
31///         vector![1.0, 0.0, 0.0, 0.0],
32///         vector![0.0, 0.0, 0.0, 1.0]
33///     ]
34/// );
35/// ```
36pub fn one_hot_encoding<T: Float>(data: Vec<usize>, size: usize, _: T) -> Vec<Vector<T>> {
37    assert!(
38        *data.iter().max().unwrap_or(&0usize) < size,
39        "!!!Size of the vector must be greater then max number:\
40     Vector size: {size}, Max Element: {}!!!",
41        data.iter().max().unwrap_or(&0usize)
42    );
43    let mut vectors = Vec::with_capacity(data.len());
44    for i in data {
45        let mut vector = vec![T::default(); size];
46        vector[i] = T::one();
47        vectors.push(Vector::from(vector));
48    }
49    vectors
50}
51
52pub fn one_hot_decoding<T: Float>(data: Vec<Vector<T>>) -> Vec<usize> {
53    data.iter()
54        .map(|x| {
55            x.data
56                .iter()
57                .position(|&x| x == T::one())
58                .unwrap_or_default()
59        })
60        .collect()
61}
62
63#[cfg(test)]
64mod tests {
65    use crate::linalg::Matrix;
66    use crate::utils::{one_hot_decoding, one_hot_encoding};
67    use crate::DataType;
68
69    #[test]
70    fn test_one_hot() {
71        let a = vec![1, 2, 3, 4];
72        println!("{}", Matrix::from(one_hot_encoding(a, 5, DataType::f32())));
73    }
74
75    #[test]
76    fn test_decoding() {
77        let a = vec![1, 2, 3, 4];
78        let one_hot = one_hot_encoding(a, 5, DataType::f32());
79        let one_dec = one_hot_decoding(one_hot);
80
81        println!("{:?}", one_dec);
82    }
83}