Skip to main content

scirs2/
gpu_ops.rs

1//! GPU-accelerated matrix operations exposed to Python.
2//!
3//! This module provides a GPU-dispatch API with a pure-CPU fallback that is
4//! always available. The API surface is identical regardless of whether GPU
5//! hardware is present, so Python callers need no conditional logic.
6//!
7//! # GPU path (future `cuda_bridge` feature)
8//! When the `cuda_bridge` feature is enabled and `cudarc`/`candle` are linked,
9//! the functions below dispatch to the GPU kernel instead of the CPU path.
10//! The feature gate is wired up but the cudarc integration itself is deferred
11//! until GPU hardware is available in CI.  See TODO.md L149/L151.
12//!
13//! # CPU path (default, pure Rust)
14//! The CPU implementations use plain `Vec<f64>` arithmetic and are correct,
15//! tested, and zero-dependency.
16
17use pyo3::exceptions::{PyNotImplementedError, PyTypeError, PyValueError};
18use pyo3::prelude::*;
19
20// ── GPU device info ────────────────────────────────────────────────────────
21
22/// Return a string describing the active compute device.
23///
24/// Returns `"cpu (cuda_bridge feature not enabled)"` unless the `cuda_bridge`
25/// Cargo feature is enabled and CUDA hardware is detected at runtime.
26#[pyfunction]
27pub fn gpu_device_info() -> String {
28    "cpu (cuda_bridge feature not enabled)".to_string()
29}
30
31// ── Matrix multiply ────────────────────────────────────────────────────────
32
33/// Multiply two row-major matrices: C = A (m×k) × B (k×n).
34///
35/// Returns the product as a flat `Vec<f64>` of length `m * n` in row-major
36/// order together with the output shape `(m, n)`.
37///
38/// # Arguments
39/// * `a_data`  – flat row-major elements of A, length must equal `a_rows * a_cols`
40/// * `a_rows`  – number of rows in A (= m)
41/// * `a_cols`  – number of columns in A (= k)
42/// * `b_data`  – flat row-major elements of B, length must equal `a_cols * b_cols`
43/// * `b_cols`  – number of columns in B (= n)
44///
45/// # Errors
46/// Returns `PyValueError` when any length constraint is violated.
47#[pyfunction]
48pub fn gpu_matmul(
49    a_data: Vec<f64>,
50    a_rows: usize,
51    a_cols: usize,
52    b_data: Vec<f64>,
53    b_cols: usize,
54) -> PyResult<Vec<f64>> {
55    if a_data.len() != a_rows * a_cols {
56        return Err(PyValueError::new_err(format!(
57            "a_data length {} does not match a_rows * a_cols = {} * {} = {}",
58            a_data.len(),
59            a_rows,
60            a_cols,
61            a_rows * a_cols,
62        )));
63    }
64    if b_data.len() != a_cols * b_cols {
65        return Err(PyValueError::new_err(format!(
66            "b_data length {} does not match a_cols * b_cols = {} * {} = {}",
67            b_data.len(),
68            a_cols,
69            b_cols,
70            a_cols * b_cols,
71        )));
72    }
73
74    // CPU row-major matrix multiply (ikj loop order for cache locality)
75    let mut c = vec![0.0f64; a_rows * b_cols];
76    for i in 0..a_rows {
77        for k in 0..a_cols {
78            let a_ik = a_data[i * a_cols + k];
79            for j in 0..b_cols {
80                c[i * b_cols + j] += a_ik * b_data[k * b_cols + j];
81            }
82        }
83    }
84    Ok(c)
85}
86
87// ── Element-wise activation functions ─────────────────────────────────────
88
89/// Apply an element-wise activation to every element of `data`.
90///
91/// Supported operations: `"exp"`, `"log"`, `"sqrt"`, `"relu"`, `"sigmoid"`,
92/// `"tanh"`, `"abs"`, `"square"`.
93///
94/// For `"log"` of non-positive values the result is `-∞`; for `"sqrt"` of
95/// negative values the result is `NaN`.  These match NumPy conventions.
96///
97/// # Errors
98/// Returns `PyValueError` if `op` is not one of the supported strings.
99#[pyfunction]
100pub fn gpu_elementwise(data: Vec<f64>, op: &str) -> PyResult<Vec<f64>> {
101    let result: Vec<f64> = match op {
102        "exp" => data.iter().map(|&x| x.exp()).collect(),
103        "log" => data
104            .iter()
105            .map(|&x| if x > 0.0 { x.ln() } else { f64::NEG_INFINITY })
106            .collect(),
107        "sqrt" => data
108            .iter()
109            .map(|&x| if x >= 0.0 { x.sqrt() } else { f64::NAN })
110            .collect(),
111        "relu" => data.iter().map(|&x| x.max(0.0)).collect(),
112        "sigmoid" => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
113        "tanh" => data.iter().map(|&x| x.tanh()).collect(),
114        "abs" => data.iter().map(|&x| x.abs()).collect(),
115        "square" => data.iter().map(|&x| x * x).collect(),
116        _ => {
117            return Err(PyValueError::new_err(format!(
118                "Unknown op '{op}'. Supported: exp, log, sqrt, relu, sigmoid, tanh, abs, square"
119            )))
120        }
121    };
122    Ok(result)
123}
124
125// ── Batch matrix operations ────────────────────────────────────────────────
126
127/// Add two row-major matrices element-wise.
128///
129/// Both vectors must have the same length (= rows × cols).
130///
131/// # Errors
132/// Returns `PyValueError` on length mismatch.
133#[pyfunction]
134pub fn gpu_matrix_add(a_data: Vec<f64>, b_data: Vec<f64>) -> PyResult<Vec<f64>> {
135    if a_data.len() != b_data.len() {
136        return Err(PyValueError::new_err(format!(
137            "Length mismatch: a has {} elements, b has {}",
138            a_data.len(),
139            b_data.len(),
140        )));
141    }
142    Ok(a_data
143        .iter()
144        .zip(b_data.iter())
145        .map(|(&a, &b)| a + b)
146        .collect())
147}
148
149/// Scale a row-major matrix by a scalar.
150#[pyfunction]
151pub fn gpu_matrix_scale(data: Vec<f64>, scalar: f64) -> Vec<f64> {
152    data.iter().map(|&x| x * scalar).collect()
153}
154
155/// Compute the Frobenius norm of a flat matrix.
156#[pyfunction]
157pub fn gpu_frobenius_norm(data: Vec<f64>) -> f64 {
158    data.iter().map(|&x| x * x).sum::<f64>().sqrt()
159}
160
161// ── CUDA tensor bridge (DLPack protocol, GPU path deferred) ───────────────
162
163/// Multiply two PyTorch/JAX tensors via the DLPack protocol.
164///
165/// This is the entry point for the zero-copy GPU path. When the `cuda_bridge`
166/// Cargo feature is enabled (and `cudarc` is linked), the function accepts any
167/// Python object implementing `__dlpack__` and dispatches directly to a CUDA
168/// GEMM kernel.
169///
170/// In the current CPU-only build this function returns `PyNotImplementedError`
171/// with a clear message directing callers to `gpu_matmul()`.
172///
173/// # Python example
174/// ```python
175/// import torch, scirs2
176/// a = torch.randn(512, 512, device='cuda')
177/// b = torch.randn(512, 512, device='cuda')
178/// # GPU path (when cuda_bridge feature is enabled):
179/// c = scirs2.cuda_tensor_matmul(a, b)
180/// # CPU fallback for all tensor sizes:
181/// c_data = scirs2.gpu_matmul(a.flatten().tolist(), 512, 512, b.flatten().tolist(), 512)
182/// ```
183#[pyfunction]
184pub fn cuda_tensor_matmul<'py>(
185    _py: Python<'py>,
186    tensor_a: &Bound<'py, PyAny>,
187    _tensor_b: &Bound<'py, PyAny>,
188) -> PyResult<Py<PyAny>> {
189    // Verify that the input implements the DLPack protocol before returning the
190    // not-implemented error, so that callers know their tensor type is compatible.
191    let has_dlpack = tensor_a.hasattr("__dlpack__").unwrap_or(false);
192    if !has_dlpack {
193        return Err(PyTypeError::new_err(
194            "Tensors must implement the __dlpack__ protocol (e.g. PyTorch or JAX tensors)",
195        ));
196    }
197
198    // CPU-only build: direct to the Vec-based fallback instead.
199    Err(PyNotImplementedError::new_err(
200        "CUDA tensor bridge is not yet compiled in. \
201         Enable the `cuda_bridge` Cargo feature and install `cudarc`. \
202         For a CPU fallback that accepts Python lists, use gpu_matmul().",
203    ))
204}
205
206// ── Module registration ────────────────────────────────────────────────────
207
208/// Register all GPU-dispatch functions in the parent Python module.
209pub fn register_gpu_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
210    m.add_function(wrap_pyfunction!(gpu_device_info, m)?)?;
211    m.add_function(wrap_pyfunction!(gpu_matmul, m)?)?;
212    m.add_function(wrap_pyfunction!(gpu_elementwise, m)?)?;
213    m.add_function(wrap_pyfunction!(gpu_matrix_add, m)?)?;
214    m.add_function(wrap_pyfunction!(gpu_matrix_scale, m)?)?;
215    m.add_function(wrap_pyfunction!(gpu_frobenius_norm, m)?)?;
216    m.add_function(wrap_pyfunction!(cuda_tensor_matmul, m)?)?;
217    Ok(())
218}
219
220// ── Unit tests ────────────────────────────────────────────────────────────
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_gpu_device_info_non_empty() {
228        let info = gpu_device_info();
229        assert!(!info.is_empty());
230        assert!(info.contains("cpu"));
231    }
232
233    #[test]
234    fn test_matmul_2x2_identity() {
235        // [1,0; 0,1] × [5,6; 7,8] = [5,6; 7,8]
236        let id = vec![1.0, 0.0, 0.0, 1.0];
237        let b = vec![5.0, 6.0, 7.0, 8.0];
238        let c = gpu_matmul(id, 2, 2, b.clone(), 2).expect("matmul should not fail");
239        assert!((c[0] - 5.0).abs() < 1e-12);
240        assert!((c[1] - 6.0).abs() < 1e-12);
241        assert!((c[2] - 7.0).abs() < 1e-12);
242        assert!((c[3] - 8.0).abs() < 1e-12);
243    }
244
245    #[test]
246    fn test_matmul_2x2_general() {
247        // [1,2; 3,4] × [5,6; 7,8] = [19,22; 43,50]
248        let a = vec![1.0, 2.0, 3.0, 4.0];
249        let b = vec![5.0, 6.0, 7.0, 8.0];
250        let c = gpu_matmul(a, 2, 2, b, 2).expect("matmul should not fail");
251        assert!((c[0] - 19.0).abs() < 1e-12);
252        assert!((c[1] - 22.0).abs() < 1e-12);
253        assert!((c[2] - 43.0).abs() < 1e-12);
254        assert!((c[3] - 50.0).abs() < 1e-12);
255    }
256
257    #[test]
258    fn test_matmul_non_square() {
259        // [1,2,3; 4,5,6] (2×3) × [7,8; 9,10; 11,12] (3×2) = [58,64; 139,154]
260        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
261        let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
262        let c = gpu_matmul(a, 2, 3, b, 2).expect("non-square matmul should succeed");
263        assert!((c[0] - 58.0).abs() < 1e-12);
264        assert!((c[1] - 64.0).abs() < 1e-12);
265        assert!((c[2] - 139.0).abs() < 1e-12);
266        assert!((c[3] - 154.0).abs() < 1e-12);
267    }
268
269    #[test]
270    fn test_matmul_a_length_mismatch_returns_error() {
271        let a = vec![1.0, 2.0]; // length=2 but a_rows=2, a_cols=2 expects 4
272        let b = vec![1.0, 2.0, 3.0, 4.0];
273        assert!(gpu_matmul(a, 2, 2, b, 2).is_err());
274    }
275
276    #[test]
277    fn test_matmul_b_length_mismatch_returns_error() {
278        let a = vec![1.0, 2.0, 3.0, 4.0];
279        let b = vec![1.0, 2.0]; // length=2 but a_cols=2, b_cols=2 expects 4
280        assert!(gpu_matmul(a, 2, 2, b, 2).is_err());
281    }
282
283    #[test]
284    fn test_elementwise_relu() {
285        let data = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
286        let out = gpu_elementwise(data, "relu").expect("relu should succeed");
287        assert_eq!(out, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
288    }
289
290    #[test]
291    fn test_elementwise_sigmoid_bounds() {
292        let data = vec![-100.0, 0.0, 100.0];
293        let out = gpu_elementwise(data, "sigmoid").expect("sigmoid should succeed");
294        assert!(out[0] < 1e-3, "sigmoid(-100) should be near 0");
295        assert!((out[1] - 0.5).abs() < 1e-12, "sigmoid(0) should be 0.5");
296        assert!(out[2] > 1.0 - 1e-3, "sigmoid(100) should be near 1");
297    }
298
299    #[test]
300    fn test_elementwise_tanh() {
301        let data = vec![-1.0, 0.0, 1.0];
302        let out = gpu_elementwise(data, "tanh").expect("tanh should succeed");
303        assert!((out[1] - 0.0).abs() < 1e-12);
304        assert!((out[2] - 1.0_f64.tanh()).abs() < 1e-12);
305    }
306
307    #[test]
308    fn test_elementwise_exp_log_roundtrip() {
309        let data = vec![1.0, 2.0, 3.0];
310        let exped = gpu_elementwise(data.clone(), "exp").expect("exp should succeed");
311        let logged = gpu_elementwise(exped, "log").expect("log should succeed");
312        for (orig, rt) in data.iter().zip(logged.iter()) {
313            assert!((orig - rt).abs() < 1e-10, "exp-log roundtrip failed");
314        }
315    }
316
317    #[test]
318    fn test_elementwise_sqrt_non_negative() {
319        let data = vec![0.0, 1.0, 4.0, 9.0, 16.0];
320        let out = gpu_elementwise(data, "sqrt").expect("sqrt should succeed");
321        assert!((out[0] - 0.0).abs() < 1e-12);
322        assert!((out[1] - 1.0).abs() < 1e-12);
323        assert!((out[2] - 2.0).abs() < 1e-12);
324        assert!((out[4] - 4.0).abs() < 1e-12);
325    }
326
327    #[test]
328    fn test_elementwise_abs() {
329        let data = vec![-3.0, -1.5, 0.0, 2.5];
330        let out = gpu_elementwise(data, "abs").expect("abs should succeed");
331        assert_eq!(out, vec![3.0, 1.5, 0.0, 2.5]);
332    }
333
334    #[test]
335    fn test_elementwise_square() {
336        let data = vec![-2.0, 3.0];
337        let out = gpu_elementwise(data, "square").expect("square should succeed");
338        assert!((out[0] - 4.0).abs() < 1e-12);
339        assert!((out[1] - 9.0).abs() < 1e-12);
340    }
341
342    #[test]
343    fn test_elementwise_unknown_op_returns_error() {
344        let data = vec![1.0, 2.0];
345        assert!(gpu_elementwise(data, "unknown_activation").is_err());
346    }
347
348    #[test]
349    fn test_matrix_add_correct() {
350        let a = vec![1.0, 2.0, 3.0];
351        let b = vec![4.0, 5.0, 6.0];
352        let out = gpu_matrix_add(a, b).expect("matrix_add should succeed");
353        assert_eq!(out, vec![5.0, 7.0, 9.0]);
354    }
355
356    #[test]
357    fn test_matrix_add_length_mismatch_returns_error() {
358        let a = vec![1.0, 2.0, 3.0];
359        let b = vec![4.0, 5.0];
360        assert!(gpu_matrix_add(a, b).is_err());
361    }
362
363    #[test]
364    fn test_matrix_scale() {
365        let data = vec![1.0, 2.0, 3.0, 4.0];
366        let out = gpu_matrix_scale(data, 2.5);
367        assert_eq!(out, vec![2.5, 5.0, 7.5, 10.0]);
368    }
369
370    #[test]
371    fn test_frobenius_norm_identity() {
372        // Frobenius norm of 2×2 identity = sqrt(2)
373        let id = vec![1.0, 0.0, 0.0, 1.0];
374        let norm = gpu_frobenius_norm(id);
375        assert!((norm - 2.0_f64.sqrt()).abs() < 1e-12);
376    }
377}