Skip to main content

rust_ai_core/
traits.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Common traits for the rust-ai ecosystem.
5//!
6//! This module defines shared trait interfaces that enable consistent APIs
7//! across all rust-ai crates. Implementing these traits ensures interoperability
8//! and allows crates to be composed together seamlessly.
9//!
10//! ## Core Traits
11//!
12//! - [`ValidatableConfig`] - Configuration validation interface
13//! - [`Quantize`] - Tensor quantization (full precision → quantized)
14//! - [`Dequantize`] - Tensor dequantization (quantized → full precision)
15//! - [`GpuDispatchable`] - GPU/CPU kernel dispatch pattern
16//!
17//! ## Implementation Guidelines
18//!
19//! When implementing these traits:
20//!
21//! 1. **Validation**: Use `ValidatableConfig::validate()` in constructors
22//! 2. **GPU-first**: `GpuDispatchable` should prefer GPU, warn on CPU
23//! 3. **Error handling**: Return `CoreError` variants appropriately
24
25use crate::error::{CoreError, Result};
26use candle_core::{Device, Tensor};
27
28/// Configuration validation trait.
29///
30/// All configuration structs should implement this trait to provide
31/// consistent validation across the ecosystem.
32///
33/// # Example
34///
35/// ```rust
36/// use rust_ai_core::{ValidatableConfig, CoreError, Result};
37///
38/// #[derive(Clone)]
39/// struct LoraConfig {
40///     rank: usize,
41///     alpha: f32,
42/// }
43///
44/// impl ValidatableConfig for LoraConfig {
45///     fn validate(&self) -> Result<()> {
46///         if self.rank == 0 {
47///             return Err(CoreError::invalid_config("rank must be > 0"));
48///         }
49///         if self.alpha <= 0.0 {
50///             return Err(CoreError::invalid_config("alpha must be positive"));
51///         }
52///         Ok(())
53///     }
54/// }
55/// ```
56pub trait ValidatableConfig: Clone + Send + Sync {
57    /// Validate the configuration parameters.
58    ///
59    /// # Returns
60    ///
61    /// `Ok(())` if configuration is valid.
62    ///
63    /// # Errors
64    ///
65    /// Returns `CoreError::InvalidConfig` if validation fails.
66    fn validate(&self) -> Result<()>;
67}
68
69/// Tensor quantization trait.
70///
71/// Converts full-precision tensors to quantized representation.
72/// Implementations may use various quantization schemes (NF4, FP4, ternary, etc.).
73///
74/// # Type Parameters
75///
76/// - `Q`: The quantized tensor type (e.g., `QuantizedTensor`, `TernaryVector`)
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use rust_ai_core::Quantize;
82///
83/// struct Nf4Quantizer;
84///
85/// impl Quantize<Nf4Tensor> for Nf4Quantizer {
86///     fn quantize(&self, tensor: &Tensor, device: &Device) -> Result<Nf4Tensor> {
87///         // Quantize to NF4 format
88///     }
89/// }
90/// ```
91pub trait Quantize<Q>: Send + Sync {
92    /// Quantize a tensor.
93    ///
94    /// # Arguments
95    ///
96    /// * `tensor` - Full-precision input tensor
97    /// * `device` - Target device for the quantized output
98    ///
99    /// # Returns
100    ///
101    /// Quantized representation of the input tensor.
102    ///
103    /// # Errors
104    ///
105    /// May return errors for unsupported dtypes, shapes, or device issues.
106    fn quantize(&self, tensor: &Tensor, device: &Device) -> Result<Q>;
107}
108
109/// Tensor dequantization trait.
110///
111/// Converts quantized tensors back to full precision for computation.
112///
113/// # Type Parameters
114///
115/// - `Q`: The quantized tensor type
116///
117/// # Example
118///
119/// ```rust,ignore
120/// use rust_ai_core::Dequantize;
121///
122/// impl Dequantize<Nf4Tensor> for Nf4Quantizer {
123///     fn dequantize(&self, quantized: &Nf4Tensor, device: &Device) -> Result<Tensor> {
124///         // Restore to f32/f16/bf16
125///     }
126/// }
127/// ```
128pub trait Dequantize<Q>: Send + Sync {
129    /// Dequantize a tensor.
130    ///
131    /// # Arguments
132    ///
133    /// * `quantized` - Quantized input tensor
134    /// * `device` - Target device for the dequantized output
135    ///
136    /// # Returns
137    ///
138    /// Full-precision tensor.
139    ///
140    /// # Errors
141    ///
142    /// May return errors for corrupted quantized data or device issues.
143    fn dequantize(&self, quantized: &Q, device: &Device) -> Result<Tensor>;
144}
145
146/// GPU/CPU dispatch trait for operations with both implementations.
147///
148/// This trait enables the CUDA-first pattern: operations that have both
149/// GPU (`CubeCL`) and CPU implementations should implement this trait to
150/// automatically route to the appropriate backend.
151///
152/// # Design Pattern
153///
154/// ```rust,ignore
155/// use rust_ai_core::{GpuDispatchable, warn_if_cpu};
156///
157/// struct FlashAttention;
158///
159/// impl GpuDispatchable for FlashAttention {
160///     type Input = (Tensor, Tensor, Tensor); // Q, K, V
161///     type Output = Tensor;
162///
163///     fn dispatch_gpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
164///         // CubeCL Flash Attention kernel
165///     }
166///
167///     fn dispatch_cpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
168///         // Candle-based fallback
169///         warn_if_cpu(device, "unsloth-rs");
170///         // ... fallback implementation
171///     }
172/// }
173/// ```
174pub trait GpuDispatchable: Send + Sync {
175    /// Input type for the operation.
176    type Input;
177
178    /// Output type for the operation.
179    type Output;
180
181    /// Execute operation on GPU using `CubeCL` kernels.
182    ///
183    /// # Arguments
184    ///
185    /// * `input` - Operation input
186    /// * `device` - Must be a CUDA device
187    ///
188    /// # Errors
189    ///
190    /// Returns `CoreError::KernelError` if kernel execution fails.
191    fn dispatch_gpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output>;
192
193    /// Execute operation on CPU (fallback).
194    ///
195    /// This should emit a warning via `warn_if_cpu()` before execution.
196    ///
197    /// # Arguments
198    ///
199    /// * `input` - Operation input
200    /// * `device` - CPU device
201    ///
202    /// # Errors
203    ///
204    /// Returns appropriate error if operation fails.
205    fn dispatch_cpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output>;
206
207    /// Automatically dispatch to GPU or CPU based on device.
208    ///
209    /// This is the primary entry point. It checks the device type and
210    /// routes to the appropriate implementation.
211    ///
212    /// # Arguments
213    ///
214    /// * `input` - Operation input
215    /// * `device` - Target device (CUDA or CPU)
216    ///
217    /// # Returns
218    ///
219    /// Operation result from GPU or CPU path.
220    ///
221    /// # Errors
222    ///
223    /// Returns error if the operation fails or if Metal device is used (not supported).
224    fn dispatch(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
225        match device {
226            Device::Cuda(_) => self.dispatch_gpu(input, device),
227            Device::Cpu => self.dispatch_cpu(input, device),
228            Device::Metal(_) => Err(CoreError::device_not_available(
229                "Metal device not supported",
230            )),
231        }
232    }
233
234    /// Check if GPU dispatch is available for this operation.
235    ///
236    /// Default implementation checks if CUDA feature is enabled and
237    /// a CUDA device is available.
238    fn gpu_available(&self) -> bool {
239        #[cfg(feature = "cuda")]
240        {
241            matches!(Device::cuda_if_available(0), Ok(Device::Cuda(_)))
242        }
243        #[cfg(not(feature = "cuda"))]
244        {
245            false
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[derive(Clone)]
255    struct TestConfig {
256        value: i32,
257    }
258
259    impl ValidatableConfig for TestConfig {
260        fn validate(&self) -> Result<()> {
261            if self.value < 0 {
262                return Err(CoreError::invalid_config("value must be non-negative"));
263            }
264            Ok(())
265        }
266    }
267
268    #[test]
269    fn test_validatable_config() {
270        let valid = TestConfig { value: 10 };
271        assert!(valid.validate().is_ok());
272
273        let invalid = TestConfig { value: -1 };
274        assert!(invalid.validate().is_err());
275    }
276}