Skip to main content

tritter_accel/
lib.rs

1//! Rust acceleration for AI training and inference.
2//!
3//! `tritter-accel` provides high-performance operations for both ternary
4//! (BitNet-style) and conventional neural network workloads. It serves as
5//! an acceleration layer that can be used from either Rust or Python.
6//!
7//! # Architecture
8//!
9//! The crate is organized into two main APIs:
10//!
11//! - **Rust API** (`core` module): Pure Rust interfaces for direct integration
12//! - **Python API** (PyO3 bindings): NumPy-compatible functions for Python users
13//!
14//! # Rust Usage
15//!
16//! ```rust,ignore
17//! use tritter_accel::core::{
18//!     ternary::{PackedTernary, matmul},
19//!     quantization::{quantize_absmean, QuantizeConfig},
20//!     training::{GradientCompressor, TrainingConfig},
21//!     inference::{InferenceEngine, InferenceConfig},
22//! };
23//! use candle_core::{Device, Tensor};
24//!
25//! // Quantize weights to ternary
26//! let device = Device::Cpu;
27//! let weights = Tensor::randn(0f32, 1f32, (512, 512), &device)?;
28//! let result = quantize_absmean(&weights, &QuantizeConfig::default())?;
29//!
30//! // Create packed representation for efficient matmul
31//! let packed = result.to_packed()?;
32//!
33//! // Run ternary matmul
34//! let input = Tensor::randn(0f32, 1f32, (1, 512), &device)?;
35//! let output = matmul(&input, &packed, None)?;
36//!
37//! // Compress gradients for distributed training
38//! let compressor = GradientCompressor::new(TrainingConfig::default());
39//! let gradients: Vec<f32> = vec![0.1, -0.2, 0.3, -0.4];
40//! let compressed = compressor.compress(&gradients, Some(0.1))?;
41//! ```
42//!
43//! # Python Usage
44//!
45//! Build with maturin:
46//! ```bash
47//! cd rust-ai/tritter-accel
48//! maturin develop --release
49//! ```
50//!
51//! Then in Python:
52//! ```python
53//! from tritter_accel import (
54//!     pack_ternary_weights,
55//!     unpack_ternary_weights,
56//!     ternary_matmul,
57//!     quantize_weights_absmean,
58//!     compress_gradients_vsa,
59//! )
60//!
61//! # Pack weights for efficient storage
62//! packed = pack_ternary_weights(ternary_weights, scales)
63//!
64//! # Efficient matmul with packed weights
65//! output = ternary_matmul(input, packed)
66//!
67//! # Compress gradients for distributed training
68//! compressed = compress_gradients_vsa(gradients, compression_ratio=0.1)
69//! ```
70//!
71//! # Features
72//!
73//! - `cuda`: Enable GPU acceleration via CubeCL (requires CUDA toolkit)
74//!
75//! # Modules
76//!
77//! - [`core`]: Pure Rust API for direct integration
78//! - [`bitnet`]: Re-exports from `bitnet-quantize`
79//! - [`ternary`]: Re-exports from `trit-vsa`
80//! - [`vsa`]: Re-exports from `vsa-optim-rs`
81
82#![allow(clippy::type_complexity)]
83#![cfg_attr(feature = "python", allow(clippy::useless_conversion))] // PyO3 macro generates false positives
84
85// =============================================================================
86// CORE RUST API
87// =============================================================================
88
89pub mod core;
90
91// Re-export core types at crate root for convenience
92pub use core::{
93    inference::{InferenceConfig, InferenceEngine, TernaryLayer},
94    quantization::{quantize_absmean, quantize_absmax, QuantizationResult, QuantizeConfig},
95    ternary::{matmul as ternary_matmul_rust, PackedTernary, TernaryMatmulConfig},
96    training::{GradientCompressor, TrainingConfig},
97    vsa::{VsaConfig, VsaOps},
98};
99
100// =============================================================================
101// RE-EXPORTS FROM SISTER CRATES
102// =============================================================================
103
104/// Re-exports from `bitnet-quantize` for direct access.
105pub mod bitnet;
106
107/// Re-exports from `trit-vsa` for direct access.
108pub mod ternary;
109
110/// Re-exports from `vsa-optim-rs` for direct access.
111pub mod vsa;
112
113// =============================================================================
114// PYTHON BINDINGS (PyO3) - Only compiled with "python" feature
115// =============================================================================
116
117#[cfg(feature = "python")]
118mod python_bindings {
119    use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2, ToPyArray};
120    use pyo3::exceptions::PyValueError;
121    use pyo3::prelude::*;
122    use pyo3::IntoPyObject;
123    use std::collections::HashMap;
124    use std::sync::{Arc, Mutex};
125
126    // Delegate to sister crates
127    use bitnet_quantize::{quantize_weights as bitnet_quantize_weights, BitLinear, BitNetConfig};
128    use candle_nn::Module;
129    use trit_vsa::{PackedTritVec, Trit};
130    use vsa_optim_rs::{DeterministicPhaseConfig, DeterministicPhaseTrainer};
131
132    #[cfg(feature = "cuda")]
133    mod gpu;
134
135// =============================================================================
136// PYTHON WRAPPER TYPES FOR STATEFUL OBJECTS
137// =============================================================================
138
139/// Opaque handle for Python to reference a `DeterministicPhaseTrainer`.
140///
141/// Uses `Arc<Mutex<...>>` for thread-safe sharing between Python calls.
142#[pyclass(name = "DeterministicTrainer")]
143#[derive(Clone)]
144struct PyDeterministicTrainer {
145    inner: Arc<Mutex<DeterministicPhaseTrainer>>,
146}
147
148/// Opaque handle for Python to reference a `BitLinear` layer.
149#[pyclass(name = "BitLinearLayer")]
150#[derive(Clone)]
151struct PyBitLinearLayer {
152    inner: Arc<BitLinear>,
153}
154
155// =============================================================================
156// DETERMINISTIC PHASE TRAINER BINDINGS
157// =============================================================================
158
159/// Create a deterministic phase trainer.
160///
161/// # Arguments
162/// * `param_shapes` - List of (name, shape) tuples for model parameters.
163///   Shape can be 1D or 2D (e.g., `[("layer.weight", [768, 768])]`).
164/// * `warmup_steps` - Number of warmup steps before prediction begins (default: 10).
165/// * `full_steps` - Full gradient steps per cycle after warmup (default: 5).
166/// * `predict_steps` - Prediction steps per cycle (default: 20).
167/// * `correct_every` - Correction frequency during prediction (default: 5).
168///
169/// # Returns
170/// A `DeterministicTrainer` handle to use with `trainer_step` and `trainer_get_phase`.
171///
172/// # Example
173/// ```python
174/// trainer = create_trainer(
175///     param_shapes=[("layer.weight", [768, 768]), ("layer.bias", [768])],
176///     warmup_steps=10,
177///     predict_steps=20,
178/// )
179/// ```
180#[pyfunction]
181#[pyo3(signature = (param_shapes, warmup_steps=10, full_steps=5, predict_steps=20, correct_every=5))]
182fn create_trainer(
183    param_shapes: Vec<(String, Vec<usize>)>,
184    warmup_steps: usize,
185    full_steps: usize,
186    predict_steps: usize,
187    correct_every: usize,
188) -> PyResult<PyDeterministicTrainer> {
189    let config = DeterministicPhaseConfig::default()
190        .with_warmup_steps(warmup_steps)
191        .with_full_steps(full_steps)
192        .with_predict_steps(predict_steps)
193        .with_correct_every(correct_every);
194
195    let device = candle_core::Device::Cpu;
196
197    let trainer = DeterministicPhaseTrainer::new(&param_shapes, config, &device)
198        .map_err(|e| PyValueError::new_err(format!("Failed to create trainer: {e}")))?;
199
200    Ok(PyDeterministicTrainer {
201        inner: Arc::new(Mutex::new(trainer)),
202    })
203}
204
205/// Process one training step with the deterministic phase trainer.
206///
207/// # Arguments
208/// * `trainer` - The trainer handle from `create_trainer`.
209/// * `gradients` - Dictionary mapping parameter names to gradient arrays.
210///   Required during WARMUP, FULL, and CORRECT phases. Can be `None` during PREDICT.
211/// * `loss` - Loss value for this step.
212///
213/// # Returns
214/// Dictionary with step information:
215/// - `phase`: Current phase name ("WARMUP", "FULL", "PREDICT", "CORRECT")
216/// - `needs_backward`: Whether backward pass is needed next step
217/// - `total_step`: Total steps taken
218/// - `speedup`: Effective speedup ratio
219/// - `predicted_gradients`: If in PREDICT phase, the predicted gradients (dict of arrays)
220///
221/// # Example
222/// ```python
223/// result = trainer_step(trainer, gradients={"layer.weight": grad_array}, loss=0.5)
224/// if result["needs_backward"]:
225///     # Compute gradients via backpropagation
226///     pass
227/// else:
228///     # Use predicted_gradients from result
229///     predicted = result["predicted_gradients"]
230/// ```
231#[pyfunction]
232#[pyo3(signature = (trainer, gradients=None, loss=0.0))]
233fn trainer_step<'py>(
234    py: Python<'py>,
235    trainer: &PyDeterministicTrainer,
236    gradients: Option<HashMap<String, PyReadonlyArray2<'py, f32>>>,
237    loss: f32,
238) -> PyResult<HashMap<String, Py<PyAny>>> {
239    let mut inner = trainer
240        .inner
241        .lock()
242        .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
243
244    // Begin step to get phase info
245    let step_info = inner
246        .begin_step()
247        .map_err(|e| PyValueError::new_err(format!("begin_step failed: {e}")))?;
248
249    let mut result: HashMap<String, Py<PyAny>> = HashMap::new();
250    result.insert("phase".to_string(), step_info.phase.to_string().into_pyobject(py).unwrap().into_any().unbind());
251    result.insert(
252        "needs_backward".to_string(),
253        step_info.needs_backward.into_pyobject(py).unwrap().to_owned().into_any().unbind(),
254    );
255    result.insert(
256        "total_step".to_string(),
257        step_info.total_step.into_pyobject(py).unwrap().into_any().unbind(),
258    );
259    result.insert("cycle".to_string(), step_info.cycle.into_pyobject(py).unwrap().into_any().unbind());
260
261    // Handle gradients based on phase
262    if step_info.needs_backward {
263        // WARMUP, FULL, or CORRECT: record provided gradients
264        if let Some(grad_dict) = gradients {
265            let device = candle_core::Device::Cpu;
266            let mut tensor_grads: HashMap<String, candle_core::Tensor> = HashMap::new();
267
268            for (name, arr) in grad_dict {
269                let arr = arr.as_array();
270                let shape = arr.shape();
271                let data: Vec<f32> = arr.iter().copied().collect();
272
273                let tensor =
274                    candle_core::Tensor::from_vec(data, shape.to_vec(), &device).map_err(|e| {
275                        PyValueError::new_err(format!("Failed to create tensor for {name}: {e}"))
276                    })?;
277
278                tensor_grads.insert(name, tensor);
279            }
280
281            inner
282                .record_full_gradients(&tensor_grads)
283                .map_err(|e| PyValueError::new_err(format!("record_full_gradients failed: {e}")))?;
284        }
285        result.insert("predicted_gradients".to_string(), py.None());
286    } else {
287        // PREDICT: get predicted gradients
288        let predicted = inner
289            .get_predicted_gradients()
290            .map_err(|e| PyValueError::new_err(format!("get_predicted_gradients failed: {e}")))?;
291
292        let mut pred_dict: HashMap<String, Py<PyAny>> = HashMap::new();
293        for (name, tensor) in predicted {
294            let dims = tensor.dims();
295            let flat: Vec<f32> = tensor
296                .flatten_all()
297                .map_err(|e| PyValueError::new_err(format!("flatten failed: {e}")))?
298                .to_vec1()
299                .map_err(|e| PyValueError::new_err(format!("to_vec1 failed: {e}")))?;
300
301            if dims.len() == 1 {
302                let arr = flat.to_pyarray(py);
303                pred_dict.insert(name, arr.into_pyobject(py).unwrap().into_any().unbind());
304            } else {
305                let arr = flat
306                    .to_pyarray(py)
307                    .reshape(dims.to_vec())
308                    .map_err(|e| PyValueError::new_err(format!("reshape failed: {e}")))?;
309                pred_dict.insert(name, arr.into_pyobject(py).unwrap().into_any().unbind());
310            }
311        }
312        result.insert("predicted_gradients".to_string(), pred_dict.into_pyobject(py).unwrap().into_any().unbind());
313    }
314
315    // End step
316    inner
317        .end_step(loss)
318        .map_err(|e| PyValueError::new_err(format!("end_step failed: {e}")))?;
319
320    // Add stats
321    let stats = inner.get_stats();
322    result.insert("speedup".to_string(), stats.speedup.into_pyobject(py).unwrap().into_any().unbind());
323    result.insert(
324        "mean_prediction_error".to_string(),
325        stats.mean_prediction_error.into_pyobject(py).unwrap().into_any().unbind(),
326    );
327
328    Ok(result)
329}
330
331/// Get the current phase name from the trainer.
332///
333/// # Arguments
334/// * `trainer` - The trainer handle from `create_trainer`.
335///
336/// # Returns
337/// Phase name as string: "WARMUP", "FULL", "PREDICT", or "CORRECT".
338#[pyfunction]
339fn trainer_get_phase(trainer: &PyDeterministicTrainer) -> PyResult<String> {
340    let inner = trainer
341        .inner
342        .lock()
343        .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
344
345    Ok(inner.current_phase().to_string())
346}
347
348/// Get training statistics from the trainer.
349///
350/// # Arguments
351/// * `trainer` - The trainer handle from `create_trainer`.
352///
353/// # Returns
354/// Dictionary with training statistics:
355/// - `total_steps`: Total steps taken
356/// - `warmup_steps`: Warmup steps taken
357/// - `full_steps`: Full gradient steps taken
358/// - `predict_steps`: Prediction steps taken
359/// - `correct_steps`: Correction steps taken
360/// - `cycles`: Training cycles completed
361/// - `speedup`: Effective speedup ratio
362/// - `mean_prediction_error`: Mean prediction error
363/// - `current_loss`: Most recent loss
364#[pyfunction]
365fn trainer_get_stats(py: Python<'_>, trainer: &PyDeterministicTrainer) -> PyResult<HashMap<String, Py<PyAny>>> {
366    let inner = trainer
367        .inner
368        .lock()
369        .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
370
371    let stats = inner.get_stats();
372    let mut result: HashMap<String, Py<PyAny>> = HashMap::new();
373
374    result.insert("total_steps".to_string(), stats.total_steps.into_pyobject(py).unwrap().into_any().unbind());
375    result.insert("warmup_steps".to_string(), stats.warmup_steps.into_pyobject(py).unwrap().into_any().unbind());
376    result.insert("full_steps".to_string(), stats.full_steps.into_pyobject(py).unwrap().into_any().unbind());
377    result.insert("predict_steps".to_string(), stats.predict_steps.into_pyobject(py).unwrap().into_any().unbind());
378    result.insert("correct_steps".to_string(), stats.correct_steps.into_pyobject(py).unwrap().into_any().unbind());
379    result.insert("cycles".to_string(), stats.cycles.into_pyobject(py).unwrap().into_any().unbind());
380    result.insert("speedup".to_string(), stats.speedup.into_pyobject(py).unwrap().into_any().unbind());
381    result.insert(
382        "mean_prediction_error".to_string(),
383        stats.mean_prediction_error.into_pyobject(py).unwrap().into_any().unbind(),
384    );
385    result.insert("current_loss".to_string(), stats.current_loss.into_pyobject(py).unwrap().into_any().unbind());
386
387    Ok(result)
388}
389
390/// Reset the trainer state.
391///
392/// # Arguments
393/// * `trainer` - The trainer handle from `create_trainer`.
394#[pyfunction]
395fn trainer_reset(trainer: &PyDeterministicTrainer) -> PyResult<()> {
396    let mut inner = trainer
397        .inner
398        .lock()
399        .map_err(|e| PyValueError::new_err(format!("Lock poisoned: {e}")))?;
400
401    inner
402        .reset()
403        .map_err(|e| PyValueError::new_err(format!("reset failed: {e}")))?;
404
405    Ok(())
406}
407
408// =============================================================================
409// BITLINEAR LAYER BINDINGS
410// =============================================================================
411
412/// Create a BitLinear layer from weights.
413///
414/// BitLinear uses ternary weights {-1, 0, +1} with per-group scales,
415/// providing significant compression while maintaining accuracy.
416///
417/// # Arguments
418/// * `weight` - 2D weight array [out_features, in_features].
419/// * `bias` - Optional 1D bias array [out_features].
420/// * `group_size` - Group size for weight quantization (default: 64).
421///
422/// # Returns
423/// A `BitLinearLayer` handle to use with `bitlinear_forward`.
424///
425/// # Example
426/// ```python
427/// layer = create_bitlinear(weight_array, bias=bias_array, group_size=64)
428/// output = bitlinear_forward(layer, input_array)
429/// print(f"Compression: {bitlinear_compression_ratio(layer):.2f}x")
430/// ```
431#[pyfunction]
432#[pyo3(signature = (weight, bias=None, group_size=64))]
433fn create_bitlinear<'py>(
434    weight: PyReadonlyArray2<'py, f32>,
435    bias: Option<PyReadonlyArray1<'py, f32>>,
436    group_size: usize,
437) -> PyResult<PyBitLinearLayer> {
438    let weight_arr = weight.as_array();
439    let (out_features, in_features) = (weight_arr.nrows(), weight_arr.ncols());
440
441    let device = candle_core::Device::Cpu;
442    let weight_data: Vec<f32> = weight_arr.iter().copied().collect();
443    let weight_tensor =
444        candle_core::Tensor::from_vec(weight_data, (out_features, in_features), &device)
445            .map_err(|e| PyValueError::new_err(format!("Failed to create weight tensor: {e}")))?;
446
447    let bias_tensor = if let Some(b) = bias {
448        let bias_arr = b.as_array();
449        let bias_data: Vec<f32> = bias_arr.iter().copied().collect();
450        Some(
451            candle_core::Tensor::from_vec(bias_data, (out_features,), &device)
452                .map_err(|e| PyValueError::new_err(format!("Failed to create bias tensor: {e}")))?,
453        )
454    } else {
455        None
456    };
457
458    let config = BitNetConfig::default().with_group_size(group_size);
459
460    let layer = BitLinear::from_weight(&weight_tensor, bias_tensor.as_ref(), &config)
461        .map_err(|e| PyValueError::new_err(format!("Failed to create BitLinear: {e}")))?;
462
463    Ok(PyBitLinearLayer {
464        inner: Arc::new(layer),
465    })
466}
467
468/// Forward pass through a BitLinear layer.
469///
470/// # Arguments
471/// * `layer` - The BitLinear layer handle from `create_bitlinear`.
472/// * `input` - 2D input array [batch_size, in_features] or
473///   3D input array [batch_size, seq_len, in_features].
474///
475/// # Returns
476/// Output array with shape [batch_size, out_features] or
477/// [batch_size, seq_len, out_features].
478#[pyfunction]
479fn bitlinear_forward<'py>(
480    py: Python<'py>,
481    layer: &PyBitLinearLayer,
482    input: PyReadonlyArray2<'py, f32>,
483) -> PyResult<Bound<'py, PyArray2<f32>>> {
484    let input_arr = input.as_array();
485    let (batch_size, in_features) = (input_arr.nrows(), input_arr.ncols());
486
487    let device = candle_core::Device::Cpu;
488    let input_data: Vec<f32> = input_arr.iter().copied().collect();
489    let input_tensor =
490        candle_core::Tensor::from_vec(input_data, (batch_size, in_features), &device)
491            .map_err(|e| PyValueError::new_err(format!("Failed to create input tensor: {e}")))?;
492
493    let output_tensor = layer
494        .inner
495        .forward(&input_tensor)
496        .map_err(|e| PyValueError::new_err(format!("Forward pass failed: {e}")))?;
497
498    let output_dims = output_tensor.dims();
499    let output_data: Vec<f32> = output_tensor
500        .flatten_all()
501        .map_err(|e| PyValueError::new_err(format!("flatten failed: {e}")))?
502        .to_vec1()
503        .map_err(|e| PyValueError::new_err(format!("to_vec1 failed: {e}")))?;
504
505    Ok(output_data
506        .to_pyarray(py)
507        .reshape([output_dims[0], output_dims[1]])
508        .map_err(|e| PyValueError::new_err(format!("reshape failed: {e}")))?
509        .to_owned())
510}
511
512/// Get the compression ratio of a BitLinear layer.
513///
514/// # Arguments
515/// * `layer` - The BitLinear layer handle from `create_bitlinear`.
516///
517/// # Returns
518/// Compression ratio (e.g., 8.0 means 8x compression vs float32).
519#[pyfunction]
520fn bitlinear_compression_ratio(layer: &PyBitLinearLayer) -> f32 {
521    layer.inner.compression_ratio()
522}
523
524/// Get the sparsity of a BitLinear layer.
525///
526/// # Arguments
527/// * `layer` - The BitLinear layer handle from `create_bitlinear`.
528///
529/// # Returns
530/// Sparsity ratio (fraction of weights that are zero).
531#[pyfunction]
532fn bitlinear_sparsity(layer: &PyBitLinearLayer) -> f32 {
533    layer.inner.sparsity()
534}
535
536/// Get the input features dimension of a BitLinear layer.
537///
538/// # Arguments
539/// * `layer` - The BitLinear layer handle from `create_bitlinear`.
540///
541/// # Returns
542/// Number of input features.
543#[pyfunction]
544fn bitlinear_in_features(layer: &PyBitLinearLayer) -> usize {
545    layer.inner.in_features()
546}
547
548/// Get the output features dimension of a BitLinear layer.
549///
550/// # Arguments
551/// * `layer` - The BitLinear layer handle from `create_bitlinear`.
552///
553/// # Returns
554/// Number of output features.
555#[pyfunction]
556fn bitlinear_out_features(layer: &PyBitLinearLayer) -> usize {
557    layer.inner.out_features()
558}
559
560// =============================================================================
561// ORIGINAL TERNARY BINDINGS
562// =============================================================================
563
564/// Pack ternary weights into efficient 2-bit representation.
565///
566/// # Arguments
567/// * `weights` - 2D array of ternary values {-1, 0, +1}
568/// * `scales` - Per-row scale factors
569///
570/// # Returns
571/// Tuple of (packed_bytes, scales) for storage/transmission.
572///
573/// Note: Internally uses trit-vsa's bitsliced storage, but returns
574/// a compatible 2-bit packed format for interoperability.
575#[pyfunction]
576fn pack_ternary_weights<'py>(
577    py: Python<'py>,
578    weights: PyReadonlyArray2<'py, f32>,
579    scales: PyReadonlyArray1<'py, f32>,
580) -> PyResult<(Bound<'py, PyArray1<u8>>, Bound<'py, PyArray1<f32>>)> {
581    let weights = weights.as_array();
582    let scales = scales.as_array();
583
584    let (rows, cols) = (weights.nrows(), weights.ncols());
585
586    // Convert to trit-vsa PackedTritVec per row
587    let mut packed_vecs: Vec<PackedTritVec> = Vec::with_capacity(rows);
588
589    for row in weights.rows() {
590        let mut packed = PackedTritVec::new(cols);
591        for (col_idx, &val) in row.iter().enumerate() {
592            let trit = match val as i8 {
593                v if v > 0 => Trit::P,  // +1
594                v if v < 0 => Trit::N,  // -1
595                _ => Trit::Z,           // 0
596            };
597            packed.set(col_idx, trit);
598        }
599        packed_vecs.push(packed);
600    }
601
602    // Convert bitsliced representation to 2-bit packed format for Python compatibility
603    let packed_size = cols.div_ceil(4);
604    let mut packed = vec![0u8; rows * packed_size];
605
606    for (row_idx, pvec) in packed_vecs.iter().enumerate() {
607        for col_idx in 0..cols {
608            let trit = pvec.get(col_idx);
609            let trit_bits = match trit {
610                Trit::P => 0b01, // +1
611                Trit::N => 0b10, // -1
612                Trit::Z => 0b00, // 0
613            };
614
615            let byte_idx = row_idx * packed_size + col_idx / 4;
616            let bit_offset = (col_idx % 4) * 2;
617            packed[byte_idx] |= trit_bits << bit_offset;
618        }
619    }
620
621    let packed_array = packed.to_pyarray(py);
622    let scales_array = scales.to_vec().to_pyarray(py);
623
624    Ok((packed_array, scales_array))
625}
626
627/// Unpack ternary weights from 2-bit representation.
628///
629/// # Arguments
630/// * `packed` - Packed byte array
631/// * `scales` - Per-row scale factors
632/// * `shape` - Original (rows, cols) shape
633///
634/// # Returns
635/// 2D array of dequantized weights.
636///
637/// Note: Uses trit-vsa for intermediate storage.
638#[pyfunction]
639fn unpack_ternary_weights<'py>(
640    py: Python<'py>,
641    packed: PyReadonlyArray1<'py, u8>,
642    scales: PyReadonlyArray1<'py, f32>,
643    shape: (usize, usize),
644) -> PyResult<Bound<'py, PyArray2<f32>>> {
645    let packed = packed.as_array();
646    let scales = scales.as_array();
647    let (rows, cols) = shape;
648
649    let packed_size = cols.div_ceil(4);
650
651    // Convert 2-bit packed format to trit-vsa PackedTritVec
652    let mut packed_vecs: Vec<PackedTritVec> = Vec::with_capacity(rows);
653
654    for row_idx in 0..rows {
655        let mut pvec = PackedTritVec::new(cols);
656        for col_idx in 0..cols {
657            let byte_idx = row_idx * packed_size + col_idx / 4;
658            let bit_offset = (col_idx % 4) * 2;
659            let trit_bits = (packed[byte_idx] >> bit_offset) & 0b11;
660
661            let trit = match trit_bits {
662                0b01 => Trit::P,  // +1
663                0b10 => Trit::N,  // -1
664                _ => Trit::Z,     // 0
665            };
666            pvec.set(col_idx, trit);
667        }
668        packed_vecs.push(pvec);
669    }
670
671    // Dequantize using scales
672    let mut weights = vec![0.0f32; rows * cols];
673
674    for (row_idx, pvec) in packed_vecs.iter().enumerate() {
675        let scale = scales[row_idx];
676        for col_idx in 0..cols {
677            let value = f32::from(pvec.get(col_idx).value()) * scale;
678            weights[row_idx * cols + col_idx] = value;
679        }
680    }
681
682    Ok(weights
683        .to_pyarray(py)
684        .reshape([rows, cols])
685        .expect("reshape failed")
686        .to_owned())
687}
688
689/// Efficient matrix multiplication with packed ternary weights.
690///
691/// Computes: output = input @ weights.T
692///
693/// # Arguments
694/// * `input` - 2D input tensor (batch, in_features)
695/// * `packed_weights` - Packed ternary weights
696/// * `scales` - Per-output-channel scale factors
697/// * `weight_shape` - (out_features, in_features)
698///
699/// # Returns
700/// Output tensor (batch, out_features)
701///
702/// Note: Uses trit-vsa PackedTritVec for efficient dot products.
703#[pyfunction]
704fn ternary_matmul<'py>(
705    py: Python<'py>,
706    input: PyReadonlyArray2<'py, f32>,
707    packed_weights: PyReadonlyArray1<'py, u8>,
708    scales: PyReadonlyArray1<'py, f32>,
709    weight_shape: (usize, usize),
710) -> PyResult<Bound<'py, PyArray2<f32>>> {
711    let input = input.as_array();
712    let packed = packed_weights.as_array();
713    let scales = scales.as_array();
714
715    let (batch_size, in_features) = (input.nrows(), input.ncols());
716    let (out_features, _) = weight_shape;
717
718    if in_features != weight_shape.1 {
719        return Err(PyValueError::new_err(format!(
720            "Input features {} doesn't match weight features {}",
721            in_features, weight_shape.1
722        )));
723    }
724
725    let packed_cols = in_features.div_ceil(4);
726
727    // Convert packed weights to trit-vsa PackedTritVec for each output row
728    let mut weight_vecs: Vec<PackedTritVec> = Vec::with_capacity(out_features);
729
730    for o in 0..out_features {
731        let mut pvec = PackedTritVec::new(in_features);
732        for i in 0..in_features {
733            let byte_idx = o * packed_cols + i / 4;
734            let bit_offset = (i % 4) * 2;
735            let trit_bits = (packed[byte_idx] >> bit_offset) & 0b11;
736
737            let trit = match trit_bits {
738                0b01 => Trit::P,
739                0b10 => Trit::N,
740                _ => Trit::Z,
741            };
742            pvec.set(i, trit);
743        }
744        weight_vecs.push(pvec);
745    }
746
747    let mut output = vec![0.0f32; batch_size * out_features];
748
749    // Compute matmul: for each batch, compute dot product with each weight row
750    for b in 0..batch_size {
751        for (o, weight_vec) in weight_vecs.iter().enumerate() {
752            let scale = scales[o];
753            let mut sum = 0.0f32;
754
755            // Use the trit values to select add/subtract/skip
756            for i in 0..in_features {
757                let trit = weight_vec.get(i);
758                let x = input[[b, i]];
759                sum += match trit {
760                    Trit::P => x,   // +1: add
761                    Trit::N => -x,  // -1: subtract
762                    Trit::Z => 0.0, // 0: skip
763                };
764            }
765
766            output[b * out_features + o] = sum * scale;
767        }
768    }
769
770    Ok(output
771        .to_pyarray(py)
772        .reshape([batch_size, out_features])
773        .expect("reshape failed")
774        .to_owned())
775}
776
777/// Quantize weights to ternary using AbsMean scaling.
778///
779/// # Arguments
780/// * `weights` - 2D weight tensor
781///
782/// # Returns
783/// Tuple of (ternary_weights, scales)
784///
785/// Note: Delegates to bitnet-quantize for the core algorithm.
786#[pyfunction]
787fn quantize_weights_absmean<'py>(
788    py: Python<'py>,
789    weights: PyReadonlyArray2<'py, f32>,
790) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
791    let weights_arr = weights.as_array();
792    let (rows, cols) = (weights_arr.nrows(), weights_arr.ncols());
793
794    // Use bitnet-quantize with group_size = cols (per-row quantization)
795    let config = BitNetConfig::default().with_group_size(cols);
796
797    // Convert numpy array to candle tensor
798    let device = candle_core::Device::Cpu;
799    let weight_data: Vec<f32> = weights_arr.iter().copied().collect();
800    let weight_tensor = candle_core::Tensor::from_vec(weight_data, (rows, cols), &device)
801        .map_err(|e| PyValueError::new_err(format!("Failed to create tensor: {e}")))?;
802
803    // Quantize using bitnet-quantize
804    let ternary = bitnet_quantize_weights(&weight_tensor, &config)
805        .map_err(|e| PyValueError::new_err(format!("Quantization failed: {e}")))?;
806
807    // Extract ternary values and scales
808    let mut ternary_output = vec![0.0f32; rows * cols];
809    let scales: Vec<f32> = ternary.scales.clone();
810
811    for (row_idx, packed) in ternary.data.iter().enumerate() {
812        for col_idx in 0..cols {
813            let trit = packed.get(col_idx);
814            ternary_output[row_idx * cols + col_idx] = f32::from(trit.value());
815        }
816    }
817
818    Ok((
819        ternary_output
820            .to_pyarray(py)
821            .reshape([rows, cols])
822            .expect("reshape failed")
823            .to_owned(),
824        scales.to_pyarray(py),
825    ))
826}
827
828/// Compress gradients using VSA random projection.
829///
830/// # Arguments
831/// * `gradients` - Flattened gradient tensor
832/// * `compression_ratio` - Target compression (0.0 to 1.0)
833/// * `seed` - Random seed for reproducibility
834///
835/// # Returns
836/// Tuple of (compressed_gradients, projection_seed)
837///
838/// Note: Uses a simplified random projection for Python compatibility.
839/// For full VSA with bind/bundle/unbind, use vsa-optim-rs directly.
840#[pyfunction]
841#[allow(clippy::cast_precision_loss)]
842fn compress_gradients_vsa<'py>(
843    py: Python<'py>,
844    gradients: PyReadonlyArray1<'py, f32>,
845    compression_ratio: f32,
846    seed: u64,
847) -> PyResult<(Bound<'py, PyArray1<f32>>, u64)> {
848    use rand::{Rng, SeedableRng};
849    use rand_chacha::ChaCha8Rng;
850
851    let gradients = gradients.as_array();
852    let original_dim = gradients.len();
853    let compressed_dim = ((original_dim as f32 * compression_ratio).ceil() as usize).max(256);
854
855    let mut rng = ChaCha8Rng::seed_from_u64(seed);
856    let mut compressed = vec![0.0f32; compressed_dim];
857
858    // Random projection (Johnson-Lindenstrauss style)
859    let scale = 1.0 / (original_dim as f32).sqrt();
860    for &g in gradients.iter() {
861        for c in compressed.iter_mut() {
862            // Sparse random projection: ~68% zeros, 16% +1, 16% -1
863            let r: f32 = rng.gen();
864            if r < 0.16 {
865                *c += g * scale;
866            } else if r < 0.32 {
867                *c -= g * scale;
868            }
869        }
870    }
871
872    Ok((compressed.to_pyarray(py), seed))
873}
874
875/// Decompress gradients from VSA projection.
876///
877/// # Arguments
878/// * `compressed` - Compressed gradient tensor
879/// * `original_dim` - Original gradient dimension
880/// * `seed` - Random seed (must match compression)
881///
882/// # Returns
883/// Reconstructed gradient tensor (approximate)
884///
885/// Note: Uses simplified inverse projection for Python compatibility.
886#[pyfunction]
887#[allow(clippy::cast_precision_loss)]
888fn decompress_gradients_vsa<'py>(
889    py: Python<'py>,
890    compressed: PyReadonlyArray1<'py, f32>,
891    original_dim: usize,
892    seed: u64,
893) -> PyResult<Bound<'py, PyArray1<f32>>> {
894    use rand::{Rng, SeedableRng};
895    use rand_chacha::ChaCha8Rng;
896
897    let compressed = compressed.as_array();
898
899    let mut rng = ChaCha8Rng::seed_from_u64(seed);
900    let mut gradients = vec![0.0f32; original_dim];
901
902    // Inverse projection (transpose of forward projection)
903    let scale = 1.0 / (original_dim as f32).sqrt();
904    for g in gradients.iter_mut() {
905        for &c in compressed.iter() {
906            let r: f32 = rng.gen();
907            if r < 0.16 {
908                *g += c * scale;
909            } else if r < 0.32 {
910                *g -= c * scale;
911            }
912        }
913    }
914
915    Ok(gradients.to_pyarray(py))
916}
917
918/// Get version information.
919#[pyfunction]
920fn version() -> &'static str {
921    env!("CARGO_PKG_VERSION")
922}
923
924/// Check if CUDA is available (for Python).
925#[pyfunction]
926fn cuda_available_py() -> bool {
927    #[cfg(feature = "cuda")]
928    {
929        candle_core::Device::cuda_if_available(0)
930            .map(|d| matches!(d, candle_core::Device::Cuda(_)))
931            .unwrap_or(false)
932    }
933    #[cfg(not(feature = "cuda"))]
934    {
935        false
936    }
937}
938
939/// Python module definition.
940#[pymodule]
941fn tritter_accel(m: &Bound<'_, PyModule>) -> PyResult<()> {
942    // Register class types
943    m.add_class::<PyDeterministicTrainer>()?;
944    m.add_class::<PyBitLinearLayer>()?;
945
946    // Deterministic phase trainer functions
947    m.add_function(wrap_pyfunction!(create_trainer, m)?)?;
948    m.add_function(wrap_pyfunction!(trainer_step, m)?)?;
949    m.add_function(wrap_pyfunction!(trainer_get_phase, m)?)?;
950    m.add_function(wrap_pyfunction!(trainer_get_stats, m)?)?;
951    m.add_function(wrap_pyfunction!(trainer_reset, m)?)?;
952
953    // BitLinear layer functions
954    m.add_function(wrap_pyfunction!(create_bitlinear, m)?)?;
955    m.add_function(wrap_pyfunction!(bitlinear_forward, m)?)?;
956    m.add_function(wrap_pyfunction!(bitlinear_compression_ratio, m)?)?;
957    m.add_function(wrap_pyfunction!(bitlinear_sparsity, m)?)?;
958    m.add_function(wrap_pyfunction!(bitlinear_in_features, m)?)?;
959    m.add_function(wrap_pyfunction!(bitlinear_out_features, m)?)?;
960
961    // Core ternary functions
962    m.add_function(wrap_pyfunction!(pack_ternary_weights, m)?)?;
963    m.add_function(wrap_pyfunction!(unpack_ternary_weights, m)?)?;
964    m.add_function(wrap_pyfunction!(ternary_matmul, m)?)?;
965    m.add_function(wrap_pyfunction!(quantize_weights_absmean, m)?)?;
966
967    // Gradient compression
968    m.add_function(wrap_pyfunction!(compress_gradients_vsa, m)?)?;
969    m.add_function(wrap_pyfunction!(decompress_gradients_vsa, m)?)?;
970
971    // Utilities
972    m.add_function(wrap_pyfunction!(version, m)?)?;
973    m.add_function(wrap_pyfunction!(cuda_available_py, m)?)?;
974
975    // GPU-accelerated VSA operations (when cuda feature enabled)
976    #[cfg(feature = "cuda")]
977    {
978        m.add_function(wrap_pyfunction!(gpu::cuda_available, m)?)?;
979        m.add_function(wrap_pyfunction!(gpu::get_device, m)?)?;
980        m.add_function(wrap_pyfunction!(gpu::set_device, m)?)?;
981        m.add_function(wrap_pyfunction!(gpu::vsa_bind, m)?)?;
982        m.add_function(wrap_pyfunction!(gpu::vsa_unbind, m)?)?;
983        m.add_function(wrap_pyfunction!(gpu::vsa_bundle, m)?)?;
984        m.add_function(wrap_pyfunction!(gpu::vsa_similarity, m)?)?;
985        m.add_function(wrap_pyfunction!(gpu::vsa_random, m)?)?;
986    }
987
988    Ok(())
989}
990} // End of python_bindings module
991
992// Re-export the Python module entry point when the python feature is enabled
993#[cfg(feature = "python")]
994pub use python_bindings::*;
995
996// =============================================================================
997// TESTS
998// =============================================================================
999
1000#[cfg(test)]
1001mod tests {
1002    use trit_vsa::{PackedTritVec, Trit};
1003    use bitnet_quantize::{quantize_weights as bitnet_quantize_weights, BitLinear, BitNetConfig};
1004    use candle_nn::Module;
1005
1006    #[test]
1007    fn test_pack_unpack_roundtrip() {
1008        // Test that pack/unpack preserves ternary values using trit-vsa
1009        let weights = vec![1.0, 0.0, -1.0, 1.0, -1.0, 1.0, 0.0, -1.0];
1010        let scales = vec![1.0, 1.0];
1011
1012        // Create PackedTritVec for each row
1013        let mut packed_vecs = Vec::new();
1014        for row in 0..2 {
1015            let mut pvec = PackedTritVec::new(4);
1016            for col in 0..4 {
1017                let val = weights[row * 4 + col];
1018                let trit = match val as i8 {
1019                    v if v > 0 => Trit::P,
1020                    v if v < 0 => Trit::N,
1021                    _ => Trit::Z,
1022                };
1023                pvec.set(col, trit);
1024            }
1025            packed_vecs.push(pvec);
1026        }
1027
1028        // Unpack and dequantize
1029        let mut unpacked = vec![0.0f32; 8];
1030        for (row, pvec) in packed_vecs.iter().enumerate() {
1031            let scale = scales[row];
1032            for col in 0..4 {
1033                unpacked[row * 4 + col] = f32::from(pvec.get(col).value()) * scale;
1034            }
1035        }
1036
1037        assert_eq!(unpacked, weights);
1038    }
1039
1040    #[test]
1041    fn test_quantize_absmean_with_bitnet() {
1042        // Test that quantization works with bitnet-quantize
1043        let weights = vec![0.5, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4];
1044
1045        // Use bitnet-quantize directly
1046        let device = candle_core::Device::Cpu;
1047        let config = BitNetConfig::default().with_group_size(4);
1048        let tensor = candle_core::Tensor::from_vec(weights.clone(), (2, 4), &device).unwrap();
1049
1050        let ternary = bitnet_quantize_weights(&tensor, &config).unwrap();
1051
1052        // Check structure
1053        assert_eq!(ternary.shape, (2, 4));
1054        assert_eq!(ternary.data.len(), 2);
1055    }
1056
1057    #[test]
1058    fn test_trit_vsa_dot_product() {
1059        // Test trit-vsa dot product functionality
1060        let mut a = PackedTritVec::new(4);
1061        let mut b = PackedTritVec::new(4);
1062
1063        // a = [+1, -1, 0, +1]
1064        a.set(0, Trit::P);
1065        a.set(1, Trit::N);
1066        a.set(2, Trit::Z);
1067        a.set(3, Trit::P);
1068
1069        // b = [+1, +1, -1, 0]
1070        b.set(0, Trit::P);
1071        b.set(1, Trit::P);
1072        b.set(2, Trit::N);
1073        b.set(3, Trit::Z);
1074
1075        // dot = 1*1 + (-1)*1 + 0*(-1) + 1*0 = 1 - 1 + 0 + 0 = 0
1076        assert_eq!(a.dot(&b), 0);
1077    }
1078
1079    #[test]
1080    fn test_rust_api_quantization() {
1081        // Test the Rust API directly
1082        use crate::core::quantization::{quantize_absmean, QuantizeConfig};
1083
1084        let device = candle_core::Device::Cpu;
1085        let weights = candle_core::Tensor::from_vec(
1086            vec![0.5f32, -0.3, 0.1, 0.8, -0.2, 0.6, -0.7, 0.4],
1087            (2, 4),
1088            &device,
1089        )
1090        .unwrap();
1091
1092        let config = QuantizeConfig::default();
1093        let result = quantize_absmean(&weights, &config).unwrap();
1094
1095        assert_eq!(result.shape, (2, 4));
1096        assert_eq!(result.values.len(), 8);
1097
1098        // All values should be ternary
1099        for v in &result.values {
1100            assert!([-1, 0, 1].contains(v));
1101        }
1102    }
1103
1104    #[test]
1105    fn test_rust_api_gradient_compression() {
1106        // Test the Rust API for gradient compression
1107        use crate::core::training::{GradientCompressor, TrainingConfig};
1108
1109        let config = TrainingConfig::default();
1110        let compressor = GradientCompressor::new(config);
1111
1112        let gradients: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
1113        let compressed = compressor.compress(&gradients, Some(0.1)).unwrap();
1114        let recovered = compressor.decompress(&compressed).unwrap();
1115
1116        assert_eq!(recovered.len(), gradients.len());
1117    }
1118
1119    #[test]
1120    fn test_deterministic_phase_trainer_creation() {
1121        // Test that DeterministicPhaseTrainer can be created from vsa-optim-rs
1122        use vsa_optim_rs::{DeterministicPhaseConfig, DeterministicPhaseTrainer};
1123
1124        let shapes = vec![
1125            ("layer.weight".to_string(), vec![16, 32]),
1126            ("layer.bias".to_string(), vec![16]),
1127        ];
1128
1129        let config = DeterministicPhaseConfig::default()
1130            .with_warmup_steps(5)
1131            .with_full_steps(3)
1132            .with_predict_steps(10);
1133
1134        let trainer =
1135            DeterministicPhaseTrainer::new(&shapes, config, &candle_core::Device::Cpu).unwrap();
1136
1137        // Initial phase should be WARMUP
1138        assert_eq!(
1139            trainer.current_phase(),
1140            vsa_optim_rs::DeterministicPhase::Warmup
1141        );
1142    }
1143
1144    #[test]
1145    fn test_deterministic_trainer_step_cycle() {
1146        use std::collections::HashMap;
1147        use vsa_optim_rs::{DeterministicPhaseConfig, DeterministicPhaseTrainer};
1148
1149        let shapes = vec![
1150            ("layer.weight".to_string(), vec![8, 16]),
1151            ("layer.bias".to_string(), vec![8]),
1152        ];
1153
1154        let config = DeterministicPhaseConfig::default()
1155            .with_warmup_steps(3)
1156            .with_full_steps(2)
1157            .with_predict_steps(5);
1158
1159        let mut trainer =
1160            DeterministicPhaseTrainer::new(&shapes, config, &candle_core::Device::Cpu).unwrap();
1161
1162        // Run through several steps
1163        for i in 0..10 {
1164            let info = trainer.begin_step().unwrap();
1165
1166            if info.needs_backward {
1167                // Create mock gradients
1168                let mut grads = HashMap::new();
1169                grads.insert(
1170                    "layer.weight".to_string(),
1171                    candle_core::Tensor::ones((8, 16), candle_core::DType::F32, &candle_core::Device::Cpu)
1172                        .unwrap()
1173                        .affine((i as f64 + 1.0) * 0.1, 0.0)
1174                        .unwrap(),
1175                );
1176                grads.insert(
1177                    "layer.bias".to_string(),
1178                    candle_core::Tensor::ones(8, candle_core::DType::F32, &candle_core::Device::Cpu)
1179                        .unwrap()
1180                        .affine((i as f64 + 1.0) * 0.1, 0.0)
1181                        .unwrap(),
1182                );
1183                trainer.record_full_gradients(&grads).unwrap();
1184            } else {
1185                let _predicted = trainer.get_predicted_gradients();
1186            }
1187
1188            trainer.end_step(1.0 / (i + 1) as f32).unwrap();
1189        }
1190
1191        let stats = trainer.get_stats();
1192        assert_eq!(stats.total_steps, 10);
1193    }
1194
1195    #[test]
1196    fn test_bitlinear_layer_creation() {
1197        let device = candle_core::Device::Cpu;
1198        let config = BitNetConfig::default().with_group_size(16);
1199
1200        let weight =
1201            candle_core::Tensor::randn(0.0f32, 1.0, (32, 64), &device).unwrap();
1202
1203        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
1204
1205        assert_eq!(layer.in_features(), 64);
1206        assert_eq!(layer.out_features(), 32);
1207        assert!(layer.compression_ratio() > 1.0);
1208    }
1209
1210    #[test]
1211    fn test_bitlinear_forward_pass() {
1212        let device = candle_core::Device::Cpu;
1213        let config = BitNetConfig::default().with_group_size(16);
1214
1215        let weight =
1216            candle_core::Tensor::randn(0.0f32, 1.0, (32, 64), &device).unwrap();
1217        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
1218
1219        let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
1220        let output = layer.forward(&input).unwrap();
1221
1222        assert_eq!(output.shape().dims(), &[4, 32]);
1223    }
1224
1225    #[test]
1226    fn test_bitlinear_with_bias() {
1227        let device = candle_core::Device::Cpu;
1228        let config = BitNetConfig::default().with_group_size(16);
1229
1230        let weight =
1231            candle_core::Tensor::randn(0.0f32, 1.0, (32, 64), &device).unwrap();
1232        let bias = candle_core::Tensor::randn(0.0f32, 1.0, (32,), &device).unwrap();
1233
1234        let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
1235
1236        assert!(layer.bias().is_some());
1237
1238        let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
1239        let output = layer.forward(&input).unwrap();
1240
1241        assert_eq!(output.shape().dims(), &[4, 32]);
1242    }
1243}