yscv_model/layers/
linear.rs1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6#[derive(Debug, Clone, PartialEq)]
10pub struct LinearLayer {
11 in_features: usize,
12 out_features: usize,
13 weight_tensor: Tensor,
14 bias_tensor: Tensor,
15 weight_node: Option<NodeId>,
16 bias_node: Option<NodeId>,
17}
18
19impl LinearLayer {
20 pub fn new(
22 graph: &mut Graph,
23 in_features: usize,
24 out_features: usize,
25 weight_init: Tensor,
26 bias_init: Tensor,
27 ) -> Result<Self, ModelError> {
28 let expected_weight = vec![in_features, out_features];
29 if weight_init.shape() != expected_weight {
30 return Err(ModelError::InvalidParameterShape {
31 parameter: "weight",
32 expected: expected_weight,
33 got: weight_init.shape().to_vec(),
34 });
35 }
36 let expected_bias = vec![out_features];
37 if bias_init.shape() != expected_bias {
38 return Err(ModelError::InvalidParameterShape {
39 parameter: "bias",
40 expected: expected_bias,
41 got: bias_init.shape().to_vec(),
42 });
43 }
44
45 let weight_node = graph.variable(weight_init.clone());
46 let bias_node = graph.variable(bias_init.clone());
47 Ok(Self {
48 in_features,
49 out_features,
50 weight_tensor: weight_init,
51 bias_tensor: bias_init,
52 weight_node: Some(weight_node),
53 bias_node: Some(bias_node),
54 })
55 }
56
57 pub fn zero_init(
59 graph: &mut Graph,
60 in_features: usize,
61 out_features: usize,
62 ) -> Result<Self, ModelError> {
63 let weight = Tensor::zeros(vec![in_features, out_features])?;
64 let bias = Tensor::zeros(vec![out_features])?;
65 Self::new(graph, in_features, out_features, weight, bias)
66 }
67
68 pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
70 if let Some(w_id) = self.weight_node {
71 self.weight_tensor = graph.value(w_id)?.clone();
72 }
73 if let Some(b_id) = self.bias_node {
74 self.bias_tensor = graph.value(b_id)?.clone();
75 }
76 Ok(())
77 }
78
79 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
81 let w_id = self
82 .weight_node
83 .ok_or(ModelError::ParamsNotRegistered { layer: "Linear" })?;
84 let b_id = self
85 .bias_node
86 .ok_or(ModelError::ParamsNotRegistered { layer: "Linear" })?;
87 let input_shape = graph.value(input)?.shape().to_vec();
88 if input_shape.len() != 2 || input_shape[1] != self.in_features {
89 return Err(ModelError::InvalidInputShape {
90 expected_features: self.in_features,
91 got: input_shape,
92 });
93 }
94
95 let projected = graph.matmul_2d(input, w_id)?;
96 let output = graph.add(projected, b_id)?;
97 Ok(output)
98 }
99
100 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
102 if input.rank() != 2 || input.shape()[1] != self.in_features {
103 return Err(ModelError::InvalidInputShape {
104 expected_features: self.in_features,
105 got: input.shape().to_vec(),
106 });
107 }
108 let projected = yscv_kernels::matmul_2d(input, &self.weight_tensor)?;
109 projected.add(&self.bias_tensor).map_err(ModelError::Tensor)
110 }
111
112 pub fn in_features(&self) -> usize {
113 self.in_features
114 }
115
116 pub fn out_features(&self) -> usize {
117 self.out_features
118 }
119
120 pub(crate) fn trainable_nodes(&self) -> Vec<NodeId> {
121 let mut nodes = Vec::new();
122 if let Some(w) = self.weight_node {
123 nodes.push(w);
124 }
125 if let Some(b) = self.bias_node {
126 nodes.push(b);
127 }
128 nodes
129 }
130
131 pub fn weight_node(&self) -> Option<NodeId> {
132 self.weight_node
133 }
134
135 pub fn bias_node(&self) -> Option<NodeId> {
136 self.bias_node
137 }
138
139 pub fn weight(&self) -> &Tensor {
140 &self.weight_tensor
141 }
142
143 pub fn bias(&self) -> &Tensor {
144 &self.bias_tensor
145 }
146}