Skip to main content

yscv_model/layers/
linear.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// Dense linear layer: `y = x @ weight + bias`.
7///
8/// Supports both graph-mode (autograd training) and inference-mode (raw tensors).
9#[derive(Debug, Clone, PartialEq)]
10pub struct LinearLayer {
11    in_features: usize,
12    out_features: usize,
13    weight_tensor: Tensor,
14    bias_tensor: Tensor,
15    weight_node: Option<NodeId>,
16    bias_node: Option<NodeId>,
17}
18
19impl LinearLayer {
20    /// Creates a layer from explicit parameter tensors.
21    pub fn new(
22        graph: &mut Graph,
23        in_features: usize,
24        out_features: usize,
25        weight_init: Tensor,
26        bias_init: Tensor,
27    ) -> Result<Self, ModelError> {
28        let expected_weight = vec![in_features, out_features];
29        if weight_init.shape() != expected_weight {
30            return Err(ModelError::InvalidParameterShape {
31                parameter: "weight",
32                expected: expected_weight,
33                got: weight_init.shape().to_vec(),
34            });
35        }
36        let expected_bias = vec![out_features];
37        if bias_init.shape() != expected_bias {
38            return Err(ModelError::InvalidParameterShape {
39                parameter: "bias",
40                expected: expected_bias,
41                got: bias_init.shape().to_vec(),
42            });
43        }
44
45        let weight_node = graph.variable(weight_init.clone());
46        let bias_node = graph.variable(bias_init.clone());
47        Ok(Self {
48            in_features,
49            out_features,
50            weight_tensor: weight_init,
51            bias_tensor: bias_init,
52            weight_node: Some(weight_node),
53            bias_node: Some(bias_node),
54        })
55    }
56
57    /// Creates a zero-initialized layer.
58    pub fn zero_init(
59        graph: &mut Graph,
60        in_features: usize,
61        out_features: usize,
62    ) -> Result<Self, ModelError> {
63        let weight = Tensor::zeros(vec![in_features, out_features])?;
64        let bias = Tensor::zeros(vec![out_features])?;
65        Self::new(graph, in_features, out_features, weight, bias)
66    }
67
68    /// Synchronizes owned tensors from the graph (e.g. after optimizer step).
69    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
70        if let Some(w_id) = self.weight_node {
71            self.weight_tensor = graph.value(w_id)?.clone();
72        }
73        if let Some(b_id) = self.bias_node {
74            self.bias_tensor = graph.value(b_id)?.clone();
75        }
76        Ok(())
77    }
78
79    /// Graph-mode forward pass (for training with autograd).
80    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
81        let w_id = self
82            .weight_node
83            .ok_or(ModelError::ParamsNotRegistered { layer: "Linear" })?;
84        let b_id = self
85            .bias_node
86            .ok_or(ModelError::ParamsNotRegistered { layer: "Linear" })?;
87        let input_shape = graph.value(input)?.shape().to_vec();
88        if input_shape.len() != 2 || input_shape[1] != self.in_features {
89            return Err(ModelError::InvalidInputShape {
90                expected_features: self.in_features,
91                got: input_shape,
92            });
93        }
94
95        let projected = graph.matmul_2d(input, w_id)?;
96        let output = graph.add(projected, b_id)?;
97        Ok(output)
98    }
99
100    /// Inference-mode forward pass (no graph needed).
101    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
102        if input.rank() != 2 || input.shape()[1] != self.in_features {
103            return Err(ModelError::InvalidInputShape {
104                expected_features: self.in_features,
105                got: input.shape().to_vec(),
106            });
107        }
108        let projected = yscv_kernels::matmul_2d(input, &self.weight_tensor)?;
109        projected.add(&self.bias_tensor).map_err(ModelError::Tensor)
110    }
111
112    pub fn in_features(&self) -> usize {
113        self.in_features
114    }
115
116    pub fn out_features(&self) -> usize {
117        self.out_features
118    }
119
120    pub(crate) fn trainable_nodes(&self) -> Vec<NodeId> {
121        let mut nodes = Vec::new();
122        if let Some(w) = self.weight_node {
123            nodes.push(w);
124        }
125        if let Some(b) = self.bias_node {
126            nodes.push(b);
127        }
128        nodes
129    }
130
131    pub fn weight_node(&self) -> Option<NodeId> {
132        self.weight_node
133    }
134
135    pub fn bias_node(&self) -> Option<NodeId> {
136        self.bias_node
137    }
138
139    pub fn weight(&self) -> &Tensor {
140        &self.weight_tensor
141    }
142
143    pub fn bias(&self) -> &Tensor {
144        &self.bias_tensor
145    }
146}