tensorrs/utils/mod.rs
1use crate::linalg::Vector;
2use crate::Float;
3
4mod autodiff;
5
6//pub use autodiff::*;
7
8/// Performs one-hot encoding on a vector of categorical indices.
9///
10/// Given a vector of indices (`data`), this function converts each index into a one-hot encoded vector
11/// of the specified `size`. Each resulting vector has all elements set to zero except for the position
12/// corresponding to the index, which is set to one.
13///
14/// # Parameters
15/// - `data`: Input vector containing categorical indices (0-based).
16/// - `size`: Dimension of the output one-hot vectors. Must be greater than the maximum index in `data`.
17/// - `_`: Phantom parameter to infer the float type `T` (e.g., `f32` or `f64`). Value is ignored.
18/// # Examples
19/// ```
20/// use tensorrs::linalg::Vector;
21/// use tensorrs::utils::one_hot_encoding;
22/// use tensorrs::vector;
23///
24/// let data = vec![2, 0, 3];
25/// let encoded = one_hot_encoding(data, 4, 0.0f32);
26///
27/// assert_eq!(
28/// encoded,
29/// vec![
30/// vector![0.0, 0.0, 1.0, 0.0],
31/// vector![1.0, 0.0, 0.0, 0.0],
32/// vector![0.0, 0.0, 0.0, 1.0]
33/// ]
34/// );
35/// ```
36pub fn one_hot_encoding<T: Float>(data: Vec<usize>, size: usize, _: T) -> Vec<Vector<T>> {
37 assert!(
38 *data.iter().max().unwrap_or(&0usize) < size,
39 "!!!Size of the vector must be greater then max number:\
40 Vector size: {size}, Max Element: {}!!!",
41 data.iter().max().unwrap_or(&0usize)
42 );
43 let mut vectors = Vec::with_capacity(data.len());
44 for i in data {
45 let mut vector = vec![T::default(); size];
46 vector[i] = T::one();
47 vectors.push(Vector::from(vector));
48 }
49 vectors
50}
51
52pub fn one_hot_decoding<T: Float>(data: Vec<Vector<T>>) -> Vec<usize> {
53 data.iter()
54 .map(|x| {
55 x.data
56 .iter()
57 .position(|&x| x == T::one())
58 .unwrap_or_default()
59 })
60 .collect()
61}
62
63#[cfg(test)]
64mod tests {
65 use crate::linalg::Matrix;
66 use crate::utils::{one_hot_decoding, one_hot_encoding};
67 use crate::DataType;
68
69 #[test]
70 fn test_one_hot() {
71 let a = vec![1, 2, 3, 4];
72 println!("{}", Matrix::from(one_hot_encoding(a, 5, DataType::f32())));
73 }
74
75 #[test]
76 fn test_decoding() {
77 let a = vec![1, 2, 3, 4];
78 let one_hot = one_hot_encoding(a, 5, DataType::f32());
79 let one_dec = one_hot_decoding(one_hot);
80
81 println!("{:?}", one_dec);
82 }
83}