1#![allow(clippy::type_complexity)]
83#![cfg_attr(feature = "python", allow(clippy::useless_conversion))] pub mod core;
90
91pub use core::{
93 inference::{InferenceConfig, InferenceEngine, TernaryLayer},
94 quantization::{quantize_absmean, quantize_absmax, QuantizationResult, QuantizeConfig},
95 ternary::{matmul as ternary_matmul_rust, PackedTernary, TernaryMatmulConfig},
96 training::{GradientCompressor, TrainingConfig},
97 vsa::{VsaConfig, VsaOps},
98};
99
100pub mod bitnet;
106
107pub mod ternary;
109
110pub mod vsa;
112
113#[cfg(feature = "python")]
118mod python_bindings {
119 use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2, ToPyArray};
120 use pyo3::exceptions::PyValueError;
121 use pyo3::prelude::*;
122 use pyo3::IntoPyObject;
123 use std::collections::HashMap;
124 use std::sync::{Arc, Mutex};
125
126 use bitnet_quantize::{quantize_weights as bitnet_quantize_weights, BitLinear, BitNetConfig};
128 use candle_nn::Module;
129 use trit_vsa::{PackedTritVec, Trit};
130 use vsa_optim_rs::{DeterministicPhaseConfig, DeterministicPhaseTrainer};
131
132 #[cfg(feature = "cuda")]
133 mod gpu;
134
135#[pyclass(name = "DeterministicTrainer")]
143#[derive(Clone)]
144struct PyDeterministicTrainer {
145 inner: Arc<Mutex<DeterministicPhaseTrainer>>,
146}
147
148#[pyclass(name = "BitLinearLayer")]
150#[derive(Clone)]
151struct PyBitLinearLayer {
152 inner: Arc<BitLinear>,
153}
154
155#[pyfunction]
181#[pyo3(signature = (param_shapes, warmup_steps=10, full_steps=5, predict_steps=20, correct_every=5))]
182fn create_trainer(
183 param_shapes: Vec<(String, Vec<usize>)>,
184 warmup_steps: usize,
185 full_steps: usize,
186 predict_steps: usize,
187 correct_every: usize,
188) -> PyResult<PyDeterministicTrainer> {
189 let config = DeterministicPhaseConfig::default()
190 .with_warmup_steps(warmup_steps)
191 .with_full_steps(full_steps)
192 .with_predict_steps(predict_steps)
193 .with_correct_every(correct_every);
194
195 let device = candle_core::Device::Cpu;
196
197 let trainer = DeterministicPhaseTrainer::new(¶m_shapes, config, &device)
198 .map_err(|e| PyValueError::new_err(format!("Failed to create trainer: {e}")))?;
199
200 Ok(PyDeterministicTrainer {
201 inner: Arc::new(Mutex::new(trainer)),
202 })
203}
204
205#[pyfunction]
232#[pyo3(signature = (trainer, gradients=None, loss=0.0))]
233fn trainer_step<'py>(
234 py: Python<'py>,
235 trainer: &PyDeterministicTrainer,
236 gradients: Option<HashMap<String, PyReadonlyArray2<'py, f32>>>,
237 loss: f32,
238) -> PyResult<HashMap<String, Py<PyAny>>> {
239 let mut inner = trainer
240 .inner
241 .lock()
242 .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
243
244 let step_info = inner
246 .begin_step()
247 .map_err(|e| PyValueError::new_err(format!("begin_step failed: {e}")))?;
248
249 let mut result: HashMap<String, Py<PyAny>> = HashMap::new();
250 result.insert("phase".to_string(), step_info.phase.to_string().into_pyobject(py).unwrap().into_any().unbind());
251 result.insert(
252 "needs_backward".to_string(),
253 step_info.needs_backward.into_pyobject(py).unwrap().to_owned().into_any().unbind(),
254 );
255 result.insert(
256 "total_step".to_string(),
257 step_info.total_step.into_pyobject(py).unwrap().into_any().unbind(),
258 );
259 result.insert("cycle".to_string(), step_info.cycle.into_pyobject(py).unwrap().into_any().unbind());
260
261 if step_info.needs_backward {
263 if let Some(grad_dict) = gradients {
265 let device = candle_core::Device::Cpu;
266 let mut tensor_grads: HashMap<String, candle_core::Tensor> = HashMap::new();
267
268 for (name, arr) in grad_dict {
269 let arr = arr.as_array();
270 let shape = arr.shape();
271 let data: Vec<f32> = arr.iter().copied().collect();
272
273 let tensor =
274 candle_core::Tensor::from_vec(data, shape.to_vec(), &device).map_err(|e| {
275 PyValueError::new_err(format!("Failed to create tensor for {name}: {e}"))
276 })?;
277
278 tensor_grads.insert(name, tensor);
279 }
280
281 inner
282 .record_full_gradients(&tensor_grads)
283 .map_err(|e| PyValueError::new_err(format!("record_full_gradients failed: {e}")))?;
284 }
285 result.insert("predicted_gradients".to_string(), py.None());
286 } else {
287 let predicted = inner
289 .get_predicted_gradients()
290 .map_err(|e| PyValueError::new_err(format!("get_predicted_gradients failed: {e}")))?;
291
292 let mut pred_dict: HashMap<String, Py<PyAny>> = HashMap::new();
293 for (name, tensor) in predicted {
294 let dims = tensor.dims();
295 let flat: Vec<f32> = tensor
296 .flatten_all()
297 .map_err(|e| PyValueError::new_err(format!("flatten failed: {e}")))?
298 .to_vec1()
299 .map_err(|e| PyValueError::new_err(format!("to_vec1 failed: {e}")))?;
300
301 if dims.len() == 1 {
302 let arr = flat.to_pyarray(py);
303 pred_dict.insert(name, arr.into_pyobject(py).unwrap().into_any().unbind());
304 } else {
305 let arr = flat
306 .to_pyarray(py)
307 .reshape(dims.to_vec())
308 .map_err(|e| PyValueError::new_err(format!("reshape failed: {e}")))?;
309 pred_dict.insert(name, arr.into_pyobject(py).unwrap().into_any().unbind());
310 }
311 }
312 result.insert("predicted_gradients".to_string(), pred_dict.into_pyobject(py).unwrap().into_any().unbind());
313 }
314
315 inner
317 .end_step(loss)
318 .map_err(|e| PyValueError::new_err(format!("end_step failed: {e}")))?;
319
320 let stats = inner.get_stats();
322 result.insert("speedup".to_string(), stats.speedup.into_pyobject(py).unwrap().into_any().unbind());
323 result.insert(
324 "mean_prediction_error".to_string(),
325 stats.mean_prediction_error.into_pyobject(py).unwrap().into_any().unbind(),
326 );
327
328 Ok(result)
329}
330
331#[pyfunction]
339fn trainer_get_phase(trainer: &PyDeterministicTrainer) -> PyResult<String> {
340 let inner = trainer
341 .inner
342 .lock()
343 .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
344
345 Ok(inner.current_phase().to_string())
346}
347
348#[pyfunction]
365fn trainer_get_stats(py: Python<'_>, trainer: &PyDeterministicTrainer) -> PyResult<HashMap<String, Py<PyAny>>> {
366 let inner = trainer
367 .inner
368 .lock()
369 .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
370
371 let stats = inner.get_stats();
372 let mut result: HashMap<String, Py<PyAny>> = HashMap::new();
373
374 result.insert("total_steps".to_string(), stats.total_steps.into_pyobject(py).unwrap().into_any().unbind());
375 result.insert("warmup_steps".to_string(), stats.warmup_steps.into_pyobject(py).unwrap().into_any().unbind());
376 result.insert("full_steps".to_string(), stats.full_steps.into_pyobject(py).unwrap().into_any().unbind());
377 result.insert("predict_steps".to_string(), stats.predict_steps.into_pyobject(py).unwrap().into_any().unbind());
378 result.insert("correct_steps".to_string(), stats.correct_steps.into_pyobject(py).unwrap().into_any().unbind());
379 result.insert("cycles".to_string(), stats.cycles.into_pyobject(py).unwrap().into_any().unbind());
380 result.insert("speedup".to_string(), stats.speedup.into_pyobject(py).unwrap().into_any().unbind());
381 result.insert(
382 "mean_prediction_error".to_string(),
383 stats.mean_prediction_error.into_pyobject(py).unwrap().into_any().unbind(),
384 );
385 result.insert("current_loss".to_string(), stats.current_loss.into_pyobject(py).unwrap().into_any().unbind());
386
387 Ok(result)
388}
389
390#[pyfunction]
395fn trainer_reset(trainer: &PyDeterministicTrainer) -> PyResult<()> {
396 let mut inner = trainer
397 .inner
398 .lock()
399 .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
400
401 inner
402 .reset()
403 .map_err(|e| PyValueError::new_err(format!("reset failed: {e}")))?;
404
405 Ok(())
406}
407
408#[pyfunction]
432#[pyo3(signature = (weight, bias=None, group_size=64))]
433fn create_bitlinear<'py>(
434 weight: PyReadonlyArray2<'py, f32>,
435 bias: Option<PyReadonlyArray1<'py, f32>>,
436 group_size: usize,
437) -> PyResult<PyBitLinearLayer> {
438 let weight_arr = weight.as_array();
439 let (out_features, in_features) = (weight_arr.nrows(), weight_arr.ncols());
440
441 let device = candle_core::Device::Cpu;
442 let weight_data: Vec<f32> = weight_arr.iter().copied().collect();
443 let weight_tensor =
444 candle_core::Tensor::from_vec(weight_data, (out_features, in_features), &device)
445 .map_err(|e| PyValueError::new_err(format!("Failed to create weight tensor: {e}")))?;
446
447 let bias_tensor = if let Some(b) = bias {
448 let bias_arr = b.as_array();
449 let bias_data: Vec<f32> = bias_arr.iter().copied().collect();
450 Some(
451 candle_core::Tensor::from_vec(bias_data, (out_features,), &device)
452 .map_err(|e| PyValueError::new_err(format!("Failed to create bias tensor: {e}")))?,
453 )
454 } else {
455 None
456 };
457
458 let config = BitNetConfig::default().with_group_size(group_size);
459
460 let layer = BitLinear::from_weight(&weight_tensor, bias_tensor.as_ref(), &config)
461 .map_err(|e| PyValueError::new_err(format!("Failed to create BitLinear: {e}")))?;
462
463 Ok(PyBitLinearLayer {
464 inner: Arc::new(layer),
465 })
466}
467
468#[pyfunction]
479fn bitlinear_forward<'py>(
480 py: Python<'py>,
481 layer: &PyBitLinearLayer,
482 input: PyReadonlyArray2<'py, f32>,
483) -> PyResult<Bound<'py, PyArray2<f32>>> {
484 let input_arr = input.as_array();
485 let (batch_size, in_features) = (input_arr.nrows(), input_arr.ncols());
486
487 let device = candle_core::Device::Cpu;
488 let input_data: Vec<f32> = input_arr.iter().copied().collect();
489 let input_tensor =
490 candle_core::Tensor::from_vec(input_data, (batch_size, in_features), &device)
491 .map_err(|e| PyValueError::new_err(format!("Failed to create input tensor: {e}")))?;
492
493 let output_tensor = layer
494 .inner
495 .forward(&input_tensor)
496 .map_err(|e| PyValueError::new_err(format!("Forward pass failed: {e}")))?;
497
498 let output_dims = output_tensor.dims();
499 let output_data: Vec<f32> = output_tensor
500 .flatten_all()
501 .map_err(|e| PyValueError::new_err(format!("flatten failed: {e}")))?
502 .to_vec1()
503 .map_err(|e| PyValueError::new_err(format!("to_vec1 failed: {e}")))?;
504
505 Ok(output_data
506 .to_pyarray(py)
507 .reshape([output_dims[0], output_dims[1]])
508 .map_err(|e| PyValueError::new_err(format!("reshape failed: {e}")))?
509 .to_owned())
510}
511
512#[pyfunction]
520fn bitlinear_compression_ratio(layer: &PyBitLinearLayer) -> f32 {
521 layer.inner.compression_ratio()
522}
523
524#[pyfunction]
532fn bitlinear_sparsity(layer: &PyBitLinearLayer) -> f32 {
533 layer.inner.sparsity()
534}
535
536#[pyfunction]
544fn bitlinear_in_features(layer: &PyBitLinearLayer) -> usize {
545 layer.inner.in_features()
546}
547
548#[pyfunction]
556fn bitlinear_out_features(layer: &PyBitLinearLayer) -> usize {
557 layer.inner.out_features()
558}
559
560#[pyfunction]
576fn pack_ternary_weights<'py>(
577 py: Python<'py>,
578 weights: PyReadonlyArray2<'py, f32>,
579 scales: PyReadonlyArray1<'py, f32>,
580) -> PyResult<(Bound<'py, PyArray1<u8>>, Bound<'py, PyArray1<f32>>)> {
581 let weights = weights.as_array();
582 let scales = scales.as_array();
583
584 let (rows, cols) = (weights.nrows(), weights.ncols());
585
586 let mut packed_vecs: Vec<PackedTritVec> = Vec::with_capacity(rows);
588
589 for row in weights.rows() {
590 let mut packed = PackedTritVec::new(cols);
591 for (col_idx, &val) in row.iter().enumerate() {
592 let trit = match val as i8 {
593 v if v > 0 => Trit::P, v if v < 0 => Trit::N, _ => Trit::Z, };
597 packed.set(col_idx, trit);
598 }
599 packed_vecs.push(packed);
600 }
601
602 let packed_size = cols.div_ceil(4);
604 let mut packed = vec![0u8; rows * packed_size];
605
606 for (row_idx, pvec) in packed_vecs.iter().enumerate() {
607 for col_idx in 0..cols {
608 let trit = pvec.get(col_idx);
609 let trit_bits = match trit {
610 Trit::P => 0b01, Trit::N => 0b10, Trit::Z => 0b00, };
614
615 let byte_idx = row_idx * packed_size + col_idx / 4;
616 let bit_offset = (col_idx % 4) * 2;
617 packed[byte_idx] |= trit_bits << bit_offset;
618 }
619 }
620
621 let packed_array = packed.to_pyarray(py);
622 let scales_array = scales.to_vec().to_pyarray(py);
623
624 Ok((packed_array, scales_array))
625}
626
627#[pyfunction]
639fn unpack_ternary_weights<'py>(
640 py: Python<'py>,
641 packed: PyReadonlyArray1<'py, u8>,
642 scales: PyReadonlyArray1<'py, f32>,
643 shape: (usize, usize),
644) -> PyResult<Bound<'py, PyArray2<f32>>> {
645 let packed = packed.as_array();
646 let scales = scales.as_array();
647 let (rows, cols) = shape;
648
649 let packed_size = cols.div_ceil(4);
650
651 let mut packed_vecs: Vec<PackedTritVec> = Vec::with_capacity(rows);
653
654 for row_idx in 0..rows {
655 let mut pvec = PackedTritVec::new(cols);
656 for col_idx in 0..cols {
657 let byte_idx = row_idx * packed_size + col_idx / 4;
658 let bit_offset = (col_idx % 4) * 2;
659 let trit_bits = (packed[byte_idx] >> bit_offset) & 0b11;
660
661 let trit = match trit_bits {
662 0b01 => Trit::P, 0b10 => Trit::N, _ => Trit::Z, };
666 pvec.set(col_idx, trit);
667 }
668 packed_vecs.push(pvec);
669 }
670
671 let mut weights = vec![0.0f32; rows * cols];
673
674 for (row_idx, pvec) in packed_vecs.iter().enumerate() {
675 let scale = scales[row_idx];
676 for col_idx in 0..cols {
677 let value = f32::from(pvec.get(col_idx).value()) * scale;
678 weights[row_idx * cols + col_idx] = value;
679 }
680 }
681
682 Ok(weights
683 .to_pyarray(py)
684 .reshape([rows, cols])
685 .expect("reshape failed")
686 .to_owned())
687}
688
689#[pyfunction]
704fn ternary_matmul<'py>(
705 py: Python<'py>,
706 input: PyReadonlyArray2<'py, f32>,
707 packed_weights: PyReadonlyArray1<'py, u8>,
708 scales: PyReadonlyArray1<'py, f32>,
709 weight_shape: (usize, usize),
710) -> PyResult<Bound<'py, PyArray2<f32>>> {
711 let input = input.as_array();
712 let packed = packed_weights.as_array();
713 let scales = scales.as_array();
714
715 let (batch_size, in_features) = (input.nrows(), input.ncols());
716 let (out_features, _) = weight_shape;
717
718 if in_features != weight_shape.1 {
719 return Err(PyValueError::new_err(format!(
720 "Input features {} doesn't match weight features {}",
721 in_features, weight_shape.1
722 )));
723 }
724
725 let packed_cols = in_features.div_ceil(4);
726
727 let mut weight_vecs: Vec<PackedTritVec> = Vec::with_capacity(out_features);
729
730 for o in 0..out_features {
731 let mut pvec = PackedTritVec::new(in_features);
732 for i in 0..in_features {
733 let byte_idx = o * packed_cols + i / 4;
734 let bit_offset = (i % 4) * 2;
735 let trit_bits = (packed[byte_idx] >> bit_offset) & 0b11;
736
737 let trit = match trit_bits {
738 0b01 => Trit::P,
739 0b10 => Trit::N,
740 _ => Trit::Z,
741 };
742 pvec.set(i, trit);
743 }
744 weight_vecs.push(pvec);
745 }
746
747 let mut output = vec![0.0f32; batch_size * out_features];
748
749 for b in 0..batch_size {
751 for (o, weight_vec) in weight_vecs.iter().enumerate() {
752 let scale = scales[o];
753 let mut sum = 0.0f32;
754
755 for i in 0..in_features {
757 let trit = weight_vec.get(i);
758 let x = input[[b, i]];
759 sum += match trit {
760 Trit::P => x, Trit::N => -x, Trit::Z => 0.0, };
764 }
765
766 output[b * out_features + o] = sum * scale;
767 }
768 }
769
770 Ok(output
771 .to_pyarray(py)
772 .reshape([batch_size, out_features])
773 .expect("reshape failed")
774 .to_owned())
775}
776
777#[pyfunction]
787fn quantize_weights_absmean<'py>(
788 py: Python<'py>,
789 weights: PyReadonlyArray2<'py, f32>,
790) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
791 let weights_arr = weights.as_array();
792 let (rows, cols) = (weights_arr.nrows(), weights_arr.ncols());
793
794 let config = BitNetConfig::default().with_group_size(cols);
796
797 let device = candle_core::Device::Cpu;
799 let weight_data: Vec<f32> = weights_arr.iter().copied().collect();
800 let weight_tensor = candle_core::Tensor::from_vec(weight_data, (rows, cols), &device)
801 .map_err(|e| PyValueError::new_err(format!("Failed to create tensor: {e}")))?;
802
803 let ternary = bitnet_quantize_weights(&weight_tensor, &config)
805 .map_err(|e| PyValueError::new_err(format!("Quantization failed: {e}")))?;
806
807 let mut ternary_output = vec![0.0f32; rows * cols];
809 let scales: Vec<f32> = ternary.scales.clone();
810
811 for (row_idx, packed) in ternary.data.iter().enumerate() {
812 for col_idx in 0..cols {
813 let trit = packed.get(col_idx);
814 ternary_output[row_idx * cols + col_idx] = f32::from(trit.value());
815 }
816 }
817
818 Ok((
819 ternary_output
820 .to_pyarray(py)
821 .reshape([rows, cols])
822 .expect("reshape failed")
823 .to_owned(),
824 scales.to_pyarray(py),
825 ))
826}
827
828#[pyfunction]
841#[allow(clippy::cast_precision_loss)]
842fn compress_gradients_vsa<'py>(
843 py: Python<'py>,
844 gradients: PyReadonlyArray1<'py, f32>,
845 compression_ratio: f32,
846 seed: u64,
847) -> PyResult<(Bound<'py, PyArray1<f32>>, u64)> {
848 use rand::{Rng, SeedableRng};
849 use rand_chacha::ChaCha8Rng;
850
851 let gradients = gradients.as_array();
852 let original_dim = gradients.len();
853 let compressed_dim = ((original_dim as f32 * compression_ratio).ceil() as usize).max(256);
854
855 let mut rng = ChaCha8Rng::seed_from_u64(seed);
856 let mut compressed = vec![0.0f32; compressed_dim];
857
858 let scale = 1.0 / (original_dim as f32).sqrt();
860 for &g in gradients.iter() {
861 for c in compressed.iter_mut() {
862 let r: f32 = rng.gen();
864 if r < 0.16 {
865 *c += g * scale;
866 } else if r < 0.32 {
867 *c -= g * scale;
868 }
869 }
870 }
871
872 Ok((compressed.to_pyarray(py), seed))
873}
874
875#[pyfunction]
887#[allow(clippy::cast_precision_loss)]
888fn decompress_gradients_vsa<'py>(
889 py: Python<'py>,
890 compressed: PyReadonlyArray1<'py, f32>,
891 original_dim: usize,
892 seed: u64,
893) -> PyResult<Bound<'py, PyArray1<f32>>> {
894 use rand::{Rng, SeedableRng};
895 use rand_chacha::ChaCha8Rng;
896
897 let compressed = compressed.as_array();
898
899 let mut rng = ChaCha8Rng::seed_from_u64(seed);
900 let mut gradients = vec![0.0f32; original_dim];
901
902 let scale = 1.0 / (original_dim as f32).sqrt();
904 for g in gradients.iter_mut() {
905 for &c in compressed.iter() {
906 let r: f32 = rng.gen();
907 if r < 0.16 {
908 *g += c * scale;
909 } else if r < 0.32 {
910 *g -= c * scale;
911 }
912 }
913 }
914
915 Ok(gradients.to_pyarray(py))
916}
917
918#[pyfunction]
920fn version() -> &'static str {
921 env!("CARGO_PKG_VERSION")
922}
923
924#[pyfunction]
926fn cuda_available_py() -> bool {
927 #[cfg(feature = "cuda")]
928 {
929 candle_core::Device::cuda_if_available(0)
930 .map(|d| matches!(d, candle_core::Device::Cuda(_)))
931 .unwrap_or(false)
932 }
933 #[cfg(not(feature = "cuda"))]
934 {
935 false
936 }
937}
938
939#[pymodule]
941fn tritter_accel(m: &Bound<'_, PyModule>) -> PyResult<()> {
942 m.add_class::<PyDeterministicTrainer>()?;
944 m.add_class::<PyBitLinearLayer>()?;
945
946 m.add_function(wrap_pyfunction!(create_trainer, m)?)?;
948 m.add_function(wrap_pyfunction!(trainer_step, m)?)?;
949 m.add_function(wrap_pyfunction!(trainer_get_phase, m)?)?;
950 m.add_function(wrap_pyfunction!(trainer_get_stats, m)?)?;
951 m.add_function(wrap_pyfunction!(trainer_reset, m)?)?;
952
953 m.add_function(wrap_pyfunction!(create_bitlinear, m)?)?;
955 m.add_function(wrap_pyfunction!(bitlinear_forward, m)?)?;
956 m.add_function(wrap_pyfunction!(bitlinear_compression_ratio, m)?)?;
957 m.add_function(wrap_pyfunction!(bitlinear_sparsity, m)?)?;
958 m.add_function(wrap_pyfunction!(bitlinear_in_features, m)?)?;
959 m.add_function(wrap_pyfunction!(bitlinear_out_features, m)?)?;
960
961 m.add_function(wrap_pyfunction!(pack_ternary_weights, m)?)?;
963 m.add_function(wrap_pyfunction!(unpack_ternary_weights, m)?)?;
964 m.add_function(wrap_pyfunction!(ternary_matmul, m)?)?;
965 m.add_function(wrap_pyfunction!(quantize_weights_absmean, m)?)?;
966
967 m.add_function(wrap_pyfunction!(compress_gradients_vsa, m)?)?;
969 m.add_function(wrap_pyfunction!(decompress_gradients_vsa, m)?)?;
970
971 m.add_function(wrap_pyfunction!(version, m)?)?;
973 m.add_function(wrap_pyfunction!(cuda_available_py, m)?)?;
974
975 #[cfg(feature = "cuda")]
977 {
978 m.add_function(wrap_pyfunction!(gpu::cuda_available, m)?)?;
979 m.add_function(wrap_pyfunction!(gpu::get_device, m)?)?;
980 m.add_function(wrap_pyfunction!(gpu::set_device, m)?)?;
981 m.add_function(wrap_pyfunction!(gpu::vsa_bind, m)?)?;
982 m.add_function(wrap_pyfunction!(gpu::vsa_unbind, m)?)?;
983 m.add_function(wrap_pyfunction!(gpu::vsa_bundle, m)?)?;
984 m.add_function(wrap_pyfunction!(gpu::vsa_similarity, m)?)?;
985 m.add_function(wrap_pyfunction!(gpu::vsa_random, m)?)?;
986 }
987
988 Ok(())
989}
990} #[cfg(feature = "python")]
994pub use python_bindings::*;
995
996#[cfg(test)]
1001mod tests {
1002 use trit_vsa::{PackedTritVec, Trit};
1003 use bitnet_quantize::{quantize_weights as bitnet_quantize_weights, BitLinear, BitNetConfig};
1004 use candle_nn::Module;
1005
1006 #[test]
1007 fn test_pack_unpack_roundtrip() {
1008 let weights = vec![1.0, 0.0, -1.0, 1.0, -1.0, 1.0, 0.0, -1.0];
1010 let scales = vec![1.0, 1.0];
1011
1012 let mut packed_vecs = Vec::new();
1014 for row in 0..2 {
1015 let mut pvec = PackedTritVec::new(4);
1016 for col in 0..4 {
1017 let val = weights[row * 4 + col];
1018 let trit = match val as i8 {
1019 v if v > 0 => Trit::P,
1020 v if v < 0 => Trit::N,
1021 _ => Trit::Z,
1022 };
1023 pvec.set(col, trit);
1024 }
1025 packed_vecs.push(pvec);
1026 }
1027
1028 let mut unpacked = vec![0.0f32; 8];
1030 for (row, pvec) in packed_vecs.iter().enumerate() {
1031 let scale = scales[row];
1032 for col in 0..4 {
1033 unpacked[row * 4 + col] = f32::from(pvec.get(col).value()) * scale;
1034 }
1035 }
1036
1037 assert_eq!(unpacked, weights);
1038 }
1039
1040 #[test]
1041 fn test_quantize_absmean_with_bitnet() {
1042 let weights = vec![0.5, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4];
1044
1045 let device = candle_core::Device::Cpu;
1047 let config = BitNetConfig::default().with_group_size(4);
1048 let tensor = candle_core::Tensor::from_vec(weights.clone(), (2, 4), &device).unwrap();
1049
1050 let ternary = bitnet_quantize_weights(&tensor, &config).unwrap();
1051
1052 assert_eq!(ternary.shape, (2, 4));
1054 assert_eq!(ternary.data.len(), 2);
1055 }
1056
1057 #[test]
1058 fn test_trit_vsa_dot_product() {
1059 let mut a = PackedTritVec::new(4);
1061 let mut b = PackedTritVec::new(4);
1062
1063 a.set(0, Trit::P);
1065 a.set(1, Trit::N);
1066 a.set(2, Trit::Z);
1067 a.set(3, Trit::P);
1068
1069 b.set(0, Trit::P);
1071 b.set(1, Trit::P);
1072 b.set(2, Trit::N);
1073 b.set(3, Trit::Z);
1074
1075 assert_eq!(a.dot(&b), 0);
1077 }
1078
1079 #[test]
1080 fn test_rust_api_quantization() {
1081 use crate::core::quantization::{quantize_absmean, QuantizeConfig};
1083
1084 let device = candle_core::Device::Cpu;
1085 let weights = candle_core::Tensor::from_vec(
1086 vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4],
1087 (2, 4),
1088 &device,
1089 )
1090 .unwrap();
1091
1092 let config = QuantizeConfig::default();
1093 let result = quantize_absmean(&weights, &config).unwrap();
1094
1095 assert_eq!(result.shape, (2, 4));
1096 assert_eq!(result.values.len(), 8);
1097
1098 for v in &result.values {
1100 assert!([-1, 0, 1].contains(v));
1101 }
1102 }
1103
1104 #[test]
1105 fn test_rust_api_gradient_compression() {
1106 use crate::core::training::{GradientCompressor, TrainingConfig};
1108
1109 let config = TrainingConfig::default();
1110 let compressor = GradientCompressor::new(config);
1111
1112 let gradients: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
1113 let compressed = compressor.compress(&gradients, Some(0.1)).unwrap();
1114 let recovered = compressor.decompress(&compressed).unwrap();
1115
1116 assert_eq!(recovered.len(), gradients.len());
1117 }
1118
1119 #[test]
1120 fn test_deterministic_phase_trainer_creation() {
1121 use vsa_optim_rs::{DeterministicPhaseConfig, DeterministicPhaseTrainer};
1123
1124 let shapes = vec![
1125 ("layer.weight".to_string(), vec![16, 32]),
1126 ("layer.bias".to_string(), vec![16]),
1127 ];
1128
1129 let config = DeterministicPhaseConfig::default()
1130 .with_warmup_steps(5)
1131 .with_full_steps(3)
1132 .with_predict_steps(10);
1133
1134 let trainer =
1135 DeterministicPhaseTrainer::new(&shapes, config, &candle_core::Device::Cpu).unwrap();
1136
1137 assert_eq!(
1139 trainer.current_phase(),
1140 vsa_optim_rs::DeterministicPhase::Warmup
1141 );
1142 }
1143
1144 #[test]
1145 fn test_deterministic_trainer_step_cycle() {
1146 use std::collections::HashMap;
1147 use vsa_optim_rs::{DeterministicPhaseConfig, DeterministicPhaseTrainer};
1148
1149 let shapes = vec![
1150 ("layer.weight".to_string(), vec![8, 16]),
1151 ("layer.bias".to_string(), vec![8]),
1152 ];
1153
1154 let config = DeterministicPhaseConfig::default()
1155 .with_warmup_steps(3)
1156 .with_full_steps(2)
1157 .with_predict_steps(5);
1158
1159 let mut trainer =
1160 DeterministicPhaseTrainer::new(&shapes, config, &candle_core::Device::Cpu).unwrap();
1161
1162 for i in 0..10 {
1164 let info = trainer.begin_step().unwrap();
1165
1166 if info.needs_backward {
1167 let mut grads = HashMap::new();
1169 grads.insert(
1170 "layer.weight".to_string(),
1171 candle_core::Tensor::ones((8, 16), candle_core::DType::F32, &candle_core::Device::Cpu)
1172 .unwrap()
1173 .affine((i as f64 + 1.0) * 0.1, 0.0)
1174 .unwrap(),
1175 );
1176 grads.insert(
1177 "layer.bias".to_string(),
1178 candle_core::Tensor::ones(8, candle_core::DType::F32, &candle_core::Device::Cpu)
1179 .unwrap()
1180 .affine((i as f64 + 1.0) * 0.1, 0.0)
1181 .unwrap(),
1182 );
1183 trainer.record_full_gradients(&grads).unwrap();
1184 } else {
1185 let _predicted = trainer.get_predicted_gradients();
1186 }
1187
1188 trainer.end_step(1.0 / (i + 1) as f32).unwrap();
1189 }
1190
1191 let stats = trainer.get_stats();
1192 assert_eq!(stats.total_steps, 10);
1193 }
1194
1195 #[test]
1196 fn test_bitlinear_layer_creation() {
1197 let device = candle_core::Device::Cpu;
1198 let config = BitNetConfig::default().with_group_size(16);
1199
1200 let weight =
1201 candle_core::Tensor::randn(0.0f32, 1.0, (32, 64), &device).unwrap();
1202
1203 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
1204
1205 assert_eq!(layer.in_features(), 64);
1206 assert_eq!(layer.out_features(), 32);
1207 assert!(layer.compression_ratio() > 1.0);
1208 }
1209
1210 #[test]
1211 fn test_bitlinear_forward_pass() {
1212 let device = candle_core::Device::Cpu;
1213 let config = BitNetConfig::default().with_group_size(16);
1214
1215 let weight =
1216 candle_core::Tensor::randn(0.0f32, 1.0, (32, 64), &device).unwrap();
1217 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
1218
1219 let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
1220 let output = layer.forward(&input).unwrap();
1221
1222 assert_eq!(output.shape().dims(), &[4, 32]);
1223 }
1224
1225 #[test]
1226 fn test_bitlinear_with_bias() {
1227 let device = candle_core::Device::Cpu;
1228 let config = BitNetConfig::default().with_group_size(16);
1229
1230 let weight =
1231 candle_core::Tensor::randn(0.0f32, 1.0, (32, 64), &device).unwrap();
1232 let bias = candle_core::Tensor::randn(0.0f32, 1.0, (32,), &device).unwrap();
1233
1234 let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
1235
1236 assert!(layer.bias().is_some());
1237
1238 let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
1239 let output = layer.forward(&input).unwrap();
1240
1241 assert_eq!(output.shape().dims(), &[4, 32]);
1242 }
1243}