Skip to main content

svod_tensor/nn/
linear.rs

1use svod_dtype::DType;
2
3use crate::Tensor;
4use crate::nn::Layer;
5
6type Result<T> = crate::Result<T>;
7
8/// Fully connected layer: `y = x @ weight.T + bias`.
9///
10/// Weight shape: `[out_features, in_features]`, bias shape: `[out_features]`.
11pub struct Linear {
12    pub weight: Tensor,
13    pub bias: Tensor,
14}
15
16impl Linear {
17    /// Create a linear layer from existing weight and bias tensors.
18    ///
19    /// Weight must have shape `[out_features, in_features]`, bias must have shape `[out_features]`.
20    pub fn new(weight: Tensor, bias: Tensor) -> Self {
21        Self { weight, bias }
22    }
23
24    /// Create a linear layer with deterministic initialization using `sin()`.
25    ///
26    /// Weight shape: `[out_features, in_features]`, bias: zeros.
27    pub fn with_dims(in_features: usize, out_features: usize, dtype: DType) -> Self {
28        let weight_data: Vec<f32> = (0..in_features * out_features).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
29        let weight = Tensor::from_slice(&weight_data)
30            .try_reshape([out_features as isize, in_features as isize])
31            .expect("linear weight reshape failed");
32        let bias = Tensor::full(&[out_features], 0.0, dtype).expect("linear bias creation failed");
33        Self { weight, bias }
34    }
35}
36
37impl Layer for Linear {
38    fn forward(&self, x: &Tensor) -> Result<Tensor> {
39        x.linear().weight(&self.weight).bias(&self.bias).call()
40    }
41}