1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait},
4 matrix::{Matrix, Ref, Repr},
5 num::Num,
6};
7
8impl<T: Num, D: Device> Matrix<Ref<&mut T>, DimDyn, D> {
9 #[expect(clippy::missing_panics_doc)]
10 pub fn softmax_assign<R: Repr<Item = T>>(&self, source: &Matrix<R, DimDyn, D>, axis: usize) {
11 assert!(
12 axis < self.shape().len(),
13 "axis must be less than the number of dimensions"
14 );
15 assert!(
16 self.shape().slice() == source.shape().slice(),
17 "softmax shape mismatch"
18 );
19
20 let max_diff = source.to_ref() - source.max_axis(axis, true);
21 let mut output = max_diff.exp();
22 let sum = output.to_ref().sum(axis, true);
23 output /= sum;
24 self.copy_from(&output.to_ref());
25 }
26}
27
28#[cfg(test)]
29mod softmax {
30 #![expect(clippy::unreadable_literal)]
31 use crate::{
32 device::Device,
33 dim::DimDyn,
34 matrix::{Matrix, Owned},
35 };
36
37 fn softmax_1d<D: Device>() {
38 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4.], [4]);
39 let mut b = Matrix::<Owned<f32>, DimDyn, D>::zeros([4]);
40 b.to_ref_mut().softmax_assign(&a, 0);
41 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
42 vec![0.0320586, 0.08714432, 0.23688284, 0.64391428],
43 [4],
44 );
45 let diff = b - ans;
46 assert!(diff.asum() < 1e-6);
47 }
48 #[test]
49 fn softmax_1d_cpu() {
50 softmax_1d::<crate::device::cpu::Cpu>();
51 }
52 #[cfg(feature = "nvidia")]
53 #[test]
54 fn softmax_1d_cuda() {
55 softmax_1d::<crate::device::nvidia::Nvidia>();
56 }
57
58 fn softmax_2d<D: Device>() {
59 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
60 let mut b = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 3]);
61 b.to_ref_mut().softmax_assign(&a, 1);
62 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
63 vec![
64 0.09003057, 0.24472847, 0.66524096, 0.09003057, 0.24472847, 0.66524096,
65 ],
66 [2, 3],
67 );
68 let diff = b - ans;
69 assert!(diff.asum() < 1e-6);
70
71 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
72 let mut b = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 3]);
73 b.to_ref_mut().softmax_assign(&a, 0);
74 let ans_2 = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
75 vec![
76 0.04742587, 0.04742587, 0.04742587, 0.95257413, 0.95257413, 0.95257413,
77 ],
78 [2, 3],
79 );
80 let diff = b - ans_2;
81 assert!(diff.asum() < 1e-6);
82 }
83 #[test]
84 fn softmax_2d_cpu() {
85 softmax_2d::<crate::device::cpu::Cpu>();
86 }
87 #[cfg(feature = "nvidia")]
88 #[test]
89 fn softmax_2d_cuda() {
90 softmax_2d::<crate::device::nvidia::Nvidia>();
91 }
92}