rust_ai_core/traits.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Common traits for the rust-ai ecosystem.
5//!
6//! This module defines shared trait interfaces that enable consistent APIs
7//! across all rust-ai crates. Implementing these traits ensures interoperability
8//! and allows crates to be composed together seamlessly.
9//!
10//! ## Core Traits
11//!
12//! - [`ValidatableConfig`] - Configuration validation interface
13//! - [`Quantize`] - Tensor quantization (full precision → quantized)
14//! - [`Dequantize`] - Tensor dequantization (quantized → full precision)
15//! - [`GpuDispatchable`] - GPU/CPU kernel dispatch pattern
16//!
17//! ## Implementation Guidelines
18//!
19//! When implementing these traits:
20//!
21//! 1. **Validation**: Use `ValidatableConfig::validate()` in constructors
22//! 2. **GPU-first**: `GpuDispatchable` should prefer GPU, warn on CPU
23//! 3. **Error handling**: Return `CoreError` variants appropriately
24
25use crate::error::{CoreError, Result};
26use candle_core::{Device, Tensor};
27
28/// Configuration validation trait.
29///
30/// All configuration structs should implement this trait to provide
31/// consistent validation across the ecosystem.
32///
33/// # Example
34///
35/// ```rust
36/// use rust_ai_core::{ValidatableConfig, CoreError, Result};
37///
38/// #[derive(Clone)]
39/// struct LoraConfig {
40/// rank: usize,
41/// alpha: f32,
42/// }
43///
44/// impl ValidatableConfig for LoraConfig {
45/// fn validate(&self) -> Result<()> {
46/// if self.rank == 0 {
47/// return Err(CoreError::invalid_config("rank must be > 0"));
48/// }
49/// if self.alpha <= 0.0 {
50/// return Err(CoreError::invalid_config("alpha must be positive"));
51/// }
52/// Ok(())
53/// }
54/// }
55/// ```
56pub trait ValidatableConfig: Clone + Send + Sync {
57 /// Validate the configuration parameters.
58 ///
59 /// # Returns
60 ///
61 /// `Ok(())` if configuration is valid.
62 ///
63 /// # Errors
64 ///
65 /// Returns `CoreError::InvalidConfig` if validation fails.
66 fn validate(&self) -> Result<()>;
67}
68
69/// Tensor quantization trait.
70///
71/// Converts full-precision tensors to quantized representation.
72/// Implementations may use various quantization schemes (NF4, FP4, ternary, etc.).
73///
74/// # Type Parameters
75///
76/// - `Q`: The quantized tensor type (e.g., `QuantizedTensor`, `TernaryVector`)
77///
78/// # Example
79///
80/// ```rust,ignore
81/// use rust_ai_core::Quantize;
82///
83/// struct Nf4Quantizer;
84///
85/// impl Quantize<Nf4Tensor> for Nf4Quantizer {
86/// fn quantize(&self, tensor: &Tensor, device: &Device) -> Result<Nf4Tensor> {
87/// // Quantize to NF4 format
88/// }
89/// }
90/// ```
91pub trait Quantize<Q>: Send + Sync {
92 /// Quantize a tensor.
93 ///
94 /// # Arguments
95 ///
96 /// * `tensor` - Full-precision input tensor
97 /// * `device` - Target device for the quantized output
98 ///
99 /// # Returns
100 ///
101 /// Quantized representation of the input tensor.
102 ///
103 /// # Errors
104 ///
105 /// May return errors for unsupported dtypes, shapes, or device issues.
106 fn quantize(&self, tensor: &Tensor, device: &Device) -> Result<Q>;
107}
108
109/// Tensor dequantization trait.
110///
111/// Converts quantized tensors back to full precision for computation.
112///
113/// # Type Parameters
114///
115/// - `Q`: The quantized tensor type
116///
117/// # Example
118///
119/// ```rust,ignore
120/// use rust_ai_core::Dequantize;
121///
122/// impl Dequantize<Nf4Tensor> for Nf4Quantizer {
123/// fn dequantize(&self, quantized: &Nf4Tensor, device: &Device) -> Result<Tensor> {
124/// // Restore to f32/f16/bf16
125/// }
126/// }
127/// ```
128pub trait Dequantize<Q>: Send + Sync {
129 /// Dequantize a tensor.
130 ///
131 /// # Arguments
132 ///
133 /// * `quantized` - Quantized input tensor
134 /// * `device` - Target device for the dequantized output
135 ///
136 /// # Returns
137 ///
138 /// Full-precision tensor.
139 ///
140 /// # Errors
141 ///
142 /// May return errors for corrupted quantized data or device issues.
143 fn dequantize(&self, quantized: &Q, device: &Device) -> Result<Tensor>;
144}
145
146/// GPU/CPU dispatch trait for operations with both implementations.
147///
148/// This trait enables the CUDA-first pattern: operations that have both
149/// GPU (`CubeCL`) and CPU implementations should implement this trait to
150/// automatically route to the appropriate backend.
151///
152/// # Design Pattern
153///
154/// ```rust,ignore
155/// use rust_ai_core::{GpuDispatchable, warn_if_cpu};
156///
157/// struct FlashAttention;
158///
159/// impl GpuDispatchable for FlashAttention {
160/// type Input = (Tensor, Tensor, Tensor); // Q, K, V
161/// type Output = Tensor;
162///
163/// fn dispatch_gpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
164/// // CubeCL Flash Attention kernel
165/// }
166///
167/// fn dispatch_cpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
168/// // Candle-based fallback
169/// warn_if_cpu(device, "unsloth-rs");
170/// // ... fallback implementation
171/// }
172/// }
173/// ```
174pub trait GpuDispatchable: Send + Sync {
175 /// Input type for the operation.
176 type Input;
177
178 /// Output type for the operation.
179 type Output;
180
181 /// Execute operation on GPU using `CubeCL` kernels.
182 ///
183 /// # Arguments
184 ///
185 /// * `input` - Operation input
186 /// * `device` - Must be a CUDA device
187 ///
188 /// # Errors
189 ///
190 /// Returns `CoreError::KernelError` if kernel execution fails.
191 fn dispatch_gpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output>;
192
193 /// Execute operation on CPU (fallback).
194 ///
195 /// This should emit a warning via `warn_if_cpu()` before execution.
196 ///
197 /// # Arguments
198 ///
199 /// * `input` - Operation input
200 /// * `device` - CPU device
201 ///
202 /// # Errors
203 ///
204 /// Returns appropriate error if operation fails.
205 fn dispatch_cpu(&self, input: &Self::Input, device: &Device) -> Result<Self::Output>;
206
207 /// Automatically dispatch to GPU or CPU based on device.
208 ///
209 /// This is the primary entry point. It checks the device type and
210 /// routes to the appropriate implementation.
211 ///
212 /// # Arguments
213 ///
214 /// * `input` - Operation input
215 /// * `device` - Target device (CUDA or CPU)
216 ///
217 /// # Returns
218 ///
219 /// Operation result from GPU or CPU path.
220 ///
221 /// # Errors
222 ///
223 /// Returns error if the operation fails or if Metal device is used (not supported).
224 fn dispatch(&self, input: &Self::Input, device: &Device) -> Result<Self::Output> {
225 match device {
226 Device::Cuda(_) => self.dispatch_gpu(input, device),
227 Device::Cpu => self.dispatch_cpu(input, device),
228 Device::Metal(_) => Err(CoreError::device_not_available(
229 "Metal device not supported",
230 )),
231 }
232 }
233
234 /// Check if GPU dispatch is available for this operation.
235 ///
236 /// Default implementation checks if CUDA feature is enabled and
237 /// a CUDA device is available.
238 fn gpu_available(&self) -> bool {
239 #[cfg(feature = "cuda")]
240 {
241 matches!(Device::cuda_if_available(0), Ok(Device::Cuda(_)))
242 }
243 #[cfg(not(feature = "cuda"))]
244 {
245 false
246 }
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[derive(Clone)]
255 struct TestConfig {
256 value: i32,
257 }
258
259 impl ValidatableConfig for TestConfig {
260 fn validate(&self) -> Result<()> {
261 if self.value < 0 {
262 return Err(CoreError::invalid_config("value must be non-negative"));
263 }
264 Ok(())
265 }
266 }
267
268 #[test]
269 fn test_validatable_config() {
270 let valid = TestConfig { value: 10 };
271 assert!(valid.validate().is_ok());
272
273 let invalid = TestConfig { value: -1 };
274 assert!(invalid.validate().is_err());
275 }
276}