ruvector_sparse_inference/
ops.rs1use std::f32;
4
5#[derive(Debug, Clone)]
7pub struct Linear {
8 pub weight: Vec<Vec<f32>>, pub bias: Option<Vec<f32>>,
10 pub in_features: usize,
11 pub out_features: usize,
12}
13
14impl Linear {
15 pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self {
16 Self {
17 weight: vec![vec![0.0; in_features]; out_features],
18 bias: if use_bias {
19 Some(vec![0.0; out_features])
20 } else {
21 None
22 },
23 in_features,
24 out_features,
25 }
26 }
27
28 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
29 let mut output = vec![0.0; self.out_features];
30
31 for i in 0..self.out_features {
32 let mut sum = 0.0;
33 for j in 0..self.in_features.min(input.len()) {
34 sum += self.weight[i][j] * input[j];
35 }
36 if let Some(ref bias) = self.bias {
37 sum += bias[i];
38 }
39 output[i] = sum;
40 }
41
42 output
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct Embedding {
49 pub weight: Vec<Vec<f32>>, pub vocab_size: usize,
51 pub embedding_dim: usize,
52}
53
54impl Embedding {
55 pub fn new(vocab_size: usize, embedding_dim: usize) -> Self {
56 Self {
57 weight: vec![vec![0.0; embedding_dim]; vocab_size],
58 vocab_size,
59 embedding_dim,
60 }
61 }
62
63 pub fn forward(&self, input_ids: &[u64]) -> Vec<f32> {
64 let mut output = Vec::new();
65
66 for &id in input_ids {
67 let idx = id as usize;
68 if idx < self.vocab_size {
69 output.extend_from_slice(&self.weight[idx]);
70 } else {
71 output.extend_from_slice(&vec![0.0; self.embedding_dim]);
72 }
73 }
74
75 output
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct RMSNorm {
82 pub weight: Vec<f32>,
83 pub eps: f32,
84}
85
86impl RMSNorm {
87 pub fn new(dim: usize, eps: f32) -> Self {
88 Self {
89 weight: vec![1.0; dim],
90 eps,
91 }
92 }
93
94 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
95 let mean_square = input.iter().map(|x| x * x).sum::<f32>() / input.len() as f32;
96 let rms = (mean_square + self.eps).sqrt();
97
98 input
99 .iter()
100 .zip(self.weight.iter())
101 .map(|(x, w)| (x / rms) * w)
102 .collect()
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct LayerNorm {
109 pub weight: Vec<f32>,
110 pub bias: Vec<f32>,
111 pub eps: f32,
112}
113
114impl LayerNorm {
115 pub fn new(dim: usize, eps: f32) -> Self {
116 Self {
117 weight: vec![1.0; dim],
118 bias: vec![0.0; dim],
119 eps,
120 }
121 }
122
123 pub fn forward(&self, input: &[f32]) -> Vec<f32> {
124 let mean = input.iter().sum::<f32>() / input.len() as f32;
125 let variance = input.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / input.len() as f32;
126 let std = (variance + self.eps).sqrt();
127
128 input
129 .iter()
130 .zip(self.weight.iter().zip(self.bias.iter()))
131 .map(|(x, (w, b))| ((x - mean) / std) * w + b)
132 .collect()
133 }
134}
135
136pub fn silu(x: f32) -> f32 {
138 x / (1.0 + (-x).exp())
139}
140
141pub fn gelu(x: f32) -> f32 {
143 0.5 * x * (1.0 + ((2.0 / f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
144}
145
146pub fn relu(x: f32) -> f32 {
148 x.max(0.0)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_linear() {
157 let mut linear = Linear::new(3, 2, true);
158 linear.weight = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
159 linear.bias = Some(vec![0.1, 0.2]);
160
161 let input = vec![1.0, 2.0, 3.0];
162 let output = linear.forward(&input);
163
164 assert_eq!(output.len(), 2);
165 assert!((output[0] - 14.1).abs() < 1e-5);
166 assert!((output[1] - 32.2).abs() < 1e-5);
167 }
168
169 #[test]
170 fn test_silu() {
171 assert!((silu(0.0) - 0.0).abs() < 1e-5);
172 assert!(silu(1.0) > 0.0);
173 assert!(silu(-1.0) < 0.0);
174 }
175
176 #[test]
177 fn test_rms_norm() {
178 let norm = RMSNorm::new(4, 1e-6);
179 let input = vec![1.0, 2.0, 3.0, 4.0];
180 let output = norm.forward(&input);
181 assert_eq!(output.len(), 4);
182 }
183}