Skip to main content

trident/field/
fixed.rs

1//! Fixed-point arithmetic in the Goldilocks field.
2//!
3//! Scale factor S = 2^16 = 65536. Real values encoded as field elements.
4//! Multiply with rescale: (a * b) * inv(S). 16-bit fractional precision.
5
6use super::goldilocks::{Goldilocks, MODULUS};
7use super::PrimeField;
8
9/// Scale factor: 2^16 = 65536.
10pub const SCALE: u64 = 1 << 16;
11
12/// Half the field modulus — values above this are "negative".
13const HALF_P: u64 = MODULUS / 2;
14
15/// Precomputed inverse of the scale factor: inv(65536) mod p.
16fn inv_scale() -> Goldilocks {
17    static INV: std::sync::OnceLock<Goldilocks> = std::sync::OnceLock::new();
18    *INV.get_or_init(|| Goldilocks::from_u64(SCALE).inv().expect("SCALE is nonzero"))
19}
20
21/// Fixed-point value in Goldilocks field (scale factor 2^16).
22#[derive(Clone, Copy, Debug, PartialEq, Eq)]
23pub struct Fixed(pub Goldilocks);
24
25impl Fixed {
26    pub const ZERO: Self = Self(Goldilocks(0));
27    pub const ONE: Self = Self(Goldilocks(SCALE));
28
29    /// Encode an f64 as a fixed-point field element.
30    ///
31    /// Negative values map to the upper half of the field.
32    pub fn from_f64(v: f64) -> Self {
33        let scaled = v * SCALE as f64;
34        if scaled >= 0.0 {
35            Self(Goldilocks::from_u64(scaled.round() as u64))
36        } else {
37            // Negative: p - |scaled|
38            let abs = (-scaled).round() as u64;
39            Self(Goldilocks::from_u64(MODULUS - abs))
40        }
41    }
42
43    /// Decode a fixed-point field element back to f64.
44    ///
45    /// Values in the upper half of the field are treated as negative.
46    pub fn to_f64(self) -> f64 {
47        let raw = self.0.to_u64();
48        if raw <= HALF_P {
49            raw as f64 / SCALE as f64
50        } else {
51            -((MODULUS - raw) as f64 / SCALE as f64)
52        }
53    }
54
55    /// Raw field element access.
56    pub fn raw(self) -> Goldilocks {
57        self.0
58    }
59
60    /// Construct from a raw Goldilocks element (already scaled).
61    pub fn from_raw(g: Goldilocks) -> Self {
62        Self(g)
63    }
64
65    /// Fixed-point addition (field add, no rescale needed).
66    #[inline]
67    pub fn add(self, rhs: Self) -> Self {
68        Self(self.0.add(rhs.0))
69    }
70
71    /// Fixed-point subtraction (field sub, no rescale needed).
72    #[inline]
73    pub fn sub(self, rhs: Self) -> Self {
74        Self(self.0.sub(rhs.0))
75    }
76
77    /// Fixed-point multiplication: (a * b) * inv(S).
78    #[inline]
79    pub fn mul(self, rhs: Self) -> Self {
80        Self(self.0.mul(rhs.0).mul(inv_scale()))
81    }
82
83    /// Additive inverse.
84    #[inline]
85    pub fn neg(self) -> Self {
86        Self(self.0.neg())
87    }
88
89    /// Multiplicative inverse: result * self = ONE.
90    ///
91    /// self encodes real value v = self.0 / S.
92    /// We want 1/v = S / self.0, encoded as fixed-point: S^2 / self.0 = S^2 * inv(self.0).
93    pub fn inv(self) -> Self {
94        let raw_inv = self.0.inv().expect("cannot invert zero");
95        let s = Goldilocks::from_u64(SCALE);
96        Self(raw_inv.mul(s).mul(s))
97    }
98
99    /// ReLU: if value is "positive" (< p/2) return self, else zero.
100    #[inline]
101    pub fn relu(self) -> Self {
102        if self.0.to_u64() <= HALF_P {
103            self
104        } else {
105            Self::ZERO
106        }
107    }
108
109    /// Multiply-accumulate: self + a * b (fused, one rescale).
110    #[inline]
111    pub fn madd(self, a: Self, b: Self) -> Self {
112        self.add(a.mul(b))
113    }
114}
115
116impl std::fmt::Display for Fixed {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{:.4}", self.to_f64())
119    }
120}
121
122// ─── Fused Dot Product ─────────────────────────────────────────────
123
124/// Raw accumulator for fused dot products.
125///
126/// Accumulates a.0 * b.0 in raw Goldilocks (no per-multiply rescale).
127/// Call `finish()` to apply inv(SCALE) once and get a proper Fixed value.
128/// For a dot product of length N: N+1 field muls instead of 2N.
129pub struct RawAccum(pub Goldilocks);
130
131impl RawAccum {
132    #[inline]
133    pub fn zero() -> Self {
134        Self(Goldilocks(0))
135    }
136
137    /// Accumulate one product: self += a.0 * b.0 (raw, no rescale).
138    #[inline]
139    pub fn add_prod(&mut self, a: Fixed, b: Fixed) {
140        self.0 = self.0.add(a.0.mul(b.0));
141    }
142
143    /// Accumulate a pre-scaled addition: self += bias.0 * SCALE.
144    /// Used when adding a Fixed bias to a raw accumulator.
145    #[inline]
146    pub fn add_bias(&mut self, bias: Fixed) {
147        self.0 = self.0.add(bias.0.mul(Goldilocks(SCALE)));
148    }
149
150    /// Finalize: apply inv(SCALE) once to produce a proper Fixed value.
151    #[inline]
152    pub fn finish(self) -> Fixed {
153        Fixed(self.0.mul(inv_scale()))
154    }
155}
156
157// ─── Vector Operations ─────────────────────────────────────────────
158
159/// Dot product of two fixed-point vectors (fused, single rescale).
160pub fn dot(a: &[Fixed], b: &[Fixed]) -> Fixed {
161    debug_assert_eq!(a.len(), b.len());
162    let mut acc = RawAccum::zero();
163    for i in 0..a.len() {
164        acc.add_prod(a[i], b[i]);
165    }
166    acc.finish()
167}
168
169/// Matrix-vector multiply: out[i] = dot(mat[i], vec).
170/// Matrix is row-major: mat.len() = rows * cols.
171pub fn matvec(mat: &[Fixed], vec: &[Fixed], cols: usize) -> Vec<Fixed> {
172    let rows = mat.len() / cols;
173    let mut out = Vec::with_capacity(rows);
174    for r in 0..rows {
175        let row = &mat[r * cols..(r + 1) * cols];
176        out.push(dot(row, vec));
177    }
178    out
179}
180
181/// Element-wise ReLU.
182pub fn relu_vec(v: &mut [Fixed]) {
183    for x in v.iter_mut() {
184        *x = x.relu();
185    }
186}
187
188/// Layer normalization (simplified: zero-mean, unit-variance approximation).
189/// Subtracts mean, scales by inverse of approximate std deviation.
190pub fn layer_norm(v: &mut [Fixed]) {
191    let n = v.len();
192    if n == 0 {
193        return;
194    }
195    let n_fixed = Fixed::from_f64(n as f64);
196
197    // Mean
198    let mut sum = Fixed::ZERO;
199    for x in v.iter() {
200        sum = sum.add(*x);
201    }
202    let mean = sum.mul(n_fixed.inv());
203
204    // Subtract mean
205    for x in v.iter_mut() {
206        *x = x.sub(mean);
207    }
208
209    // Variance (sum of squares / n)
210    let mut var_sum = Fixed::ZERO;
211    for x in v.iter() {
212        var_sum = var_sum.madd(*x, *x);
213    }
214    let variance = var_sum.mul(n_fixed.inv());
215
216    // Approximate inv_sqrt via Newton: 1/sqrt(v) ≈ inv(v) * v ≈ just use inv(sqrt_approx)
217    // Simple approach: scale by inv(max(variance, epsilon))
218    let epsilon = Fixed::from_f64(1e-5);
219    let scale = if variance.to_f64().abs() < epsilon.to_f64() {
220        Fixed::ONE
221    } else {
222        variance.inv()
223    };
224    // This is 1/variance not 1/sqrt(variance), but for neural nets
225    // with normalized inputs, it's a reasonable approximation that avoids
226    // computing square roots in field arithmetic.
227    for x in v.iter_mut() {
228        *x = x.mul(scale);
229    }
230}
231
232// ─── Tests ─────────────────────────────────────────────────────────
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn roundtrip_positive() {
240        let vals = [0.0, 0.5, 1.0, 0.375, 100.0, 0.001];
241        for &v in &vals {
242            let f = Fixed::from_f64(v);
243            let back = f.to_f64();
244            assert!(
245                (back - v).abs() < 0.001,
246                "roundtrip failed for {}: got {}",
247                v,
248                back
249            );
250        }
251    }
252
253    #[test]
254    fn roundtrip_negative() {
255        let vals = [-0.5, -1.0, -100.0, -0.001];
256        for &v in &vals {
257            let f = Fixed::from_f64(v);
258            let back = f.to_f64();
259            assert!(
260                (back - v).abs() < 0.001,
261                "roundtrip failed for {}: got {}",
262                v,
263                back
264            );
265        }
266    }
267
268    #[test]
269    fn add_commutative() {
270        let a = Fixed::from_f64(0.5);
271        let b = Fixed::from_f64(0.25);
272        assert_eq!(a.add(b), b.add(a));
273    }
274
275    #[test]
276    fn add_values() {
277        let a = Fixed::from_f64(0.5);
278        let b = Fixed::from_f64(0.25);
279        let c = a.add(b);
280        assert!((c.to_f64() - 0.75).abs() < 0.001);
281    }
282
283    #[test]
284    fn sub_values() {
285        let a = Fixed::from_f64(1.0);
286        let b = Fixed::from_f64(0.25);
287        let c = a.sub(b);
288        assert!((c.to_f64() - 0.75).abs() < 0.001);
289    }
290
291    #[test]
292    fn mul_values() {
293        let a = Fixed::from_f64(0.5);
294        let b = Fixed::from_f64(0.5);
295        let c = a.mul(b);
296        assert!(
297            (c.to_f64() - 0.25).abs() < 0.001,
298            "0.5 * 0.5 = {}, expected 0.25",
299            c.to_f64()
300        );
301    }
302
303    #[test]
304    fn mul_negative() {
305        let a = Fixed::from_f64(-0.5);
306        let b = Fixed::from_f64(2.0);
307        let c = a.mul(b);
308        assert!(
309            (c.to_f64() - (-1.0)).abs() < 0.001,
310            "-0.5 * 2.0 = {}, expected -1.0",
311            c.to_f64()
312        );
313    }
314
315    #[test]
316    fn neg_values() {
317        let a = Fixed::from_f64(1.0);
318        let b = a.neg();
319        assert!((b.to_f64() - (-1.0)).abs() < 0.001);
320        assert_eq!(a.add(b), Fixed::ZERO);
321    }
322
323    #[test]
324    fn relu_positive() {
325        let a = Fixed::from_f64(0.5);
326        assert_eq!(a.relu(), a);
327    }
328
329    #[test]
330    fn relu_negative() {
331        let a = Fixed::from_f64(-0.5);
332        assert_eq!(a.relu(), Fixed::ZERO);
333    }
334
335    #[test]
336    fn relu_zero() {
337        assert_eq!(Fixed::ZERO.relu(), Fixed::ZERO);
338    }
339
340    #[test]
341    fn dot_product() {
342        let a = [
343            Fixed::from_f64(1.0),
344            Fixed::from_f64(2.0),
345            Fixed::from_f64(3.0),
346        ];
347        let b = [
348            Fixed::from_f64(4.0),
349            Fixed::from_f64(5.0),
350            Fixed::from_f64(6.0),
351        ];
352        let result = dot(&a, &b);
353        // 1*4 + 2*5 + 3*6 = 32
354        assert!(
355            (result.to_f64() - 32.0).abs() < 0.1,
356            "dot product = {}, expected 32.0",
357            result.to_f64()
358        );
359    }
360
361    #[test]
362    fn one_is_identity() {
363        let a = Fixed::from_f64(42.0);
364        let c = a.mul(Fixed::ONE);
365        assert!(
366            (c.to_f64() - 42.0).abs() < 0.01,
367            "a * 1 = {}, expected 42.0",
368            c.to_f64()
369        );
370    }
371
372    #[test]
373    fn inv_roundtrip() {
374        let a = Fixed::from_f64(4.0);
375        let b = a.inv();
376        let c = a.mul(b);
377        assert!(
378            (c.to_f64() - 1.0).abs() < 0.01,
379            "4 * inv(4) = {}, expected 1.0",
380            c.to_f64()
381        );
382    }
383
384    #[test]
385    fn raw_accum_dot_matches_naive() {
386        let a = [
387            Fixed::from_f64(1.0),
388            Fixed::from_f64(2.0),
389            Fixed::from_f64(3.0),
390        ];
391        let b = [
392            Fixed::from_f64(4.0),
393            Fixed::from_f64(5.0),
394            Fixed::from_f64(6.0),
395        ];
396        let naive = a[0].mul(b[0]).add(a[1].mul(b[1])).add(a[2].mul(b[2]));
397        let fused = dot(&a, &b);
398        assert!(
399            (naive.to_f64() - fused.to_f64()).abs() < 0.1,
400            "naive={}, fused={}",
401            naive.to_f64(),
402            fused.to_f64()
403        );
404        assert!(
405            (fused.to_f64() - 32.0).abs() < 0.1,
406            "fused dot = {}, expected 32.0",
407            fused.to_f64()
408        );
409    }
410
411    #[test]
412    fn raw_accum_with_bias() {
413        // bias + a*b = 10 + 3*4 = 22
414        let mut acc = RawAccum::zero();
415        acc.add_bias(Fixed::from_f64(10.0));
416        acc.add_prod(Fixed::from_f64(3.0), Fixed::from_f64(4.0));
417        let result = acc.finish();
418        assert!(
419            (result.to_f64() - 22.0).abs() < 0.1,
420            "bias+prod = {}, expected 22.0",
421            result.to_f64()
422        );
423    }
424
425    #[test]
426    fn accumulation_precision() {
427        // Sum 1000 copies of 0.001 — should be close to 1.0
428        let small = Fixed::from_f64(0.001);
429        let mut acc = Fixed::ZERO;
430        for _ in 0..1000 {
431            acc = acc.add(small);
432        }
433        assert!(
434            (acc.to_f64() - 1.0).abs() < 0.1,
435            "1000 * 0.001 = {}, expected ~1.0",
436            acc.to_f64()
437        );
438    }
439}