ruvector_tiny_dancer_core/
model.rs1use crate::error::{Result, TinyDancerError};
6use ndarray::{Array1, Array2};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct FastGRNNConfig {
13 pub input_dim: usize,
15 pub hidden_dim: usize,
17 pub output_dim: usize,
19 pub nu: f32,
21 pub zeta: f32,
23 pub rank: Option<usize>,
25}
26
27impl Default for FastGRNNConfig {
28 fn default() -> Self {
29 Self {
30 input_dim: 5, hidden_dim: 8,
32 output_dim: 1,
33 nu: 1.0,
34 zeta: 1.0,
35 rank: Some(4),
36 }
37 }
38}
39
40pub struct FastGRNN {
42 config: FastGRNNConfig,
43 w_reset: Array2<f32>,
45 w_update: Array2<f32>,
47 w_candidate: Array2<f32>,
49 w_recurrent: Array2<f32>,
51 w_output: Array2<f32>,
53 b_reset: Array1<f32>,
55 b_update: Array1<f32>,
57 b_candidate: Array1<f32>,
59 b_output: Array1<f32>,
61 quantized: bool,
63}
64
65impl FastGRNN {
66 pub fn new(config: FastGRNNConfig) -> Result<Self> {
68 use rand::Rng;
69 let mut rng = rand::thread_rng();
70
71 let w_reset = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
73 rng.gen_range(-0.1..0.1)
74 });
75 let w_update = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
76 rng.gen_range(-0.1..0.1)
77 });
78 let w_candidate = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
79 rng.gen_range(-0.1..0.1)
80 });
81 let w_recurrent = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
82 rng.gen_range(-0.1..0.1)
83 });
84 let w_output = Array2::from_shape_fn((config.output_dim, config.hidden_dim), |_| {
85 rng.gen_range(-0.1..0.1)
86 });
87
88 let b_reset = Array1::zeros(config.hidden_dim);
89 let b_update = Array1::zeros(config.hidden_dim);
90 let b_candidate = Array1::zeros(config.hidden_dim);
91 let b_output = Array1::zeros(config.output_dim);
92
93 Ok(Self {
94 config,
95 w_reset,
96 w_update,
97 w_candidate,
98 w_recurrent,
99 w_output,
100 b_reset,
101 b_update,
102 b_candidate,
103 b_output,
104 quantized: false,
105 })
106 }
107
108 pub fn load<P: AsRef<Path>>(_path: P) -> Result<Self> {
110 Self::new(FastGRNNConfig::default())
113 }
114
115 pub fn save<P: AsRef<Path>>(&self, _path: P) -> Result<()> {
117 Ok(())
119 }
120
121 pub fn forward(&self, input: &[f32], initial_hidden: Option<&[f32]>) -> Result<f32> {
130 if input.len() != self.config.input_dim {
131 return Err(TinyDancerError::InvalidInput(format!(
132 "Expected input dimension {}, got {}",
133 self.config.input_dim,
134 input.len()
135 )));
136 }
137
138 let x = Array1::from_vec(input.to_vec());
139 let mut h = if let Some(hidden) = initial_hidden {
140 Array1::from_vec(hidden.to_vec())
141 } else {
142 Array1::zeros(self.config.hidden_dim)
143 };
144
145 let r = sigmoid(&(self.w_reset.dot(&x) + &self.b_reset), self.config.nu);
148
149 let u = sigmoid(&(self.w_update.dot(&x) + &self.b_update), self.config.nu);
151
152 let c = tanh(
154 &(self.w_candidate.dot(&x) + self.w_recurrent.dot(&(&r * &h)) + &self.b_candidate),
155 self.config.zeta,
156 );
157
158 h = &u * &h + &((Array1::<f32>::ones(u.len()) - &u) * &c);
160
161 let output = self.w_output.dot(&h) + &self.b_output;
163
164 Ok(sigmoid_scalar(output[0]))
166 }
167
168 pub fn forward_batch(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
170 inputs
171 .iter()
172 .map(|input| self.forward(input, None))
173 .collect()
174 }
175
176 pub fn quantize(&mut self) -> Result<()> {
178 self.quantized = true;
180 Ok(())
181 }
182
183 pub fn prune(&mut self, sparsity: f32) -> Result<()> {
185 if !(0.0..=1.0).contains(&sparsity) {
186 return Err(TinyDancerError::InvalidInput(
187 "Sparsity must be between 0.0 and 1.0".to_string(),
188 ));
189 }
190
191 Ok(())
193 }
194
195 pub fn size_bytes(&self) -> usize {
197 let params = self.w_reset.len()
198 + self.w_update.len()
199 + self.w_candidate.len()
200 + self.w_recurrent.len()
201 + self.w_output.len()
202 + self.b_reset.len()
203 + self.b_update.len()
204 + self.b_candidate.len()
205 + self.b_output.len();
206
207 params * if self.quantized { 1 } else { 4 } }
209
210 pub fn config(&self) -> &FastGRNNConfig {
212 &self.config
213 }
214}
215
216fn sigmoid(x: &Array1<f32>, scale: f32) -> Array1<f32> {
218 x.mapv(|v| sigmoid_scalar(v * scale))
219}
220
221fn sigmoid_scalar(x: f32) -> f32 {
223 if x > 0.0 {
224 1.0 / (1.0 + (-x).exp())
225 } else {
226 let ex = x.exp();
227 ex / (1.0 + ex)
228 }
229}
230
231fn tanh(x: &Array1<f32>, scale: f32) -> Array1<f32> {
233 x.mapv(|v| (v * scale).tanh())
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_fastgrnn_creation() {
242 let config = FastGRNNConfig::default();
243 let model = FastGRNN::new(config).unwrap();
244 assert!(model.size_bytes() > 0);
245 }
246
247 #[test]
248 fn test_forward_pass() {
249 let config = FastGRNNConfig {
250 input_dim: 10,
251 hidden_dim: 8,
252 output_dim: 1,
253 ..Default::default()
254 };
255 let model = FastGRNN::new(config).unwrap();
256 let input = vec![0.5; 10];
257 let output = model.forward(&input, None).unwrap();
258 assert!(output >= 0.0 && output <= 1.0);
259 }
260
261 #[test]
262 fn test_batch_inference() {
263 let config = FastGRNNConfig {
264 input_dim: 10,
265 ..Default::default()
266 };
267 let model = FastGRNN::new(config).unwrap();
268 let inputs = vec![vec![0.5; 10], vec![0.3; 10], vec![0.8; 10]];
269 let outputs = model.forward_batch(&inputs).unwrap();
270 assert_eq!(outputs.len(), 3);
271 }
272}