1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn};
5use std::collections::HashMap;
6
7pub trait Model {
13 fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>>;
21
22 fn backward(
31 &self,
32 input: &ArrayView<f64, Ix2>,
33 grad_output: &ArrayView<f64, Ix2>,
34 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
35
36 fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>;
38
39 fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>>;
41
42 fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>);
44
45 fn num_parameters(&self) -> usize {
47 self.parameters().values().map(|p| p.len()).sum()
48 }
49
50 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
52 self.parameters()
53 .iter()
54 .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
55 .collect()
56 }
57
58 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) -> TrainResult<()> {
60 let parameters = self.parameters_mut();
61
62 for (name, values) in state {
63 if let Some(param) = parameters.get_mut(&name) {
64 if param.len() != values.len() {
65 return Err(TrainError::InvalidParameter(format!(
66 "Parameter '{}' size mismatch: expected {}, got {}",
67 name,
68 param.len(),
69 values.len()
70 )));
71 }
72
73 for (p, v) in param.iter_mut().zip(values.iter()) {
74 *p = *v;
75 }
76 } else {
77 return Err(TrainError::InvalidParameter(format!(
78 "Parameter '{}' not found in model",
79 name
80 )));
81 }
82 }
83
84 Ok(())
85 }
86
87 fn reset_parameters(&mut self) {
89 }
92}
93
94pub trait AutodiffModel: Model {
102 fn forward_autodiff(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<()> {
110 let _ = input;
112 Ok(())
113 }
114
115 fn compute_gradients(&self) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
120 Ok(HashMap::new())
122 }
123}
124
125pub trait DynamicModel {
130 fn forward_dynamic(&self, input: &ArrayView<f64, IxDyn>) -> TrainResult<Array<f64, IxDyn>>;
132
133 fn backward_dynamic(
135 &self,
136 input: &ArrayView<f64, IxDyn>,
137 grad_output: &ArrayView<f64, IxDyn>,
138 ) -> TrainResult<HashMap<String, Array<f64, IxDyn>>>;
139}
140
141#[derive(Debug, Clone)]
143pub struct LinearModel {
144 parameters: HashMap<String, Array<f64, Ix2>>,
146 input_dim: usize,
148 output_dim: usize,
150}
151
152impl LinearModel {
153 pub fn new(input_dim: usize, output_dim: usize) -> Self {
159 let mut parameters = HashMap::new();
160
161 let weights = Array::zeros((input_dim, output_dim));
163 let biases = Array::zeros((1, output_dim));
164
165 parameters.insert("weight".to_string(), weights);
166 parameters.insert("bias".to_string(), biases);
167
168 Self {
169 parameters,
170 input_dim,
171 output_dim,
172 }
173 }
174
175 pub fn xavier_init(&mut self) {
177 let limit = (6.0 / (self.input_dim + self.output_dim) as f64).sqrt();
178
179 if let Some(weights) = self.parameters.get_mut("weight") {
180 weights.mapv_inplace(|_| (limit * 2.0 * 0.5) - limit);
182 }
183 }
184
185 pub fn input_dim(&self) -> usize {
187 self.input_dim
188 }
189
190 pub fn output_dim(&self) -> usize {
192 self.output_dim
193 }
194}
195
196impl Model for LinearModel {
197 fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>> {
198 let weights = self
199 .parameters
200 .get("weight")
201 .ok_or_else(|| TrainError::InvalidParameter("weight not found".to_string()))?;
202 let biases = self
203 .parameters
204 .get("bias")
205 .ok_or_else(|| TrainError::InvalidParameter("bias not found".to_string()))?;
206
207 let output = input.dot(weights) + biases;
209 Ok(output)
210 }
211
212 fn backward(
213 &self,
214 input: &ArrayView<f64, Ix2>,
215 grad_output: &ArrayView<f64, Ix2>,
216 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
217 let mut gradients = HashMap::new();
218
219 let grad_weights = input.t().dot(grad_output);
221 gradients.insert("weight".to_string(), grad_weights);
222
223 let grad_biases = grad_output
225 .sum_axis(scirs2_core::ndarray::Axis(0))
226 .insert_axis(scirs2_core::ndarray::Axis(0));
227 gradients.insert("bias".to_string(), grad_biases);
228
229 Ok(gradients)
230 }
231
232 fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>> {
233 &self.parameters
234 }
235
236 fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>> {
237 &mut self.parameters
238 }
239
240 fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>) {
241 self.parameters = parameters;
242 }
243
244 fn reset_parameters(&mut self) {
245 self.xavier_init();
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use scirs2_core::ndarray::arr2;
253
254 #[test]
255 fn test_linear_model_creation() {
256 let model = LinearModel::new(10, 5);
257 assert_eq!(model.input_dim(), 10);
258 assert_eq!(model.output_dim(), 5);
259 assert_eq!(model.parameters().len(), 2);
260 }
261
262 #[test]
263 fn test_linear_model_forward() {
264 let model = LinearModel::new(3, 2);
265 let input = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
266 let output = model.forward(&input.view()).unwrap();
267 assert_eq!(output.shape(), &[2, 2]);
268 }
269
270 #[test]
271 fn test_linear_model_backward() {
272 let model = LinearModel::new(3, 2);
273 let input = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
274 let grad_output = arr2(&[[1.0, 1.0], [1.0, 1.0]]);
275
276 let gradients = model.backward(&input.view(), &grad_output.view()).unwrap();
277
278 assert!(gradients.contains_key("weight"));
279 assert!(gradients.contains_key("bias"));
280 assert_eq!(gradients["weight"].shape(), &[3, 2]);
281 assert_eq!(gradients["bias"].shape(), &[1, 2]);
282 }
283
284 #[test]
285 fn test_model_state_dict() {
286 let model = LinearModel::new(2, 2);
287 let state = model.state_dict();
288 assert_eq!(state.len(), 2);
289 assert!(state.contains_key("weight"));
290 assert!(state.contains_key("bias"));
291 }
292
293 #[test]
294 fn test_model_load_state() {
295 let mut model = LinearModel::new(2, 2);
296 let state = model.state_dict();
297
298 model.parameters_mut().get_mut("weight").unwrap()[[0, 0]] = 99.0;
300
301 model.load_state_dict(state.clone()).unwrap();
303
304 assert_eq!(model.parameters().get("weight").unwrap()[[0, 0]], 0.0);
306 }
307
308 #[test]
309 fn test_num_parameters() {
310 let model = LinearModel::new(10, 5);
311 assert_eq!(model.num_parameters(), 55);
313 }
314}