1use rayon::prelude::*;
18use std::sync::atomic::{AtomicUsize, Ordering};
19
20use crate::tritfloat::TritFloat;
21use crate::{Trit, TritMatrix};
22
23#[derive(Clone, Debug)]
30pub struct TritFloatTensor {
31 pub data: Vec<TritFloat>,
32 pub shape: Vec<usize>,
33}
34
35impl TritFloatTensor {
38 pub fn zeros(shape: &[usize]) -> Self {
40 let numel = shape.iter().product();
41 Self { data: vec![TritFloat::zero(); numel], shape: shape.to_vec() }
42 }
43
44 pub fn ones(shape: &[usize]) -> Self {
46 let numel = shape.iter().product::<usize>();
47 Self {
48 data: vec![TritFloat::from_f32(1.0); numel],
49 shape: shape.to_vec(),
50 }
51 }
52
53 pub fn from_f32_slice(data: &[f32], shape: &[usize]) -> Self {
55 assert_eq!(data.len(), shape.iter().product::<usize>(),
56 "data length must equal product of shape dimensions");
57 Self {
58 data: data.iter().map(|&x| TritFloat::from_f32(x)).collect(),
59 shape: shape.to_vec(),
60 }
61 }
62
63 pub fn from_f32_with_confidence(vals: &[f32], conf: &[f32], shape: &[usize]) -> Self {
65 assert_eq!(vals.len(), shape.iter().product::<usize>());
66 assert_eq!(vals.len(), conf.len());
67 Self {
68 data: vals.iter().zip(conf.iter())
69 .map(|(&v, &c)| TritFloat::from_f32_with_confidence(v, c))
70 .collect(),
71 shape: shape.to_vec(),
72 }
73 }
74
75 pub fn from_tritmatrix(m: &TritMatrix) -> Self {
79 let data = m.data.iter().map(|&t| {
80 let v = match t {
81 Trit::Affirm => 1.0f32,
82 Trit::Reject => -1.0,
83 Trit::Tend => 0.0,
84 };
85 TritFloat::from_f32_with_confidence(v, 1.0)
86 }).collect();
87 Self { data, shape: vec![m.rows, m.cols] }
88 }
89
90 pub fn shape(&self) -> &[usize] { &self.shape }
93 pub fn ndim(&self) -> usize { self.shape.len() }
94 pub fn numel(&self) -> usize { self.data.len() }
95
96 fn flat_idx(&self, idx: &[usize]) -> usize {
98 assert_eq!(idx.len(), self.ndim(), "index rank must match tensor rank");
99 let mut flat = 0usize;
100 let mut stride = 1usize;
101 for i in (0..self.ndim()).rev() {
102 flat += idx[i] * stride;
103 stride *= self.shape[i];
104 }
105 flat
106 }
107
108 pub fn get(&self, idx: &[usize]) -> TritFloat {
109 self.data[self.flat_idx(idx)]
110 }
111
112 pub fn set(&mut self, idx: &[usize], val: TritFloat) {
113 let flat = self.flat_idx(idx);
114 self.data[flat] = val;
115 }
116
117 pub fn matmul(a: &Self, b: &Self) -> Self {
127 assert_eq!(a.ndim(), 2, "matmul requires 2D tensors");
128 assert_eq!(b.ndim(), 2, "matmul requires 2D tensors");
129 let (m, k) = (a.shape[0], a.shape[1]);
130 let (k2, n) = (b.shape[0], b.shape[1]);
131 assert_eq!(k, k2, "matmul: a.cols ({k}) must equal b.rows ({k2})");
132
133 let mut out_data = vec![TritFloat::zero(); m * n];
134
135 out_data.par_chunks_mut(n).enumerate().for_each(|(row, out_row)| {
136 for col in 0..n {
137 let mut acc = 0.0f32;
138 let mut min_conf = 1.0f32;
139 for i in 0..k {
140 let ai = a.data[row * k + i];
141 let bi = b.data[i * n + col];
142 let c = TritFloat::mul_confidence(ai, bi);
143 if c < min_conf { min_conf = c; }
144 if !ai.is_zero() && !bi.is_zero() {
145 acc += ai.to_f32() * bi.to_f32();
146 }
147 }
148 out_row[col] = TritFloat::from_f32_with_confidence(acc, min_conf);
149 }
150 });
151
152 Self { data: out_data, shape: vec![m, n] }
153 }
154
155 pub fn matmul_sparse(a: &Self, b: &Self) -> (Self, usize) {
157 assert_eq!(a.ndim(), 2);
158 assert_eq!(b.ndim(), 2);
159 let (m, k) = (a.shape[0], a.shape[1]);
160 let (k2, n) = (b.shape[0], b.shape[1]);
161 assert_eq!(k, k2);
162
163 let mut out_data = vec![TritFloat::zero(); m * n];
164 let total_skipped = AtomicUsize::new(0);
165
166 out_data.par_chunks_mut(n).enumerate().for_each(|(row, out_row)| {
167 let mut row_skipped = 0usize;
168 for col in 0..n {
169 let mut acc = 0.0f32;
170 let mut min_conf = 1.0f32;
171 for i in 0..k {
172 let ai = a.data[row * k + i];
173 let bi = b.data[i * n + col];
174 let c = TritFloat::mul_confidence(ai, bi);
175 if c < min_conf { min_conf = c; }
176 if ai.is_zero() || bi.is_zero() {
177 row_skipped += 1;
178 } else {
179 acc += ai.to_f32() * bi.to_f32();
180 }
181 }
182 out_row[col] = TritFloat::from_f32_with_confidence(acc, min_conf);
183 }
184 total_skipped.fetch_add(row_skipped, Ordering::Relaxed);
185 });
186
187 let skipped = total_skipped.load(Ordering::Relaxed);
188 (Self { data: out_data, shape: vec![m, n] }, skipped)
189 }
190
191 pub fn matmul_trit(activations: &Self, weights: &TritMatrix) -> (Self, usize) {
204 assert_eq!(activations.ndim(), 2,
205 "matmul_trit requires 2D activation tensor");
206 let (m, k) = (activations.shape[0], activations.shape[1]);
207 assert_eq!(k, weights.rows,
208 "activation cols ({k}) must match weight rows ({})", weights.rows);
209 let n = weights.cols;
210
211 let w_i8 = weights.to_i8_vec();
212 let mut out_data = vec![TritFloat::zero(); m * n];
213 let total_skipped = AtomicUsize::new(0);
214
215 out_data.par_chunks_mut(n).enumerate().for_each(|(row, out_row)| {
216 let mut row_skipped = 0usize;
217 let act_row = &activations.data[row * k..(row + 1) * k];
218
219 for col in 0..n {
220 let mut acc = 0.0f32;
221 let mut min_conf = 1.0f32;
222 for i in 0..k {
223 let ai = act_row[i];
224 let wi = w_i8[i * n + col];
225 let c = ai.confidence();
227 if c < min_conf { min_conf = c; }
228 if ai.is_zero() || wi == 0 {
229 row_skipped += 1;
230 } else {
231 acc += ai.to_f32() * (wi as f32);
232 }
233 }
234 out_row[col] = TritFloat::from_f32_with_confidence(acc, min_conf);
235 }
236 total_skipped.fetch_add(row_skipped, Ordering::Relaxed);
237 });
238
239 (Self { data: out_data, shape: vec![m, n] },
240 total_skipped.load(Ordering::Relaxed))
241 }
242
243 pub fn add_elementwise(a: &Self, b: &Self) -> Self {
246 assert_eq!(a.shape, b.shape, "elementwise add requires equal shapes");
247 Self {
248 data: a.data.iter().zip(b.data.iter()).map(|(&ai, &bi)| ai.add(bi)).collect(),
249 shape: a.shape.clone(),
250 }
251 }
252
253 pub fn mul_elementwise(a: &Self, b: &Self) -> Self {
254 assert_eq!(a.shape, b.shape, "elementwise mul requires equal shapes");
255 Self {
256 data: a.data.iter().zip(b.data.iter()).map(|(&ai, &bi)| ai.mul(bi)).collect(),
257 shape: a.shape.clone(),
258 }
259 }
260
261 pub fn map<F>(&self, f: F) -> Self
263 where
264 F: Fn(TritFloat) -> TritFloat + Sync + Send,
265 {
266 Self {
267 data: self.data.par_iter().map(|&x| f(x)).collect(),
268 shape: self.shape.clone(),
269 }
270 }
271
272 pub fn sum_all(&self) -> TritFloat {
275 self.data.iter().fold(TritFloat::zero(), |acc, &x| acc.add(x))
276 }
277
278 pub fn mean_all(&self) -> TritFloat {
279 if self.data.is_empty() { return TritFloat::zero(); }
280 let s = self.sum_all();
281 TritFloat::from_f32_with_confidence(
282 s.to_f32() / self.data.len() as f32,
283 s.confidence(),
284 )
285 }
286
287 pub fn min_confidence(&self) -> f32 {
289 self.data.iter().map(|x| x.confidence()).fold(1.0f32, f32::min)
290 }
291
292 pub fn mean_confidence(&self) -> f32 {
294 if self.data.is_empty() { return 0.0; }
295 self.data.iter().map(|x| x.confidence()).sum::<f32>() / self.data.len() as f32
296 }
297
298 pub fn confidence_histogram(&self) -> [usize; 9] {
302 let mut hist = [0usize; 9];
303 for x in &self.data {
304 let idx = (x.confidence() * 8.0).round() as usize;
305 hist[idx.min(8)] += 1;
306 }
307 hist
308 }
309
310 pub fn sparsity(&self) -> f64 {
314 let zeros = self.data.iter().filter(|x| x.is_zero()).count();
315 zeros as f64 / self.data.len().max(1) as f64
316 }
317
318 pub fn to_f32_vec(&self) -> Vec<f32> {
322 self.data.iter().map(|x| x.to_f32()).collect()
323 }
324
325 pub fn to_tritmatrix(&self) -> TritMatrix {
327 assert_eq!(self.ndim(), 2, "to_tritmatrix requires a 2D tensor");
328 let data = self.data.iter().map(|x| match x.phase() {
329 1 => Trit::Affirm,
330 -1 => Trit::Reject,
331 _ => Trit::Tend,
332 }).collect();
333 TritMatrix { rows: self.shape[0], cols: self.shape[1], data }
334 }
335
336 pub fn softmax_rows(&self) -> Self {
338 assert_eq!(self.ndim(), 2, "softmax_rows requires a 2D tensor");
339 let (m, n) = (self.shape[0], self.shape[1]);
340 let mut out = Self::zeros(&[m, n]);
341 for row in 0..m {
342 let slice = &self.data[row * n..(row + 1) * n];
343 let sm = TritFloat::softmax(slice);
344 out.data[row * n..(row + 1) * n].copy_from_slice(&sm);
345 }
346 out
347 }
348}
349
350#[cfg(test)]
353mod tests {
354 use super::*;
355
356 fn approx(a: f32, b: f32, tol: f32) -> bool {
357 if b == 0.0 { return a.abs() < tol; }
358 ((a - b) / b).abs() < tol
359 }
360
361 #[test]
362 fn zeros_shape_and_values() {
363 let t = TritFloatTensor::zeros(&[3, 4]);
364 assert_eq!(t.shape(), &[3, 4]);
365 assert_eq!(t.numel(), 12);
366 assert!(t.data.iter().all(|x| x.is_zero()));
367 }
368
369 #[test]
370 fn ones_values() {
371 let t = TritFloatTensor::ones(&[2, 3]);
372 for x in &t.data {
373 assert!(approx(x.to_f32(), 1.0, 0.01));
374 assert_eq!(x.phase(), 1);
375 }
376 }
377
378 #[test]
379 fn from_f32_slice_roundtrip() {
380 let vals = vec![1.0f32, -2.0, 0.0, 3.14];
381 let t = TritFloatTensor::from_f32_slice(&vals, &[2, 2]);
382 assert_eq!(t.shape(), &[2, 2]);
383 let back = t.to_f32_vec();
384 for (a, b) in vals.iter().zip(back.iter()) {
385 assert!(approx(*b, *a, 0.01), "{a} → {b}");
386 }
387 }
388
389 #[test]
390 fn from_tritmatrix_correct_values_and_confidence() {
391 use crate::TritMatrix;
392 use crate::Trit;
393 let m = TritMatrix::from_trits(2, 2, vec![
394 Trit::Affirm, Trit::Tend, Trit::Reject, Trit::Affirm,
395 ]);
396 let t = TritFloatTensor::from_tritmatrix(&m);
397 assert_eq!(t.shape(), &[2, 2]);
398 assert!(approx(t.get(&[0, 0]).to_f32(), 1.0, 0.01));
399 assert!(t.get(&[0, 1]).is_zero());
400 assert!(approx(t.get(&[1, 0]).to_f32(), -1.0, 0.01));
401 assert!(t.data.iter().all(|x| (x.confidence() - 1.0).abs() < 0.15));
403 }
404
405 #[test]
406 fn matmul_identity() {
407 let identity = TritFloatTensor::from_f32_slice(
409 &[1.0f32, 0.0, 0.0, 1.0], &[2, 2]
410 );
411 let a = TritFloatTensor::from_f32_slice(
412 &[3.0f32, 4.0, 5.0, 6.0], &[2, 2]
413 );
414 let r = TritFloatTensor::matmul(&identity, &a);
415 let vals = r.to_f32_vec();
416 assert!(approx(vals[0], 3.0, 0.02));
417 assert!(approx(vals[1], 4.0, 0.02));
418 assert!(approx(vals[2], 5.0, 0.02));
419 assert!(approx(vals[3], 6.0, 0.02));
420 }
421
422 #[test]
423 fn matmul_2x3_x_3x2() {
424 let a = TritFloatTensor::from_f32_slice(
426 &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]
427 );
428 let b = TritFloatTensor::from_f32_slice(
429 &[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]
430 );
431 let r = TritFloatTensor::matmul(&a, &b);
432 assert_eq!(r.shape(), &[2, 2]);
433 let v = r.to_f32_vec();
434 assert!(approx(v[0], 58.0, 0.02), "got {}", v[0]);
435 assert!(approx(v[1], 64.0, 0.02), "got {}", v[1]);
436 assert!(approx(v[2], 139.0, 0.02), "got {}", v[2]);
437 assert!(approx(v[3], 154.0, 0.02), "got {}", v[3]);
438 }
439
440 #[test]
441 fn matmul_confidence_propagates() {
442 let acts = TritFloatTensor::from_f32_with_confidence(
444 &[1.0f32, 1.0], &[0.125f32, 0.125], &[1, 2]
445 );
446 let weights = TritFloatTensor::from_f32_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2]);
447 let r = TritFloatTensor::matmul(&acts, &weights);
448 assert!(r.min_confidence() < 0.3, "low-conf inputs → low-conf output");
449 }
450
451 #[test]
452 fn matmul_sparse_skip_count() {
453 let acts = TritFloatTensor::from_f32_slice(
455 &[1.0f32, 0.0, 1.0, 0.0], &[1, 4]
456 );
457 let w = TritFloatTensor::from_f32_slice(
458 &[1.0f32; 8], &[4, 2]
459 );
460 let (_, skips) = TritFloatTensor::matmul_sparse(&acts, &w);
461 assert!(skips > 0, "zero activations should produce skips");
462 }
463
464 #[test]
465 fn matmul_trit_matches_dense() {
466 use crate::TritMatrix;
468 let acts = TritFloatTensor::from_f32_slice(&[1.0f32, -1.0], &[1, 2]);
469 let mut w = TritMatrix::new(2, 2);
470 w.set(0, 0, Trit::Affirm); w.set(0, 1, Trit::Tend); w.set(1, 0, Trit::Reject); w.set(1, 1, Trit::Affirm); let (r, _) = TritFloatTensor::matmul_trit(&acts, &w);
476 assert_eq!(r.shape(), &[1, 2]);
477 let v = r.to_f32_vec();
478 assert!(approx(v[0], 2.0, 0.02), "col0: expected 2, got {}", v[0]);
480 assert!(approx(v[1], -1.0, 0.02), "col1: expected -1, got {}", v[1]);
482 }
483
484 #[test]
485 fn elementwise_add_and_mul() {
486 let a = TritFloatTensor::from_f32_slice(&[1.0f32, 2.0, 3.0], &[3]);
487 let b = TritFloatTensor::from_f32_slice(&[4.0f32, 5.0, 6.0], &[3]);
488 let s = TritFloatTensor::add_elementwise(&a, &b);
489 let p = TritFloatTensor::mul_elementwise(&a, &b);
490 let sv = s.to_f32_vec();
491 let pv = p.to_f32_vec();
492 assert!(approx(sv[0], 5.0, 0.02));
493 assert!(approx(sv[2], 9.0, 0.02));
494 assert!(approx(pv[0], 4.0, 0.02));
495 assert!(approx(pv[2], 18.0, 0.02));
496 }
497
498 #[test]
499 fn map_applies_function() {
500 let t = TritFloatTensor::from_f32_slice(&[1.0f32, 4.0, 9.0], &[3]);
501 let r = t.map(|x| x.sqrt());
502 let v = r.to_f32_vec();
503 assert!(approx(v[0], 1.0, 0.02));
504 assert!(approx(v[1], 2.0, 0.02));
505 assert!(approx(v[2], 3.0, 0.02));
506 }
507
508 #[test]
509 fn sparsity_correct() {
510 let t = TritFloatTensor::from_f32_slice(&[1.0f32, 0.0, -1.0, 0.0], &[2, 2]);
512 assert!((t.sparsity() - 0.5).abs() < 1e-6);
513 }
514
515 #[test]
516 fn confidence_histogram_bins() {
517 let t = TritFloatTensor::from_f32_with_confidence(
518 &[1.0f32, 1.0, 1.0],
519 &[0.0f32, 0.5, 1.0],
520 &[3],
521 );
522 let hist = t.confidence_histogram();
523 assert_eq!(hist[0], 1, "one element at conf=0");
524 assert_eq!(hist[4], 1, "one element at conf=0.5");
525 assert_eq!(hist[8], 1, "one element at conf=1.0");
526 }
527
528 #[test]
529 fn min_and_mean_confidence() {
530 let t = TritFloatTensor::from_f32_with_confidence(
531 &[1.0f32, 1.0],
532 &[0.125f32, 1.0],
533 &[2],
534 );
535 assert!((t.min_confidence() - 0.125).abs() < 0.15);
536 let mean = t.mean_confidence();
537 assert!(mean > 0.125 && mean < 1.0, "mean should be between min and max");
538 }
539
540 #[test]
541 fn to_tritmatrix_roundtrip() {
542 let t = TritFloatTensor::from_f32_slice(&[1.0f32, -1.0, 0.0, 0.5], &[2, 2]);
543 let m = t.to_tritmatrix();
544 assert_eq!(m.get(0, 0), Trit::Affirm);
545 assert_eq!(m.get(0, 1), Trit::Reject);
546 assert_eq!(m.get(1, 0), Trit::Tend);
547 assert_eq!(m.get(1, 1), Trit::Affirm);
548 }
549
550 #[test]
551 fn softmax_rows_sums_to_one() {
552 let t = TritFloatTensor::from_f32_slice(
553 &[1.0f32, 2.0, 3.0, 0.1, 0.2, 0.3], &[2, 3]
554 );
555 let sm = t.softmax_rows();
556 for row in 0..2 {
557 let row_sum: f32 = sm.data[row * 3..(row + 1) * 3]
558 .iter().map(|x| x.to_f32()).sum();
559 assert!((row_sum - 1.0).abs() < 0.005, "row {row} sum = {row_sum}");
561 }
562 }
563
564 #[test]
565 fn matmul_sparse_matches_matmul() {
566 let a = TritFloatTensor::from_f32_slice(
567 &[1.0f32, 0.0, 2.0, 0.0, 1.0, 3.0], &[2, 3]
568 );
569 let b = TritFloatTensor::from_f32_slice(
570 &[1.0f32, 2.0, 0.0, 3.0, 4.0, 1.0], &[3, 2]
571 );
572 let r1 = TritFloatTensor::matmul(&a, &b);
573 let (r2, _) = TritFloatTensor::matmul_sparse(&a, &b);
574 for (x, y) in r1.to_f32_vec().iter().zip(r2.to_f32_vec().iter()) {
575 assert!(approx(*x, *y, 0.001), "sparse and dense matmul disagree: {x} vs {y}");
576 }
577 }
578}