ruvector_sparse_inference/
ops.rs

1//! Basic neural network operations
2
3use std::f32;
4
5/// Linear layer (fully connected)
6#[derive(Debug, Clone)]
7pub struct Linear {
8    pub weight: Vec<Vec<f32>>,  // [out_features, in_features]
9    pub bias: Option<Vec<f32>>,
10    pub in_features: usize,
11    pub out_features: usize,
12}
13
14impl Linear {
15    pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self {
16        Self {
17            weight: vec![vec![0.0; in_features]; out_features],
18            bias: if use_bias {
19                Some(vec![0.0; out_features])
20            } else {
21                None
22            },
23            in_features,
24            out_features,
25        }
26    }
27
28    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
29        let mut output = vec![0.0; self.out_features];
30
31        for i in 0..self.out_features {
32            let mut sum = 0.0;
33            for j in 0..self.in_features.min(input.len()) {
34                sum += self.weight[i][j] * input[j];
35            }
36            if let Some(ref bias) = self.bias {
37                sum += bias[i];
38            }
39            output[i] = sum;
40        }
41
42        output
43    }
44}
45
46/// Embedding layer
47#[derive(Debug, Clone)]
48pub struct Embedding {
49    pub weight: Vec<Vec<f32>>,  // [vocab_size, embedding_dim]
50    pub vocab_size: usize,
51    pub embedding_dim: usize,
52}
53
54impl Embedding {
55    pub fn new(vocab_size: usize, embedding_dim: usize) -> Self {
56        Self {
57            weight: vec![vec![0.0; embedding_dim]; vocab_size],
58            vocab_size,
59            embedding_dim,
60        }
61    }
62
63    pub fn forward(&self, input_ids: &[u64]) -> Vec<f32> {
64        let mut output = Vec::new();
65
66        for &id in input_ids {
67            let idx = id as usize;
68            if idx < self.vocab_size {
69                output.extend_from_slice(&self.weight[idx]);
70            } else {
71                output.extend_from_slice(&vec![0.0; self.embedding_dim]);
72            }
73        }
74
75        output
76    }
77}
78
79/// RMSNorm (Root Mean Square Layer Normalization)
80#[derive(Debug, Clone)]
81pub struct RMSNorm {
82    pub weight: Vec<f32>,
83    pub eps: f32,
84}
85
86impl RMSNorm {
87    pub fn new(dim: usize, eps: f32) -> Self {
88        Self {
89            weight: vec![1.0; dim],
90            eps,
91        }
92    }
93
94    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
95        let mean_square = input.iter().map(|x| x * x).sum::<f32>() / input.len() as f32;
96        let rms = (mean_square + self.eps).sqrt();
97
98        input
99            .iter()
100            .zip(self.weight.iter())
101            .map(|(x, w)| (x / rms) * w)
102            .collect()
103    }
104}
105
106/// LayerNorm
107#[derive(Debug, Clone)]
108pub struct LayerNorm {
109    pub weight: Vec<f32>,
110    pub bias: Vec<f32>,
111    pub eps: f32,
112}
113
114impl LayerNorm {
115    pub fn new(dim: usize, eps: f32) -> Self {
116        Self {
117            weight: vec![1.0; dim],
118            bias: vec![0.0; dim],
119            eps,
120        }
121    }
122
123    pub fn forward(&self, input: &[f32]) -> Vec<f32> {
124        let mean = input.iter().sum::<f32>() / input.len() as f32;
125        let variance = input.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / input.len() as f32;
126        let std = (variance + self.eps).sqrt();
127
128        input
129            .iter()
130            .zip(self.weight.iter().zip(self.bias.iter()))
131            .map(|(x, (w, b))| ((x - mean) / std) * w + b)
132            .collect()
133    }
134}
135
136/// SiLU (Swish) activation function
137pub fn silu(x: f32) -> f32 {
138    x / (1.0 + (-x).exp())
139}
140
141/// GELU activation
142pub fn gelu(x: f32) -> f32 {
143    0.5 * x * (1.0 + ((2.0 / f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
144}
145
146/// ReLU activation
147pub fn relu(x: f32) -> f32 {
148    x.max(0.0)
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_linear() {
157        let mut linear = Linear::new(3, 2, true);
158        linear.weight = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
159        linear.bias = Some(vec![0.1, 0.2]);
160
161        let input = vec![1.0, 2.0, 3.0];
162        let output = linear.forward(&input);
163
164        assert_eq!(output.len(), 2);
165        assert!((output[0] - 14.1).abs() < 1e-5);
166        assert!((output[1] - 32.2).abs() < 1e-5);
167    }
168
169    #[test]
170    fn test_silu() {
171        assert!((silu(0.0) - 0.0).abs() < 1e-5);
172        assert!(silu(1.0) > 0.0);
173        assert!(silu(-1.0) < 0.0);
174    }
175
176    #[test]
177    fn test_rms_norm() {
178        let norm = RMSNorm::new(4, 1e-6);
179        let input = vec![1.0, 2.0, 3.0, 4.0];
180        let output = norm.forward(&input);
181        assert_eq!(output.len(), 4);
182    }
183}