1use crate::Tensor;
3use std::borrow::Borrow;
4
5#[derive(Debug, Clone, Copy)]
7pub struct LinearConfig {
8 pub ws_init: super::Init,
9 pub bs_init: Option<super::Init>,
10 pub bias: bool,
11}
12
13impl Default for LinearConfig {
14 fn default() -> Self {
15 LinearConfig { ws_init: super::init::DEFAULT_KAIMING_UNIFORM, bs_init: None, bias: true }
16 }
17}
18
19#[derive(Debug)]
21pub struct Linear {
22 pub ws: Tensor,
23 pub bs: Option<Tensor>,
24}
25
26pub fn linear<'a, T: Borrow<super::Path<'a>>>(
28 vs: T,
29 in_dim: i64,
30 out_dim: i64,
31 c: LinearConfig,
32) -> Linear {
33 let vs = vs.borrow();
34 let bs = if c.bias {
35 let bs_init = c.bs_init.unwrap_or_else(|| {
36 let bound = 1.0 / (in_dim as f64).sqrt();
37 super::Init::Uniform { lo: -bound, up: bound }
38 });
39 Some(vs.var("bias", &[out_dim], bs_init))
40 } else {
41 None
42 };
43
44 Linear { ws: vs.var("weight", &[out_dim, in_dim], c.ws_init), bs }
45}
46
47impl super::module::Module for Linear {
48 fn forward(&self, xs: &Tensor) -> Tensor {
49 xs.linear(&self.ws, self.bs.as_ref())
50 }
51}
52
53#[test]
54fn matches_pytorch() {
55 use crate::nn::Module;
56
57 let input = Tensor::read_npy("tests/linear/in.npy").unwrap();
58 let expected_output = Tensor::read_npy("tests/linear/out.npy").unwrap();
59 let ws = Tensor::read_npy("tests/linear/ws.npy").unwrap();
60 let bs = Some(Tensor::read_npy("tests/linear/bs.npy").unwrap());
61
62 let original_output =
63 if let Some(bias) = &bs { input.matmul(&ws.tr()) + bias } else { input.matmul(&ws.tr()) };
64
65 let linear = Linear { ws, bs };
66 let output = linear.forward(&input);
67
68 let delta_output: f32 = (&output - &expected_output).norm().try_into().unwrap();
69 let delta_original: f32 = (&original_output - &expected_output).norm().try_into().unwrap();
70
71 assert!(output.allclose(&expected_output, 1e-5, 1e-8, false));
73 assert!(delta_output <= delta_original);
74}