1use super::goldilocks::{Goldilocks, MODULUS};
7use super::PrimeField;
8
9pub const SCALE: u64 = 1 << 16;
11
12const HALF_P: u64 = MODULUS / 2;
14
15fn inv_scale() -> Goldilocks {
17 static INV: std::sync::OnceLock<Goldilocks> = std::sync::OnceLock::new();
18 *INV.get_or_init(|| Goldilocks::from_u64(SCALE).inv().expect("SCALE is nonzero"))
19}
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub struct Fixed(pub Goldilocks);
24
25impl Fixed {
26 pub const ZERO: Self = Self(Goldilocks(0));
27 pub const ONE: Self = Self(Goldilocks(SCALE));
28
29 pub fn from_f64(v: f64) -> Self {
33 let scaled = v * SCALE as f64;
34 if scaled >= 0.0 {
35 Self(Goldilocks::from_u64(scaled.round() as u64))
36 } else {
37 let abs = (-scaled).round() as u64;
39 Self(Goldilocks::from_u64(MODULUS - abs))
40 }
41 }
42
43 pub fn to_f64(self) -> f64 {
47 let raw = self.0.to_u64();
48 if raw <= HALF_P {
49 raw as f64 / SCALE as f64
50 } else {
51 -((MODULUS - raw) as f64 / SCALE as f64)
52 }
53 }
54
55 pub fn raw(self) -> Goldilocks {
57 self.0
58 }
59
60 pub fn from_raw(g: Goldilocks) -> Self {
62 Self(g)
63 }
64
65 #[inline]
67 pub fn add(self, rhs: Self) -> Self {
68 Self(self.0.add(rhs.0))
69 }
70
71 #[inline]
73 pub fn sub(self, rhs: Self) -> Self {
74 Self(self.0.sub(rhs.0))
75 }
76
77 #[inline]
79 pub fn mul(self, rhs: Self) -> Self {
80 Self(self.0.mul(rhs.0).mul(inv_scale()))
81 }
82
83 #[inline]
85 pub fn neg(self) -> Self {
86 Self(self.0.neg())
87 }
88
89 pub fn inv(self) -> Self {
94 let raw_inv = self.0.inv().expect("cannot invert zero");
95 let s = Goldilocks::from_u64(SCALE);
96 Self(raw_inv.mul(s).mul(s))
97 }
98
99 #[inline]
101 pub fn relu(self) -> Self {
102 if self.0.to_u64() <= HALF_P {
103 self
104 } else {
105 Self::ZERO
106 }
107 }
108
109 #[inline]
111 pub fn madd(self, a: Self, b: Self) -> Self {
112 self.add(a.mul(b))
113 }
114}
115
116impl std::fmt::Display for Fixed {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "{:.4}", self.to_f64())
119 }
120}
121
122pub struct RawAccum(pub Goldilocks);
130
131impl RawAccum {
132 #[inline]
133 pub fn zero() -> Self {
134 Self(Goldilocks(0))
135 }
136
137 #[inline]
139 pub fn add_prod(&mut self, a: Fixed, b: Fixed) {
140 self.0 = self.0.add(a.0.mul(b.0));
141 }
142
143 #[inline]
146 pub fn add_bias(&mut self, bias: Fixed) {
147 self.0 = self.0.add(bias.0.mul(Goldilocks(SCALE)));
148 }
149
150 #[inline]
152 pub fn finish(self) -> Fixed {
153 Fixed(self.0.mul(inv_scale()))
154 }
155}
156
157pub fn dot(a: &[Fixed], b: &[Fixed]) -> Fixed {
161 debug_assert_eq!(a.len(), b.len());
162 let mut acc = RawAccum::zero();
163 for i in 0..a.len() {
164 acc.add_prod(a[i], b[i]);
165 }
166 acc.finish()
167}
168
169pub fn matvec(mat: &[Fixed], vec: &[Fixed], cols: usize) -> Vec<Fixed> {
172 let rows = mat.len() / cols;
173 let mut out = Vec::with_capacity(rows);
174 for r in 0..rows {
175 let row = &mat[r * cols..(r + 1) * cols];
176 out.push(dot(row, vec));
177 }
178 out
179}
180
181pub fn relu_vec(v: &mut [Fixed]) {
183 for x in v.iter_mut() {
184 *x = x.relu();
185 }
186}
187
188pub fn layer_norm(v: &mut [Fixed]) {
191 let n = v.len();
192 if n == 0 {
193 return;
194 }
195 let n_fixed = Fixed::from_f64(n as f64);
196
197 let mut sum = Fixed::ZERO;
199 for x in v.iter() {
200 sum = sum.add(*x);
201 }
202 let mean = sum.mul(n_fixed.inv());
203
204 for x in v.iter_mut() {
206 *x = x.sub(mean);
207 }
208
209 let mut var_sum = Fixed::ZERO;
211 for x in v.iter() {
212 var_sum = var_sum.madd(*x, *x);
213 }
214 let variance = var_sum.mul(n_fixed.inv());
215
216 let epsilon = Fixed::from_f64(1e-5);
219 let scale = if variance.to_f64().abs() < epsilon.to_f64() {
220 Fixed::ONE
221 } else {
222 variance.inv()
223 };
224 for x in v.iter_mut() {
228 *x = x.mul(scale);
229 }
230}
231
232#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn roundtrip_positive() {
240 let vals = [0.0, 0.5, 1.0, 0.375, 100.0, 0.001];
241 for &v in &vals {
242 let f = Fixed::from_f64(v);
243 let back = f.to_f64();
244 assert!(
245 (back - v).abs() < 0.001,
246 "roundtrip failed for {}: got {}",
247 v,
248 back
249 );
250 }
251 }
252
253 #[test]
254 fn roundtrip_negative() {
255 let vals = [-0.5, -1.0, -100.0, -0.001];
256 for &v in &vals {
257 let f = Fixed::from_f64(v);
258 let back = f.to_f64();
259 assert!(
260 (back - v).abs() < 0.001,
261 "roundtrip failed for {}: got {}",
262 v,
263 back
264 );
265 }
266 }
267
268 #[test]
269 fn add_commutative() {
270 let a = Fixed::from_f64(0.5);
271 let b = Fixed::from_f64(0.25);
272 assert_eq!(a.add(b), b.add(a));
273 }
274
275 #[test]
276 fn add_values() {
277 let a = Fixed::from_f64(0.5);
278 let b = Fixed::from_f64(0.25);
279 let c = a.add(b);
280 assert!((c.to_f64() - 0.75).abs() < 0.001);
281 }
282
283 #[test]
284 fn sub_values() {
285 let a = Fixed::from_f64(1.0);
286 let b = Fixed::from_f64(0.25);
287 let c = a.sub(b);
288 assert!((c.to_f64() - 0.75).abs() < 0.001);
289 }
290
291 #[test]
292 fn mul_values() {
293 let a = Fixed::from_f64(0.5);
294 let b = Fixed::from_f64(0.5);
295 let c = a.mul(b);
296 assert!(
297 (c.to_f64() - 0.25).abs() < 0.001,
298 "0.5 * 0.5 = {}, expected 0.25",
299 c.to_f64()
300 );
301 }
302
303 #[test]
304 fn mul_negative() {
305 let a = Fixed::from_f64(-0.5);
306 let b = Fixed::from_f64(2.0);
307 let c = a.mul(b);
308 assert!(
309 (c.to_f64() - (-1.0)).abs() < 0.001,
310 "-0.5 * 2.0 = {}, expected -1.0",
311 c.to_f64()
312 );
313 }
314
315 #[test]
316 fn neg_values() {
317 let a = Fixed::from_f64(1.0);
318 let b = a.neg();
319 assert!((b.to_f64() - (-1.0)).abs() < 0.001);
320 assert_eq!(a.add(b), Fixed::ZERO);
321 }
322
323 #[test]
324 fn relu_positive() {
325 let a = Fixed::from_f64(0.5);
326 assert_eq!(a.relu(), a);
327 }
328
329 #[test]
330 fn relu_negative() {
331 let a = Fixed::from_f64(-0.5);
332 assert_eq!(a.relu(), Fixed::ZERO);
333 }
334
335 #[test]
336 fn relu_zero() {
337 assert_eq!(Fixed::ZERO.relu(), Fixed::ZERO);
338 }
339
340 #[test]
341 fn dot_product() {
342 let a = [
343 Fixed::from_f64(1.0),
344 Fixed::from_f64(2.0),
345 Fixed::from_f64(3.0),
346 ];
347 let b = [
348 Fixed::from_f64(4.0),
349 Fixed::from_f64(5.0),
350 Fixed::from_f64(6.0),
351 ];
352 let result = dot(&a, &b);
353 assert!(
355 (result.to_f64() - 32.0).abs() < 0.1,
356 "dot product = {}, expected 32.0",
357 result.to_f64()
358 );
359 }
360
361 #[test]
362 fn one_is_identity() {
363 let a = Fixed::from_f64(42.0);
364 let c = a.mul(Fixed::ONE);
365 assert!(
366 (c.to_f64() - 42.0).abs() < 0.01,
367 "a * 1 = {}, expected 42.0",
368 c.to_f64()
369 );
370 }
371
372 #[test]
373 fn inv_roundtrip() {
374 let a = Fixed::from_f64(4.0);
375 let b = a.inv();
376 let c = a.mul(b);
377 assert!(
378 (c.to_f64() - 1.0).abs() < 0.01,
379 "4 * inv(4) = {}, expected 1.0",
380 c.to_f64()
381 );
382 }
383
384 #[test]
385 fn raw_accum_dot_matches_naive() {
386 let a = [
387 Fixed::from_f64(1.0),
388 Fixed::from_f64(2.0),
389 Fixed::from_f64(3.0),
390 ];
391 let b = [
392 Fixed::from_f64(4.0),
393 Fixed::from_f64(5.0),
394 Fixed::from_f64(6.0),
395 ];
396 let naive = a[0].mul(b[0]).add(a[1].mul(b[1])).add(a[2].mul(b[2]));
397 let fused = dot(&a, &b);
398 assert!(
399 (naive.to_f64() - fused.to_f64()).abs() < 0.1,
400 "naive={}, fused={}",
401 naive.to_f64(),
402 fused.to_f64()
403 );
404 assert!(
405 (fused.to_f64() - 32.0).abs() < 0.1,
406 "fused dot = {}, expected 32.0",
407 fused.to_f64()
408 );
409 }
410
411 #[test]
412 fn raw_accum_with_bias() {
413 let mut acc = RawAccum::zero();
415 acc.add_bias(Fixed::from_f64(10.0));
416 acc.add_prod(Fixed::from_f64(3.0), Fixed::from_f64(4.0));
417 let result = acc.finish();
418 assert!(
419 (result.to_f64() - 22.0).abs() < 0.1,
420 "bias+prod = {}, expected 22.0",
421 result.to_f64()
422 );
423 }
424
425 #[test]
426 fn accumulation_precision() {
427 let small = Fixed::from_f64(0.001);
429 let mut acc = Fixed::ZERO;
430 for _ in 0..1000 {
431 acc = acc.add(small);
432 }
433 assert!(
434 (acc.to_f64() - 1.0).abs() < 0.1,
435 "1000 * 0.001 = {}, expected ~1.0",
436 acc.to_f64()
437 );
438 }
439}