Skip to main content

tensorlogic_train/
model.rs

1//! Model interface for training with Tensorlogic.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn};
5use std::collections::HashMap;
6
7/// Trait for trainable models.
8///
9/// This trait defines the interface for models that can be trained with the
10/// Tensorlogic training infrastructure. Models must implement forward and
11/// backward passes, parameter management, and optional save/load functionality.
12pub trait Model {
13    /// Perform a forward pass through the model.
14    ///
15    /// # Arguments
16    /// * `input` - Input tensor
17    ///
18    /// # Returns
19    /// Output tensor from the model
20    fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>>;
21
22    /// Perform a backward pass to compute gradients.
23    ///
24    /// # Arguments
25    /// * `input` - Input tensor used in forward pass
26    /// * `grad_output` - Gradient of loss with respect to model output
27    ///
28    /// # Returns
29    /// Gradients for each model parameter
30    fn backward(
31        &self,
32        input: &ArrayView<f64, Ix2>,
33        grad_output: &ArrayView<f64, Ix2>,
34    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
35
36    /// Get a reference to the model's parameters.
37    fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>;
38
39    /// Get a mutable reference to the model's parameters.
40    fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>>;
41
42    /// Set the model's parameters.
43    fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>);
44
45    /// Get the number of parameters in the model.
46    fn num_parameters(&self) -> usize {
47        self.parameters().values().map(|p| p.len()).sum()
48    }
49
50    /// Save model state to a dictionary.
51    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
52        self.parameters()
53            .iter()
54            .map(|(name, param)| (name.clone(), param.iter().copied().collect()))
55            .collect()
56    }
57
58    /// Load model state from a dictionary.
59    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) -> TrainResult<()> {
60        let parameters = self.parameters_mut();
61
62        for (name, values) in state {
63            if let Some(param) = parameters.get_mut(&name) {
64                if param.len() != values.len() {
65                    return Err(TrainError::InvalidParameter(format!(
66                        "Parameter '{}' size mismatch: expected {}, got {}",
67                        name,
68                        param.len(),
69                        values.len()
70                    )));
71                }
72
73                for (p, v) in param.iter_mut().zip(values.iter()) {
74                    *p = *v;
75                }
76            } else {
77                return Err(TrainError::InvalidParameter(format!(
78                    "Parameter '{}' not found in model",
79                    name
80                )));
81            }
82        }
83
84        Ok(())
85    }
86
87    /// Reset model parameters (optional, for retraining).
88    fn reset_parameters(&mut self) {
89        // Default implementation does nothing
90        // Models can override this to implement custom initialization
91    }
92}
93
94/// Trait for models that support automatic differentiation via scirs2-autograd.
95///
96/// This trait extends the base Model trait with support for training using
97/// SciRS2's automatic differentiation system.
98///
99/// Note: This trait is currently a placeholder for future scirs2-autograd integration.
100/// The actual Variable type will be specified once scirs2-autograd is fully integrated.
101pub trait AutodiffModel: Model {
102    /// Forward pass with autodiff tracking (placeholder).
103    ///
104    /// # Arguments
105    /// * `input` - Input data array
106    ///
107    /// # Returns
108    /// Success indicator (actual implementation will return autodiff Variable)
109    fn forward_autodiff(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<()> {
110        // Placeholder implementation
111        let _ = input;
112        Ok(())
113    }
114
115    /// Compute gradients automatically using backward pass (placeholder).
116    ///
117    /// # Returns
118    /// Gradients for all parameters
119    fn compute_gradients(&self) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
120        // Placeholder implementation
121        Ok(HashMap::new())
122    }
123}
124
125/// Trait for models with dynamic computation graphs.
126///
127/// This extends the model interface to support variable-sized inputs
128/// and dynamic graph construction (e.g., for RNNs, variable-length sequences).
129pub trait DynamicModel {
130    /// Forward pass with dynamic input dimensions.
131    fn forward_dynamic(&self, input: &ArrayView<f64, IxDyn>) -> TrainResult<Array<f64, IxDyn>>;
132
133    /// Backward pass with dynamic input dimensions.
134    fn backward_dynamic(
135        &self,
136        input: &ArrayView<f64, IxDyn>,
137        grad_output: &ArrayView<f64, IxDyn>,
138    ) -> TrainResult<HashMap<String, Array<f64, IxDyn>>>;
139}
140
141/// A simple linear model for testing and demonstration.
142#[derive(Debug, Clone)]
143pub struct LinearModel {
144    /// Model parameters (weights and biases).
145    parameters: HashMap<String, Array<f64, Ix2>>,
146    /// Input dimension.
147    input_dim: usize,
148    /// Output dimension.
149    output_dim: usize,
150}
151
152impl LinearModel {
153    /// Create a new linear model.
154    ///
155    /// # Arguments
156    /// * `input_dim` - Input dimension
157    /// * `output_dim` - Output dimension
158    pub fn new(input_dim: usize, output_dim: usize) -> Self {
159        let mut parameters = HashMap::new();
160
161        // Initialize weights with small random values (simplified)
162        let weights = Array::zeros((input_dim, output_dim));
163        let biases = Array::zeros((1, output_dim));
164
165        parameters.insert("weight".to_string(), weights);
166        parameters.insert("bias".to_string(), biases);
167
168        Self {
169            parameters,
170            input_dim,
171            output_dim,
172        }
173    }
174
175    /// Initialize parameters with Xavier/Glorot uniform initialization.
176    pub fn xavier_init(&mut self) {
177        let limit = (6.0 / (self.input_dim + self.output_dim) as f64).sqrt();
178
179        if let Some(weights) = self.parameters.get_mut("weight") {
180            // Simplified initialization (in practice, use proper random)
181            weights.mapv_inplace(|_| (limit * 2.0 * 0.5) - limit);
182        }
183    }
184
185    /// Get input dimension.
186    pub fn input_dim(&self) -> usize {
187        self.input_dim
188    }
189
190    /// Get output dimension.
191    pub fn output_dim(&self) -> usize {
192        self.output_dim
193    }
194}
195
196impl Model for LinearModel {
197    fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>> {
198        let weights = self
199            .parameters
200            .get("weight")
201            .ok_or_else(|| TrainError::InvalidParameter("weight not found".to_string()))?;
202        let biases = self
203            .parameters
204            .get("bias")
205            .ok_or_else(|| TrainError::InvalidParameter("bias not found".to_string()))?;
206
207        // Linear transformation: Y = X @ W + b
208        let output = input.dot(weights) + biases;
209        Ok(output)
210    }
211
212    fn backward(
213        &self,
214        input: &ArrayView<f64, Ix2>,
215        grad_output: &ArrayView<f64, Ix2>,
216    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
217        let mut gradients = HashMap::new();
218
219        // Gradient w.r.t. weights: dL/dW = X^T @ dL/dY
220        let grad_weights = input.t().dot(grad_output);
221        gradients.insert("weight".to_string(), grad_weights);
222
223        // Gradient w.r.t. biases: dL/db = sum(dL/dY, axis=0)
224        let grad_biases = grad_output
225            .sum_axis(scirs2_core::ndarray::Axis(0))
226            .insert_axis(scirs2_core::ndarray::Axis(0));
227        gradients.insert("bias".to_string(), grad_biases);
228
229        Ok(gradients)
230    }
231
232    fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>> {
233        &self.parameters
234    }
235
236    fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>> {
237        &mut self.parameters
238    }
239
240    fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>) {
241        self.parameters = parameters;
242    }
243
244    fn reset_parameters(&mut self) {
245        self.xavier_init();
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use scirs2_core::ndarray::arr2;
253
254    #[test]
255    fn test_linear_model_creation() {
256        let model = LinearModel::new(10, 5);
257        assert_eq!(model.input_dim(), 10);
258        assert_eq!(model.output_dim(), 5);
259        assert_eq!(model.parameters().len(), 2);
260    }
261
262    #[test]
263    fn test_linear_model_forward() {
264        let model = LinearModel::new(3, 2);
265        let input = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
266        let output = model.forward(&input.view()).unwrap();
267        assert_eq!(output.shape(), &[2, 2]);
268    }
269
270    #[test]
271    fn test_linear_model_backward() {
272        let model = LinearModel::new(3, 2);
273        let input = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
274        let grad_output = arr2(&[[1.0, 1.0], [1.0, 1.0]]);
275
276        let gradients = model.backward(&input.view(), &grad_output.view()).unwrap();
277
278        assert!(gradients.contains_key("weight"));
279        assert!(gradients.contains_key("bias"));
280        assert_eq!(gradients["weight"].shape(), &[3, 2]);
281        assert_eq!(gradients["bias"].shape(), &[1, 2]);
282    }
283
284    #[test]
285    fn test_model_state_dict() {
286        let model = LinearModel::new(2, 2);
287        let state = model.state_dict();
288        assert_eq!(state.len(), 2);
289        assert!(state.contains_key("weight"));
290        assert!(state.contains_key("bias"));
291    }
292
293    #[test]
294    fn test_model_load_state() {
295        let mut model = LinearModel::new(2, 2);
296        let state = model.state_dict();
297
298        // Modify parameters
299        model.parameters_mut().get_mut("weight").unwrap()[[0, 0]] = 99.0;
300
301        // Load original state
302        model.load_state_dict(state.clone()).unwrap();
303
304        // Verify state was restored
305        assert_eq!(model.parameters().get("weight").unwrap()[[0, 0]], 0.0);
306    }
307
308    #[test]
309    fn test_num_parameters() {
310        let model = LinearModel::new(10, 5);
311        // 10*5 weights + 5 biases = 55
312        assert_eq!(model.num_parameters(), 55);
313    }
314}