1use svod_dtype::DType;
2
3use crate::Tensor;
4use crate::nn::Layer;
5
6type Result<T> = crate::Result<T>;
7
8pub struct Linear {
12 pub weight: Tensor,
13 pub bias: Tensor,
14}
15
16impl Linear {
17 pub fn new(weight: Tensor, bias: Tensor) -> Self {
21 Self { weight, bias }
22 }
23
24 pub fn with_dims(in_features: usize, out_features: usize, dtype: DType) -> Self {
28 let weight_data: Vec<f32> = (0..in_features * out_features).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
29 let weight = Tensor::from_slice(&weight_data)
30 .try_reshape([out_features as isize, in_features as isize])
31 .expect("linear weight reshape failed");
32 let bias = Tensor::full(&[out_features], 0.0, dtype).expect("linear bias creation failed");
33 Self { weight, bias }
34 }
35}
36
37impl Layer for Linear {
38 fn forward(&self, x: &Tensor) -> Result<Tensor> {
39 x.linear().weight(&self.weight).bias(&self.bias).call()
40 }
41}