ruvector_tiny_dancer_core/
model.rs

1//! FastGRNN model implementation
2//!
3//! Lightweight Gated Recurrent Neural Network optimized for inference
4
5use crate::error::{Result, TinyDancerError};
6use ndarray::{Array1, Array2};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10/// FastGRNN model configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct FastGRNNConfig {
13    /// Input dimension
14    pub input_dim: usize,
15    /// Hidden dimension
16    pub hidden_dim: usize,
17    /// Output dimension
18    pub output_dim: usize,
19    /// Gate non-linearity parameter
20    pub nu: f32,
21    /// Hidden non-linearity parameter
22    pub zeta: f32,
23    /// Rank constraint for low-rank factorization
24    pub rank: Option<usize>,
25}
26
27impl Default for FastGRNNConfig {
28    fn default() -> Self {
29        Self {
30            input_dim: 5, // 5 features from feature engineering
31            hidden_dim: 8,
32            output_dim: 1,
33            nu: 1.0,
34            zeta: 1.0,
35            rank: Some(4),
36        }
37    }
38}
39
40/// FastGRNN model for neural routing
41pub struct FastGRNN {
42    config: FastGRNNConfig,
43    /// Weight matrix for reset gate (U_r)
44    w_reset: Array2<f32>,
45    /// Weight matrix for update gate (U_u)
46    w_update: Array2<f32>,
47    /// Weight matrix for candidate (U_c)
48    w_candidate: Array2<f32>,
49    /// Recurrent weight matrix (W)
50    w_recurrent: Array2<f32>,
51    /// Output weight matrix
52    w_output: Array2<f32>,
53    /// Bias for reset gate
54    b_reset: Array1<f32>,
55    /// Bias for update gate
56    b_update: Array1<f32>,
57    /// Bias for candidate
58    b_candidate: Array1<f32>,
59    /// Bias for output
60    b_output: Array1<f32>,
61    /// Whether the model is quantized
62    quantized: bool,
63}
64
65impl FastGRNN {
66    /// Create a new FastGRNN model with the given configuration
67    pub fn new(config: FastGRNNConfig) -> Result<Self> {
68        use rand::Rng;
69        let mut rng = rand::thread_rng();
70
71        // Xavier initialization
72        let w_reset = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
73            rng.gen_range(-0.1..0.1)
74        });
75        let w_update = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
76            rng.gen_range(-0.1..0.1)
77        });
78        let w_candidate = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
79            rng.gen_range(-0.1..0.1)
80        });
81        let w_recurrent = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
82            rng.gen_range(-0.1..0.1)
83        });
84        let w_output = Array2::from_shape_fn((config.output_dim, config.hidden_dim), |_| {
85            rng.gen_range(-0.1..0.1)
86        });
87
88        let b_reset = Array1::zeros(config.hidden_dim);
89        let b_update = Array1::zeros(config.hidden_dim);
90        let b_candidate = Array1::zeros(config.hidden_dim);
91        let b_output = Array1::zeros(config.output_dim);
92
93        Ok(Self {
94            config,
95            w_reset,
96            w_update,
97            w_candidate,
98            w_recurrent,
99            w_output,
100            b_reset,
101            b_update,
102            b_candidate,
103            b_output,
104            quantized: false,
105        })
106    }
107
108    /// Load model from a file (safetensors format)
109    pub fn load<P: AsRef<Path>>(_path: P) -> Result<Self> {
110        // TODO: Implement safetensors loading
111        // For now, return a default model
112        Self::new(FastGRNNConfig::default())
113    }
114
115    /// Save model to a file (safetensors format)
116    pub fn save<P: AsRef<Path>>(&self, _path: P) -> Result<()> {
117        // TODO: Implement safetensors saving
118        Ok(())
119    }
120
121    /// Forward pass through the FastGRNN model
122    ///
123    /// # Arguments
124    /// * `input` - Input vector (sequence of features)
125    /// * `initial_hidden` - Optional initial hidden state
126    ///
127    /// # Returns
128    /// Output score (typically between 0.0 and 1.0 after sigmoid)
129    pub fn forward(&self, input: &[f32], initial_hidden: Option<&[f32]>) -> Result<f32> {
130        if input.len() != self.config.input_dim {
131            return Err(TinyDancerError::InvalidInput(format!(
132                "Expected input dimension {}, got {}",
133                self.config.input_dim,
134                input.len()
135            )));
136        }
137
138        let x = Array1::from_vec(input.to_vec());
139        let mut h = if let Some(hidden) = initial_hidden {
140            Array1::from_vec(hidden.to_vec())
141        } else {
142            Array1::zeros(self.config.hidden_dim)
143        };
144
145        // FastGRNN cell computation
146        // r_t = sigmoid(W_r * x_t + b_r)
147        let r = sigmoid(&(self.w_reset.dot(&x) + &self.b_reset), self.config.nu);
148
149        // u_t = sigmoid(W_u * x_t + b_u)
150        let u = sigmoid(&(self.w_update.dot(&x) + &self.b_update), self.config.nu);
151
152        // c_t = tanh(W_c * x_t + W * (r_t ⊙ h_{t-1}) + b_c)
153        let c = tanh(
154            &(self.w_candidate.dot(&x) + self.w_recurrent.dot(&(&r * &h)) + &self.b_candidate),
155            self.config.zeta,
156        );
157
158        // h_t = u_t ⊙ h_{t-1} + (1 - u_t) ⊙ c_t
159        h = &u * &h + &((Array1::<f32>::ones(u.len()) - &u) * &c);
160
161        // Output: y = W_out * h_t + b_out
162        let output = self.w_output.dot(&h) + &self.b_output;
163
164        // Apply sigmoid to get probability
165        Ok(sigmoid_scalar(output[0]))
166    }
167
168    /// Batch inference for multiple inputs
169    pub fn forward_batch(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
170        inputs
171            .iter()
172            .map(|input| self.forward(input, None))
173            .collect()
174    }
175
176    /// Quantize the model to INT8
177    pub fn quantize(&mut self) -> Result<()> {
178        // TODO: Implement INT8 quantization
179        self.quantized = true;
180        Ok(())
181    }
182
183    /// Apply magnitude-based pruning
184    pub fn prune(&mut self, sparsity: f32) -> Result<()> {
185        if !(0.0..=1.0).contains(&sparsity) {
186            return Err(TinyDancerError::InvalidInput(
187                "Sparsity must be between 0.0 and 1.0".to_string(),
188            ));
189        }
190
191        // TODO: Implement magnitude-based pruning
192        Ok(())
193    }
194
195    /// Get model size in bytes
196    pub fn size_bytes(&self) -> usize {
197        let params = self.w_reset.len()
198            + self.w_update.len()
199            + self.w_candidate.len()
200            + self.w_recurrent.len()
201            + self.w_output.len()
202            + self.b_reset.len()
203            + self.b_update.len()
204            + self.b_candidate.len()
205            + self.b_output.len();
206
207        params * if self.quantized { 1 } else { 4 } // 1 byte for INT8, 4 bytes for f32
208    }
209
210    /// Get configuration
211    pub fn config(&self) -> &FastGRNNConfig {
212        &self.config
213    }
214}
215
216/// Sigmoid activation with scaling parameter
217fn sigmoid(x: &Array1<f32>, scale: f32) -> Array1<f32> {
218    x.mapv(|v| sigmoid_scalar(v * scale))
219}
220
221/// Scalar sigmoid with numerical stability
222fn sigmoid_scalar(x: f32) -> f32 {
223    if x > 0.0 {
224        1.0 / (1.0 + (-x).exp())
225    } else {
226        let ex = x.exp();
227        ex / (1.0 + ex)
228    }
229}
230
231/// Tanh activation with scaling parameter
232fn tanh(x: &Array1<f32>, scale: f32) -> Array1<f32> {
233    x.mapv(|v| (v * scale).tanh())
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_fastgrnn_creation() {
242        let config = FastGRNNConfig::default();
243        let model = FastGRNN::new(config).unwrap();
244        assert!(model.size_bytes() > 0);
245    }
246
247    #[test]
248    fn test_forward_pass() {
249        let config = FastGRNNConfig {
250            input_dim: 10,
251            hidden_dim: 8,
252            output_dim: 1,
253            ..Default::default()
254        };
255        let model = FastGRNN::new(config).unwrap();
256        let input = vec![0.5; 10];
257        let output = model.forward(&input, None).unwrap();
258        assert!(output >= 0.0 && output <= 1.0);
259    }
260
261    #[test]
262    fn test_batch_inference() {
263        let config = FastGRNNConfig {
264            input_dim: 10,
265            ..Default::default()
266        };
267        let model = FastGRNN::new(config).unwrap();
268        let inputs = vec![vec![0.5; 10], vec![0.3; 10], vec![0.8; 10]];
269        let outputs = model.forward_batch(&inputs).unwrap();
270        assert_eq!(outputs.len(), 3);
271    }
272}