Skip to main content

ternlang_ml/
tritfloat_tensor.rs

1// SPDX-License-Identifier: LicenseRef-Ternlang-Commercial
2// Copyright (C) 2026 RFI-IRFOS. All rights reserved.
3//
4// TritFloatTensor — N-dimensional tensor of TritFloats.
5//
6// Every element carries its own confidence field. The confidence map of the
7// tensor is an emergent property of the computation that produced it — there
8// is no separate uncertainty bookkeeping layer.
9//
10// Architecture integration point:
11//   TritMatrix (weights, exact {-1,0,+1}, conf=1.0)
12//     ×
13//   TritFloatTensor (activations, float magnitudes, conf propagated)
14//     →
15//   TritFloatTensor (outputs, conf reflects the least-certain activation path)
16
17use rayon::prelude::*;
18use std::sync::atomic::{AtomicUsize, Ordering};
19
20use crate::tritfloat::TritFloat;
21use crate::{Trit, TritMatrix};
22
23// ─── Core type ────────────────────────────────────────────────────────────────
24
25/// An N-dimensional tensor of TritFloats laid out in row-major order.
26///
27/// Each element carries its own confidence field. The tensor's overall
28/// confidence is determined by `min_confidence()` or `mean_confidence()`.
29#[derive(Clone, Debug)]
30pub struct TritFloatTensor {
31    pub data: Vec<TritFloat>,
32    pub shape: Vec<usize>,
33}
34
35// ─── Constructors ─────────────────────────────────────────────────────────────
36
37impl TritFloatTensor {
38    /// All-zero tensor with neutral confidence (0.5) at every element.
39    pub fn zeros(shape: &[usize]) -> Self {
40        let numel = shape.iter().product();
41        Self { data: vec![TritFloat::zero(); numel], shape: shape.to_vec() }
42    }
43
44    /// All-ones tensor with maximum confidence at every element.
45    pub fn ones(shape: &[usize]) -> Self {
46        let numel = shape.iter().product::<usize>();
47        Self {
48            data: vec![TritFloat::from_f32(1.0); numel],
49            shape: shape.to_vec(),
50        }
51    }
52
53    /// Build from a flat f32 slice, all elements get confidence=1.0.
54    pub fn from_f32_slice(data: &[f32], shape: &[usize]) -> Self {
55        assert_eq!(data.len(), shape.iter().product::<usize>(),
56            "data length must equal product of shape dimensions");
57        Self {
58            data: data.iter().map(|&x| TritFloat::from_f32(x)).collect(),
59            shape: shape.to_vec(),
60        }
61    }
62
63    /// Build from f32 values with per-element confidence.
64    pub fn from_f32_with_confidence(vals: &[f32], conf: &[f32], shape: &[usize]) -> Self {
65        assert_eq!(vals.len(), shape.iter().product::<usize>());
66        assert_eq!(vals.len(), conf.len());
67        Self {
68            data: vals.iter().zip(conf.iter())
69                .map(|(&v, &c)| TritFloat::from_f32_with_confidence(v, c))
70                .collect(),
71            shape: shape.to_vec(),
72        }
73    }
74
75    /// Convert a TritMatrix to a 2D TritFloatTensor.
76    ///
77    /// Weights are exactly {-1, 0, +1} — they carry maximum confidence (1.0).
78    pub fn from_tritmatrix(m: &TritMatrix) -> Self {
79        let data = m.data.iter().map(|&t| {
80            let v = match t {
81                Trit::Affirm =>  1.0f32,
82                Trit::Reject => -1.0,
83                Trit::Tend   =>  0.0,
84            };
85            TritFloat::from_f32_with_confidence(v, 1.0)
86        }).collect();
87        Self { data, shape: vec![m.rows, m.cols] }
88    }
89
90    // ── Shape and access ──────────────────────────────────────────────────────
91
92    pub fn shape(&self) -> &[usize] { &self.shape }
93    pub fn ndim(&self)  -> usize { self.shape.len() }
94    pub fn numel(&self) -> usize { self.data.len() }
95
96    /// Row-major flat index from a multi-dimensional index.
97    fn flat_idx(&self, idx: &[usize]) -> usize {
98        assert_eq!(idx.len(), self.ndim(), "index rank must match tensor rank");
99        let mut flat = 0usize;
100        let mut stride = 1usize;
101        for i in (0..self.ndim()).rev() {
102            flat   += idx[i] * stride;
103            stride *= self.shape[i];
104        }
105        flat
106    }
107
108    pub fn get(&self, idx: &[usize]) -> TritFloat {
109        self.data[self.flat_idx(idx)]
110    }
111
112    pub fn set(&mut self, idx: &[usize], val: TritFloat) {
113        let flat = self.flat_idx(idx);
114        self.data[flat] = val;
115    }
116
117    // ── Matmul — TritFloatTensor × TritFloatTensor ────────────────────────────
118
119    /// 2D matrix multiply: [m, k] × [k, n] → [m, n].
120    ///
121    /// Each output element's confidence = min confidence over all contributing
122    /// (a_i × b_i) multiplications. Zero-phase activations skip their MAC
123    /// (@sparseskip at activation level) but still participate in the
124    /// confidence running-minimum so uncertain zeros don't disappear.
125    /// Rows are computed in parallel via Rayon.
126    pub fn matmul(a: &Self, b: &Self) -> Self {
127        assert_eq!(a.ndim(), 2, "matmul requires 2D tensors");
128        assert_eq!(b.ndim(), 2, "matmul requires 2D tensors");
129        let (m, k) = (a.shape[0], a.shape[1]);
130        let (k2, n) = (b.shape[0], b.shape[1]);
131        assert_eq!(k, k2, "matmul: a.cols ({k}) must equal b.rows ({k2})");
132
133        let mut out_data = vec![TritFloat::zero(); m * n];
134
135        out_data.par_chunks_mut(n).enumerate().for_each(|(row, out_row)| {
136            for col in 0..n {
137                let mut acc      = 0.0f32;
138                let mut min_conf = 1.0f32;
139                for i in 0..k {
140                    let ai = a.data[row * k + i];
141                    let bi = b.data[i * n + col];
142                    let c  = TritFloat::mul_confidence(ai, bi);
143                    if c < min_conf { min_conf = c; }
144                    if !ai.is_zero() && !bi.is_zero() {
145                        acc += ai.to_f32() * bi.to_f32();
146                    }
147                }
148                out_row[col] = TritFloat::from_f32_with_confidence(acc, min_conf);
149            }
150        });
151
152        Self { data: out_data, shape: vec![m, n] }
153    }
154
155    /// Matmul returning (result, total_macs_skipped) for sparsity instrumentation.
156    pub fn matmul_sparse(a: &Self, b: &Self) -> (Self, usize) {
157        assert_eq!(a.ndim(), 2);
158        assert_eq!(b.ndim(), 2);
159        let (m, k) = (a.shape[0], a.shape[1]);
160        let (k2, n) = (b.shape[0], b.shape[1]);
161        assert_eq!(k, k2);
162
163        let mut out_data    = vec![TritFloat::zero(); m * n];
164        let total_skipped   = AtomicUsize::new(0);
165
166        out_data.par_chunks_mut(n).enumerate().for_each(|(row, out_row)| {
167            let mut row_skipped = 0usize;
168            for col in 0..n {
169                let mut acc      = 0.0f32;
170                let mut min_conf = 1.0f32;
171                for i in 0..k {
172                    let ai = a.data[row * k + i];
173                    let bi = b.data[i * n + col];
174                    let c  = TritFloat::mul_confidence(ai, bi);
175                    if c < min_conf { min_conf = c; }
176                    if ai.is_zero() || bi.is_zero() {
177                        row_skipped += 1;
178                    } else {
179                        acc += ai.to_f32() * bi.to_f32();
180                    }
181                }
182                out_row[col] = TritFloat::from_f32_with_confidence(acc, min_conf);
183            }
184            total_skipped.fetch_add(row_skipped, Ordering::Relaxed);
185        });
186
187        let skipped = total_skipped.load(Ordering::Relaxed);
188        (Self { data: out_data, shape: vec![m, n] }, skipped)
189    }
190
191    // ── Matmul — TritFloatTensor × TritMatrix (inference hot path) ────────────
192
193    /// Multiply float activations by exact ternary weights (TritMatrix).
194    ///
195    /// This is the inference-time hot path: activations carry live confidence,
196    /// weights are exact {-1,0,+1} with confidence=1.0, so output confidence
197    /// = min activation confidence over each dot product.
198    ///
199    /// @sparseskip fires on BOTH activation zeros (is_zero()) AND weight zeros
200    /// (w == 0), giving the full combined sparsity savings.
201    ///
202    /// Returns (output_tensor, total_macs_skipped).
203    pub fn matmul_trit(activations: &Self, weights: &TritMatrix) -> (Self, usize) {
204        assert_eq!(activations.ndim(), 2,
205            "matmul_trit requires 2D activation tensor");
206        let (m, k) = (activations.shape[0], activations.shape[1]);
207        assert_eq!(k, weights.rows,
208            "activation cols ({k}) must match weight rows ({})", weights.rows);
209        let n = weights.cols;
210
211        let w_i8 = weights.to_i8_vec();
212        let mut out_data  = vec![TritFloat::zero(); m * n];
213        let total_skipped = AtomicUsize::new(0);
214
215        out_data.par_chunks_mut(n).enumerate().for_each(|(row, out_row)| {
216            let mut row_skipped = 0usize;
217            let act_row = &activations.data[row * k..(row + 1) * k];
218
219            for col in 0..n {
220                let mut acc      = 0.0f32;
221                let mut min_conf = 1.0f32;
222                for i in 0..k {
223                    let ai = act_row[i];
224                    let wi = w_i8[i * n + col];
225                    // Weight conf is 1.0, so min = activation conf
226                    let c = ai.confidence();
227                    if c < min_conf { min_conf = c; }
228                    if ai.is_zero() || wi == 0 {
229                        row_skipped += 1;
230                    } else {
231                        acc += ai.to_f32() * (wi as f32);
232                    }
233                }
234                out_row[col] = TritFloat::from_f32_with_confidence(acc, min_conf);
235            }
236            total_skipped.fetch_add(row_skipped, Ordering::Relaxed);
237        });
238
239        (Self { data: out_data, shape: vec![m, n] },
240         total_skipped.load(Ordering::Relaxed))
241    }
242
243    // ── Elementwise operations ────────────────────────────────────────────────
244
245    pub fn add_elementwise(a: &Self, b: &Self) -> Self {
246        assert_eq!(a.shape, b.shape, "elementwise add requires equal shapes");
247        Self {
248            data:  a.data.iter().zip(b.data.iter()).map(|(&ai, &bi)| ai.add(bi)).collect(),
249            shape: a.shape.clone(),
250        }
251    }
252
253    pub fn mul_elementwise(a: &Self, b: &Self) -> Self {
254        assert_eq!(a.shape, b.shape, "elementwise mul requires equal shapes");
255        Self {
256            data:  a.data.iter().zip(b.data.iter()).map(|(&ai, &bi)| ai.mul(bi)).collect(),
257            shape: a.shape.clone(),
258        }
259    }
260
261    /// Apply a function to every element in parallel.
262    pub fn map<F>(&self, f: F) -> Self
263    where
264        F: Fn(TritFloat) -> TritFloat + Sync + Send,
265    {
266        Self {
267            data:  self.data.par_iter().map(|&x| f(x)).collect(),
268            shape: self.shape.clone(),
269        }
270    }
271
272    // ── Reductions ────────────────────────────────────────────────────────────
273
274    pub fn sum_all(&self) -> TritFloat {
275        self.data.iter().fold(TritFloat::zero(), |acc, &x| acc.add(x))
276    }
277
278    pub fn mean_all(&self) -> TritFloat {
279        if self.data.is_empty() { return TritFloat::zero(); }
280        let s = self.sum_all();
281        TritFloat::from_f32_with_confidence(
282            s.to_f32() / self.data.len() as f32,
283            s.confidence(),
284        )
285    }
286
287    /// Minimum confidence across all elements: how certain is the least-certain value?
288    pub fn min_confidence(&self) -> f32 {
289        self.data.iter().map(|x| x.confidence()).fold(1.0f32, f32::min)
290    }
291
292    /// Mean confidence: average epistemic certainty across the tensor.
293    pub fn mean_confidence(&self) -> f32 {
294        if self.data.is_empty() { return 0.0; }
295        self.data.iter().map(|x| x.confidence()).sum::<f32>() / self.data.len() as f32
296    }
297
298    /// Histogram of confidence states across all elements.
299    /// Index i = count of elements with confidence ≈ i/8.
300    /// The 9 bins correspond to the 9 discrete states of the 2-trit confidence field.
301    pub fn confidence_histogram(&self) -> [usize; 9] {
302        let mut hist = [0usize; 9];
303        for x in &self.data {
304            let idx = (x.confidence() * 8.0).round() as usize;
305            hist[idx.min(8)] += 1;
306        }
307        hist
308    }
309
310    // ── Sparsity ──────────────────────────────────────────────────────────────
311
312    /// Fraction of elements with zero phase (exactly 0.0 value).
313    pub fn sparsity(&self) -> f64 {
314        let zeros = self.data.iter().filter(|x| x.is_zero()).count();
315        zeros as f64 / self.data.len().max(1) as f64
316    }
317
318    // ── Conversions ───────────────────────────────────────────────────────────
319
320    /// Extract f32 values, discarding confidence information.
321    pub fn to_f32_vec(&self) -> Vec<f32> {
322        self.data.iter().map(|x| x.to_f32()).collect()
323    }
324
325    /// Quantize to TritMatrix: positive phase → Affirm, negative → Reject, zero → Tend.
326    pub fn to_tritmatrix(&self) -> TritMatrix {
327        assert_eq!(self.ndim(), 2, "to_tritmatrix requires a 2D tensor");
328        let data = self.data.iter().map(|x| match x.phase() {
329            1  => Trit::Affirm,
330            -1 => Trit::Reject,
331            _  => Trit::Tend,
332        }).collect();
333        TritMatrix { rows: self.shape[0], cols: self.shape[1], data }
334    }
335
336    /// Apply softmax along each row of a 2D tensor.
337    pub fn softmax_rows(&self) -> Self {
338        assert_eq!(self.ndim(), 2, "softmax_rows requires a 2D tensor");
339        let (m, n) = (self.shape[0], self.shape[1]);
340        let mut out = Self::zeros(&[m, n]);
341        for row in 0..m {
342            let slice = &self.data[row * n..(row + 1) * n];
343            let sm    = TritFloat::softmax(slice);
344            out.data[row * n..(row + 1) * n].copy_from_slice(&sm);
345        }
346        out
347    }
348}
349
350// ─── Tests ────────────────────────────────────────────────────────────────────
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    fn approx(a: f32, b: f32, tol: f32) -> bool {
357        if b == 0.0 { return a.abs() < tol; }
358        ((a - b) / b).abs() < tol
359    }
360
361    #[test]
362    fn zeros_shape_and_values() {
363        let t = TritFloatTensor::zeros(&[3, 4]);
364        assert_eq!(t.shape(), &[3, 4]);
365        assert_eq!(t.numel(), 12);
366        assert!(t.data.iter().all(|x| x.is_zero()));
367    }
368
369    #[test]
370    fn ones_values() {
371        let t = TritFloatTensor::ones(&[2, 3]);
372        for x in &t.data {
373            assert!(approx(x.to_f32(), 1.0, 0.01));
374            assert_eq!(x.phase(), 1);
375        }
376    }
377
378    #[test]
379    fn from_f32_slice_roundtrip() {
380        let vals = vec![1.0f32, -2.0, 0.0, 3.14];
381        let t = TritFloatTensor::from_f32_slice(&vals, &[2, 2]);
382        assert_eq!(t.shape(), &[2, 2]);
383        let back = t.to_f32_vec();
384        for (a, b) in vals.iter().zip(back.iter()) {
385            assert!(approx(*b, *a, 0.01), "{a} → {b}");
386        }
387    }
388
389    #[test]
390    fn from_tritmatrix_correct_values_and_confidence() {
391        use crate::TritMatrix;
392        use crate::Trit;
393        let m = TritMatrix::from_trits(2, 2, vec![
394            Trit::Affirm, Trit::Tend, Trit::Reject, Trit::Affirm,
395        ]);
396        let t = TritFloatTensor::from_tritmatrix(&m);
397        assert_eq!(t.shape(), &[2, 2]);
398        assert!(approx(t.get(&[0, 0]).to_f32(),  1.0, 0.01));
399        assert!(t.get(&[0, 1]).is_zero());
400        assert!(approx(t.get(&[1, 0]).to_f32(), -1.0, 0.01));
401        // All weights are exactly known → confidence 1.0
402        assert!(t.data.iter().all(|x| (x.confidence() - 1.0).abs() < 0.15));
403    }
404
405    #[test]
406    fn matmul_identity() {
407        // I × A = A
408        let identity = TritFloatTensor::from_f32_slice(
409            &[1.0f32, 0.0, 0.0, 1.0], &[2, 2]
410        );
411        let a = TritFloatTensor::from_f32_slice(
412            &[3.0f32, 4.0, 5.0, 6.0], &[2, 2]
413        );
414        let r = TritFloatTensor::matmul(&identity, &a);
415        let vals = r.to_f32_vec();
416        assert!(approx(vals[0], 3.0, 0.02));
417        assert!(approx(vals[1], 4.0, 0.02));
418        assert!(approx(vals[2], 5.0, 0.02));
419        assert!(approx(vals[3], 6.0, 0.02));
420    }
421
422    #[test]
423    fn matmul_2x3_x_3x2() {
424        // [[1,2,3],[4,5,6]] × [[7,8],[9,10],[11,12]] = [[58,64],[139,154]]
425        let a = TritFloatTensor::from_f32_slice(
426            &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]
427        );
428        let b = TritFloatTensor::from_f32_slice(
429            &[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]
430        );
431        let r = TritFloatTensor::matmul(&a, &b);
432        assert_eq!(r.shape(), &[2, 2]);
433        let v = r.to_f32_vec();
434        assert!(approx(v[0],  58.0, 0.02), "got {}", v[0]);
435        assert!(approx(v[1],  64.0, 0.02), "got {}", v[1]);
436        assert!(approx(v[2], 139.0, 0.02), "got {}", v[2]);
437        assert!(approx(v[3], 154.0, 0.02), "got {}", v[3]);
438    }
439
440    #[test]
441    fn matmul_confidence_propagates() {
442        // Certain weights × uncertain activations → uncertain outputs
443        let acts = TritFloatTensor::from_f32_with_confidence(
444            &[1.0f32, 1.0], &[0.125f32, 0.125], &[1, 2]
445        );
446        let weights = TritFloatTensor::from_f32_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2]);
447        let r = TritFloatTensor::matmul(&acts, &weights);
448        assert!(r.min_confidence() < 0.3, "low-conf inputs → low-conf output");
449    }
450
451    #[test]
452    fn matmul_sparse_skip_count() {
453        // Half zeros in activations → roughly half skips
454        let acts = TritFloatTensor::from_f32_slice(
455            &[1.0f32, 0.0, 1.0, 0.0], &[1, 4]
456        );
457        let w = TritFloatTensor::from_f32_slice(
458            &[1.0f32; 8], &[4, 2]
459        );
460        let (_, skips) = TritFloatTensor::matmul_sparse(&acts, &w);
461        assert!(skips > 0, "zero activations should produce skips");
462    }
463
464    #[test]
465    fn matmul_trit_matches_dense() {
466        // [1, -1] × [[1, 0], [-1, 1]] = [1*1 + (-1)*(-1), 1*0 + (-1)*1] = [2, -1]
467        use crate::TritMatrix;
468        let acts = TritFloatTensor::from_f32_slice(&[1.0f32, -1.0], &[1, 2]);
469        let mut w = TritMatrix::new(2, 2);
470        w.set(0, 0, Trit::Affirm);   // (0,0) = +1
471        w.set(0, 1, Trit::Tend);     // (0,1) = 0
472        w.set(1, 0, Trit::Reject);   // (1,0) = -1
473        w.set(1, 1, Trit::Affirm);   // (1,1) = +1
474
475        let (r, _) = TritFloatTensor::matmul_trit(&acts, &w);
476        assert_eq!(r.shape(), &[1, 2]);
477        let v = r.to_f32_vec();
478        // 1*1 + (-1)*(-1) = 1 + 1 = 2
479        assert!(approx(v[0], 2.0, 0.02), "col0: expected 2, got {}", v[0]);
480        // 1*0 + (-1)*1 = -1
481        assert!(approx(v[1], -1.0, 0.02), "col1: expected -1, got {}", v[1]);
482    }
483
484    #[test]
485    fn elementwise_add_and_mul() {
486        let a = TritFloatTensor::from_f32_slice(&[1.0f32, 2.0, 3.0], &[3]);
487        let b = TritFloatTensor::from_f32_slice(&[4.0f32, 5.0, 6.0], &[3]);
488        let s = TritFloatTensor::add_elementwise(&a, &b);
489        let p = TritFloatTensor::mul_elementwise(&a, &b);
490        let sv = s.to_f32_vec();
491        let pv = p.to_f32_vec();
492        assert!(approx(sv[0], 5.0, 0.02));
493        assert!(approx(sv[2], 9.0, 0.02));
494        assert!(approx(pv[0], 4.0, 0.02));
495        assert!(approx(pv[2], 18.0, 0.02));
496    }
497
498    #[test]
499    fn map_applies_function() {
500        let t = TritFloatTensor::from_f32_slice(&[1.0f32, 4.0, 9.0], &[3]);
501        let r = t.map(|x| x.sqrt());
502        let v = r.to_f32_vec();
503        assert!(approx(v[0], 1.0, 0.02));
504        assert!(approx(v[1], 2.0, 0.02));
505        assert!(approx(v[2], 3.0, 0.02));
506    }
507
508    #[test]
509    fn sparsity_correct() {
510        // 2 zeros out of 4 = 50%
511        let t = TritFloatTensor::from_f32_slice(&[1.0f32, 0.0, -1.0, 0.0], &[2, 2]);
512        assert!((t.sparsity() - 0.5).abs() < 1e-6);
513    }
514
515    #[test]
516    fn confidence_histogram_bins() {
517        let t = TritFloatTensor::from_f32_with_confidence(
518            &[1.0f32, 1.0, 1.0],
519            &[0.0f32, 0.5, 1.0],
520            &[3],
521        );
522        let hist = t.confidence_histogram();
523        assert_eq!(hist[0], 1, "one element at conf=0");
524        assert_eq!(hist[4], 1, "one element at conf=0.5");
525        assert_eq!(hist[8], 1, "one element at conf=1.0");
526    }
527
528    #[test]
529    fn min_and_mean_confidence() {
530        let t = TritFloatTensor::from_f32_with_confidence(
531            &[1.0f32, 1.0],
532            &[0.125f32, 1.0],
533            &[2],
534        );
535        assert!((t.min_confidence() - 0.125).abs() < 0.15);
536        let mean = t.mean_confidence();
537        assert!(mean > 0.125 && mean < 1.0, "mean should be between min and max");
538    }
539
540    #[test]
541    fn to_tritmatrix_roundtrip() {
542        let t = TritFloatTensor::from_f32_slice(&[1.0f32, -1.0, 0.0, 0.5], &[2, 2]);
543        let m = t.to_tritmatrix();
544        assert_eq!(m.get(0, 0), Trit::Affirm);
545        assert_eq!(m.get(0, 1), Trit::Reject);
546        assert_eq!(m.get(1, 0), Trit::Tend);
547        assert_eq!(m.get(1, 1), Trit::Affirm);
548    }
549
550    #[test]
551    fn softmax_rows_sums_to_one() {
552        let t = TritFloatTensor::from_f32_slice(
553            &[1.0f32, 2.0, 3.0, 0.1, 0.2, 0.3], &[2, 3]
554        );
555        let sm = t.softmax_rows();
556        for row in 0..2 {
557            let row_sum: f32 = sm.data[row * 3..(row + 1) * 3]
558                .iter().map(|x| x.to_f32()).sum();
559            // TritFloat roundtrip precision ~0.3% per element; 3 elements → up to ~1% accumulated
560            assert!((row_sum - 1.0).abs() < 0.005, "row {row} sum = {row_sum}");
561        }
562    }
563
564    #[test]
565    fn matmul_sparse_matches_matmul() {
566        let a = TritFloatTensor::from_f32_slice(
567            &[1.0f32, 0.0, 2.0, 0.0, 1.0, 3.0], &[2, 3]
568        );
569        let b = TritFloatTensor::from_f32_slice(
570            &[1.0f32, 2.0, 0.0, 3.0, 4.0, 1.0], &[3, 2]
571        );
572        let r1 = TritFloatTensor::matmul(&a, &b);
573        let (r2, _) = TritFloatTensor::matmul_sparse(&a, &b);
574        for (x, y) in r1.to_f32_vec().iter().zip(r2.to_f32_vec().iter()) {
575            assert!(approx(*x, *y, 0.001), "sparse and dense matmul disagree: {x} vs {y}");
576        }
577    }
578}