Skip to main content

ruvector_cnn/
lib.rs

1//! CNN Feature Extraction for Image Embeddings
2//!
3//! This crate provides pure Rust CNN-based feature extraction with SIMD acceleration.
4//! It is designed for CPU-only deployment including WASM environments.
5//!
6//! # Features
7//!
8//! - MobileNet-V3 Small/Large backbones
9//! - SIMD acceleration (AVX2, NEON, WASM SIMD128)
10//! - INT8 quantization support
11//! - Pure Rust (no BLAS/OpenCV dependencies)
12//! - Parallel batch processing with rayon (optional)
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use ruvector_cnn::{CnnEmbedder, EmbeddingConfig};
18//!
19//! let embedder = CnnEmbedder::new(EmbeddingConfig::default())?;
20//! let embedding = embedder.extract(&image_data, width, height)?;
21//! println!("Embedding dim: {}", embedding.len());
22//! ```
23//!
24//! # Using MobileNet Backbone
25//!
26//! ```rust,ignore
27//! use ruvector_cnn::embedding::MobileNetEmbedder;
28//!
29//! // Create a MobileNetV3-Small embedder
30//! let embedder = MobileNetEmbedder::v3_small()?;
31//!
32//! // Extract features from normalized float tensor (NCHW format)
33//! let features = embedder.extract(&image_tensor, 224, 224)?;
34//! println!("Feature dim: {}", features.len()); // 576 for V3-Small
35//! ```
36
37mod error;
38mod tensor;
39
40// Core modules (always available)
41pub mod layers;
42pub mod simd;
43pub mod kernels;
44
45// Quantization support (INT8 optimization)
46pub mod quantize;
47pub mod int8;
48
49// Optional modules (require backbone feature due to API incompatibility)
50#[cfg(feature = "backbone")]
51pub mod backbone;
52#[cfg(feature = "backbone")]
53pub mod embedding;
54
55// Contrastive learning (standalone, no backbone dependency)
56pub mod contrastive;
57
58pub use error::{CnnError, CnnResult};
59pub use tensor::Tensor;
60
61// Re-export backbone types (only when feature enabled)
62#[cfg(feature = "backbone")]
63pub use backbone::{
64    Backbone, BackboneExt, BackboneType,
65    MobileNetV3, MobileNetV3Config,
66    MobileNetV3Small, MobileNetV3Large, MobileNetConfig,
67    ConvBNActivation, InvertedResidual, SqueezeExcitation,
68    create_backbone, mobilenet_v3_small, mobilenet_v3_large,
69};
70
71// Re-export embedding types (only when feature enabled)
72#[cfg(feature = "backbone")]
73pub use embedding::{
74    MobileNetEmbedder, EmbeddingExtractorExt,
75    EmbeddingConfig as MobileNetEmbeddingConfig,
76    cosine_similarity, euclidean_distance,
77};
78
79// ParallelEmbedding requires the `parallel` feature (not yet implemented)
80// #[cfg(all(feature = "backbone", feature = "parallel"))]
81// pub use embedding::parallel::ParallelEmbedding;
82
83use serde::{Deserialize, Serialize};
84
85/// Configuration for CNN embedding extraction
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct EmbeddingConfig {
88    /// Input image size (assumes square input)
89    pub input_size: u32,
90    /// Output embedding dimension
91    pub embedding_dim: usize,
92    /// L2 normalize output embeddings
93    pub normalize: bool,
94    /// Use INT8 quantization
95    pub quantized: bool,
96}
97
98impl Default for EmbeddingConfig {
99    fn default() -> Self {
100        Self {
101            input_size: 224,
102            embedding_dim: 512,
103            normalize: true,
104            quantized: false,
105        }
106    }
107}
108
109/// CNN Embedder for feature extraction
110#[derive(Debug, Clone)]
111pub struct CnnEmbedder {
112    config: EmbeddingConfig,
113    weights: EmbedderWeights,
114}
115
116/// Internal weights storage
117#[derive(Debug, Clone)]
118struct EmbedderWeights {
119    /// Convolution weights (simplified representation)
120    conv_weights: Vec<f32>,
121    /// Batch norm parameters
122    bn_gamma: Vec<f32>,
123    bn_beta: Vec<f32>,
124    bn_mean: Vec<f32>,
125    bn_var: Vec<f32>,
126    /// Final projection weights
127    projection: Vec<f32>,
128}
129
130impl Default for EmbedderWeights {
131    fn default() -> Self {
132        use rand::Rng;
133        let mut rng = rand::thread_rng();
134
135        let conv_size = 3 * 3 * 3 * 16;
136        let bn_size = 16;
137        let proj_size = 16 * 512;
138
139        Self {
140            conv_weights: (0..conv_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
141            bn_gamma: vec![1.0; bn_size],
142            bn_beta: vec![0.0; bn_size],
143            bn_mean: vec![0.0; bn_size],
144            bn_var: vec![1.0; bn_size],
145            projection: (0..proj_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
146        }
147    }
148}
149
150impl CnnEmbedder {
151    /// Create a new CNN embedder with the given configuration
152    pub fn new(config: EmbeddingConfig) -> CnnResult<Self> {
153        let weights = EmbedderWeights::default();
154        Ok(Self { config, weights })
155    }
156
157    /// Create a MobileNet-V3 Small embedder
158    pub fn new_v3_small() -> CnnResult<Self> {
159        Self::new(EmbeddingConfig {
160            input_size: 224,
161            embedding_dim: 576,
162            normalize: true,
163            quantized: false,
164        })
165    }
166
167    /// Create a MobileNet-V3 Large embedder
168    pub fn new_v3_large() -> CnnResult<Self> {
169        Self::new(EmbeddingConfig {
170            input_size: 224,
171            embedding_dim: 960,
172            normalize: true,
173            quantized: false,
174        })
175    }
176
177    /// Extract embedding from image data (RGBA format)
178    pub fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
179        let expected_size = (width * height * 4) as usize;
180        if image_data.len() != expected_size {
181            return Err(CnnError::InvalidInput(format!(
182                "Expected {} bytes for {}x{} RGBA image, got {}",
183                expected_size, width, height, image_data.len()
184            )));
185        }
186
187        let rgb_float = self.preprocess(image_data, width, height)?;
188        let features = self.forward(&rgb_float)?;
189        let pooled = self.global_avg_pool(&features)?;
190        let mut embedding = self.project(&pooled)?;
191
192        if self.config.normalize {
193            self.l2_normalize(&mut embedding);
194        }
195
196        Ok(embedding)
197    }
198
199    /// Get the embedding dimension
200    pub fn embedding_dim(&self) -> usize {
201        self.config.embedding_dim
202    }
203
204    /// Get the input size
205    pub fn input_size(&self) -> u32 {
206        self.config.input_size
207    }
208
209    fn preprocess(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
210        let pixels = (width * height) as usize;
211        let mut rgb = Vec::with_capacity(pixels * 3);
212
213        let mean = [0.485, 0.456, 0.406];
214        let std = [0.229, 0.224, 0.225];
215
216        for i in 0..pixels {
217            let offset = i * 4;
218            rgb.push((image_data[offset] as f32 / 255.0 - mean[0]) / std[0]);
219            rgb.push((image_data[offset + 1] as f32 / 255.0 - mean[1]) / std[1]);
220            rgb.push((image_data[offset + 2] as f32 / 255.0 - mean[2]) / std[2]);
221        }
222
223        Ok(rgb)
224    }
225
226    fn forward(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
227        let conv_out = layers::conv2d_3x3(
228            input,
229            &self.weights.conv_weights,
230            3,
231            16,
232            self.config.input_size as usize,
233            self.config.input_size as usize,
234        );
235
236        let bn_out = layers::batch_norm(
237            &conv_out,
238            &self.weights.bn_gamma,
239            &self.weights.bn_beta,
240            &self.weights.bn_mean,
241            &self.weights.bn_var,
242            1e-5,
243        );
244
245        let activated: Vec<f32> = bn_out.iter().map(|&x| x.max(0.0)).collect();
246        Ok(activated)
247    }
248
249    fn global_avg_pool(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
250        let channels = 16;
251        let spatial = features.len() / channels;
252        let mut pooled = vec![0.0f32; channels];
253
254        for i in 0..spatial {
255            for c in 0..channels {
256                pooled[c] += features[i * channels + c];
257            }
258        }
259
260        let inv_spatial = 1.0 / spatial as f32;
261        for p in pooled.iter_mut() {
262            *p *= inv_spatial;
263        }
264
265        Ok(pooled)
266    }
267
268    fn project(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
269        let in_dim = features.len();
270        let out_dim = self.config.embedding_dim;
271        let mut output = vec![0.0f32; out_dim];
272
273        for o in 0..out_dim {
274            let mut sum = 0.0f32;
275            for i in 0..in_dim {
276                sum += features[i] * self.weights.projection[i * out_dim + o];
277            }
278            output[o] = sum;
279        }
280
281        Ok(output)
282    }
283
284    fn l2_normalize(&self, vec: &mut [f32]) {
285        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
286        if norm > 1e-10 {
287            for x in vec.iter_mut() {
288                *x /= norm;
289            }
290        }
291    }
292}
293
294/// Embedding extractor trait
295pub trait EmbeddingExtractor {
296    fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>>;
297    fn embedding_dim(&self) -> usize;
298}
299
300impl EmbeddingExtractor for CnnEmbedder {
301    fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
302        CnnEmbedder::extract(self, image_data, width, height)
303    }
304
305    fn embedding_dim(&self) -> usize {
306        CnnEmbedder::embedding_dim(self)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_embedder_creation() {
316        let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
317        assert_eq!(embedder.embedding_dim(), 512);
318    }
319
320    #[test]
321    fn test_v3_small() {
322        let embedder = CnnEmbedder::new_v3_small().unwrap();
323        assert_eq!(embedder.embedding_dim(), 576);
324    }
325
326    #[test]
327    fn test_v3_large() {
328        let embedder = CnnEmbedder::new_v3_large().unwrap();
329        assert_eq!(embedder.embedding_dim(), 960);
330    }
331
332    #[test]
333    fn test_extract_embedding() {
334        let embedder = CnnEmbedder::new(EmbeddingConfig {
335            input_size: 4,
336            embedding_dim: 8,
337            normalize: true,
338            quantized: false,
339        }).unwrap();
340
341        let image_data = vec![128u8; 4 * 4 * 4];
342        let embedding = embedder.extract(&image_data, 4, 4).unwrap();
343
344        assert_eq!(embedding.len(), 8);
345        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
346        assert!((norm - 1.0).abs() < 1e-5 || norm < 1e-10);
347    }
348
349    #[test]
350    fn test_invalid_input() {
351        let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
352        let image_data = vec![0u8; 100];
353        let result = embedder.extract(&image_data, 10, 10);
354        assert!(result.is_err());
355    }
356}