zenu_matrix/operation/
softmax.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Ref, Repr},
5    num::Num,
6};
7
8impl<T: Num, D: Device> Matrix<Ref<&mut T>, DimDyn, D> {
9    #[expect(clippy::missing_panics_doc)]
10    pub fn softmax_assign<R: Repr<Item = T>>(&self, source: &Matrix<R, DimDyn, D>, axis: usize) {
11        assert!(
12            axis < self.shape().len(),
13            "axis must be less than the number of dimensions"
14        );
15        assert!(
16            self.shape().slice() == source.shape().slice(),
17            "softmax shape mismatch"
18        );
19
20        let max_diff = source.to_ref() - source.max_axis(axis, true);
21        let mut output = max_diff.exp();
22        let sum = output.to_ref().sum(axis, true);
23        output /= sum;
24        self.copy_from(&output.to_ref());
25    }
26}
27
28#[cfg(test)]
29mod softmax {
30    #![expect(clippy::unreadable_literal)]
31    use crate::{
32        device::Device,
33        dim::DimDyn,
34        matrix::{Matrix, Owned},
35    };
36
37    fn softmax_1d<D: Device>() {
38        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [4]);
39        let mut b = Matrix::<Owned<f32>, DimDyn, D>::zeros([4]);
40        b.to_ref_mut().softmax_assign(&a, 0);
41        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
42            vec![0.0320586, 0.08714432, 0.23688284, 0.64391428],
43            [4],
44        );
45        let diff = b - ans;
46        assert!(diff.asum() < 1e-6);
47    }
48    #[test]
49    fn softmax_1d_cpu() {
50        softmax_1d::<crate::device::cpu::Cpu>();
51    }
52    #[cfg(feature = "nvidia")]
53    #[test]
54    fn softmax_1d_cuda() {
55        softmax_1d::<crate::device::nvidia::Nvidia>();
56    }
57
58    fn softmax_2d<D: Device>() {
59        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
60        let mut b = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 3]);
61        b.to_ref_mut().softmax_assign(&a, 1);
62        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
63            vec![
64                0.09003057, 0.24472847, 0.66524096, 0.09003057, 0.24472847, 0.66524096,
65            ],
66            [2, 3],
67        );
68        let diff = b - ans;
69        assert!(diff.asum() < 1e-6);
70
71        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
72        let mut b = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 3]);
73        b.to_ref_mut().softmax_assign(&a, 0);
74        let ans_2 = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
75            vec![
76                0.04742587, 0.04742587, 0.04742587, 0.95257413, 0.95257413, 0.95257413,
77            ],
78            [2, 3],
79        );
80        let diff = b - ans_2;
81        assert!(diff.asum() < 1e-6);
82    }
83    #[test]
84    fn softmax_2d_cpu() {
85        softmax_2d::<crate::device::cpu::Cpu>();
86    }
87    #[cfg(feature = "nvidia")]
88    #[test]
89    fn softmax_2d_cuda() {
90        softmax_2d::<crate::device::nvidia::Nvidia>();
91    }
92}