Skip to main content

ruvector_cnn/
lib.rs

1//! CNN Feature Extraction for Image Embeddings
2//!
3//! This crate provides pure Rust CNN-based feature extraction with SIMD acceleration.
4//! It is designed for CPU-only deployment including WASM environments.
5//!
6//! # Features
7//!
8//! - MobileNet-V3 Small/Large backbones
9//! - SIMD acceleration (AVX2, NEON, WASM SIMD128)
10//! - INT8 quantization support
11//! - Pure Rust (no BLAS/OpenCV dependencies)
12//! - Parallel batch processing with rayon (optional)
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use ruvector_cnn::{CnnEmbedder, EmbeddingConfig};
18//!
19//! let embedder = CnnEmbedder::new(EmbeddingConfig::default())?;
20//! let embedding = embedder.extract(&image_data, width, height)?;
21//! println!("Embedding dim: {}", embedding.len());
22//! ```
23//!
24//! # Using MobileNet Backbone
25//!
26//! ```rust,ignore
27//! use ruvector_cnn::embedding::MobileNetEmbedder;
28//!
29//! // Create a MobileNetV3-Small embedder
30//! let embedder = MobileNetEmbedder::v3_small()?;
31//!
32//! // Extract features from normalized float tensor (NCHW format)
33//! let features = embedder.extract(&image_tensor, 224, 224)?;
34//! println!("Feature dim: {}", features.len()); // 576 for V3-Small
35//! ```
36
37mod error;
38mod tensor;
39
40// Core modules (always available)
41pub mod layers;
42pub mod simd;
43
44// Optional modules (require backbone feature due to API incompatibility)
45#[cfg(feature = "backbone")]
46pub mod backbone;
47#[cfg(feature = "backbone")]
48pub mod embedding;
49
50// Contrastive learning (standalone, no backbone dependency)
51pub mod contrastive;
52
53pub use error::{CnnError, CnnResult};
54pub use tensor::Tensor;
55
56// Re-export backbone types (only when feature enabled)
57#[cfg(feature = "backbone")]
58pub use backbone::{
59    Backbone, BackboneExt, BackboneType,
60    MobileNetV3, MobileNetV3Config,
61    MobileNetV3Small, MobileNetV3Large, MobileNetConfig,
62    ConvBNActivation, InvertedResidual, SqueezeExcitation,
63    create_backbone, mobilenet_v3_small, mobilenet_v3_large,
64};
65
66// Re-export embedding types (only when feature enabled)
67#[cfg(feature = "backbone")]
68pub use embedding::{
69    MobileNetEmbedder, EmbeddingExtractorExt,
70    EmbeddingConfig as MobileNetEmbeddingConfig,
71    cosine_similarity, euclidean_distance,
72};
73
74// ParallelEmbedding requires the `parallel` feature (not yet implemented)
75// #[cfg(all(feature = "backbone", feature = "parallel"))]
76// pub use embedding::parallel::ParallelEmbedding;
77
78use serde::{Deserialize, Serialize};
79
80/// Configuration for CNN embedding extraction
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct EmbeddingConfig {
83    /// Input image size (assumes square input)
84    pub input_size: u32,
85    /// Output embedding dimension
86    pub embedding_dim: usize,
87    /// L2 normalize output embeddings
88    pub normalize: bool,
89    /// Use INT8 quantization
90    pub quantized: bool,
91}
92
93impl Default for EmbeddingConfig {
94    fn default() -> Self {
95        Self {
96            input_size: 224,
97            embedding_dim: 512,
98            normalize: true,
99            quantized: false,
100        }
101    }
102}
103
104/// CNN Embedder for feature extraction
105#[derive(Debug, Clone)]
106pub struct CnnEmbedder {
107    config: EmbeddingConfig,
108    weights: EmbedderWeights,
109}
110
111/// Internal weights storage
112#[derive(Debug, Clone)]
113struct EmbedderWeights {
114    /// Convolution weights (simplified representation)
115    conv_weights: Vec<f32>,
116    /// Batch norm parameters
117    bn_gamma: Vec<f32>,
118    bn_beta: Vec<f32>,
119    bn_mean: Vec<f32>,
120    bn_var: Vec<f32>,
121    /// Final projection weights
122    projection: Vec<f32>,
123}
124
125impl Default for EmbedderWeights {
126    fn default() -> Self {
127        use rand::Rng;
128        let mut rng = rand::thread_rng();
129
130        let conv_size = 3 * 3 * 3 * 16;
131        let bn_size = 16;
132        let proj_size = 16 * 512;
133
134        Self {
135            conv_weights: (0..conv_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
136            bn_gamma: vec![1.0; bn_size],
137            bn_beta: vec![0.0; bn_size],
138            bn_mean: vec![0.0; bn_size],
139            bn_var: vec![1.0; bn_size],
140            projection: (0..proj_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
141        }
142    }
143}
144
145impl CnnEmbedder {
146    /// Create a new CNN embedder with the given configuration
147    pub fn new(config: EmbeddingConfig) -> CnnResult<Self> {
148        let weights = EmbedderWeights::default();
149        Ok(Self { config, weights })
150    }
151
152    /// Create a MobileNet-V3 Small embedder
153    pub fn new_v3_small() -> CnnResult<Self> {
154        Self::new(EmbeddingConfig {
155            input_size: 224,
156            embedding_dim: 576,
157            normalize: true,
158            quantized: false,
159        })
160    }
161
162    /// Create a MobileNet-V3 Large embedder
163    pub fn new_v3_large() -> CnnResult<Self> {
164        Self::new(EmbeddingConfig {
165            input_size: 224,
166            embedding_dim: 960,
167            normalize: true,
168            quantized: false,
169        })
170    }
171
172    /// Extract embedding from image data (RGBA format)
173    pub fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
174        let expected_size = (width * height * 4) as usize;
175        if image_data.len() != expected_size {
176            return Err(CnnError::InvalidInput(format!(
177                "Expected {} bytes for {}x{} RGBA image, got {}",
178                expected_size, width, height, image_data.len()
179            )));
180        }
181
182        let rgb_float = self.preprocess(image_data, width, height)?;
183        let features = self.forward(&rgb_float)?;
184        let pooled = self.global_avg_pool(&features)?;
185        let mut embedding = self.project(&pooled)?;
186
187        if self.config.normalize {
188            self.l2_normalize(&mut embedding);
189        }
190
191        Ok(embedding)
192    }
193
194    /// Get the embedding dimension
195    pub fn embedding_dim(&self) -> usize {
196        self.config.embedding_dim
197    }
198
199    /// Get the input size
200    pub fn input_size(&self) -> u32 {
201        self.config.input_size
202    }
203
204    fn preprocess(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
205        let pixels = (width * height) as usize;
206        let mut rgb = Vec::with_capacity(pixels * 3);
207
208        let mean = [0.485, 0.456, 0.406];
209        let std = [0.229, 0.224, 0.225];
210
211        for i in 0..pixels {
212            let offset = i * 4;
213            rgb.push((image_data[offset] as f32 / 255.0 - mean[0]) / std[0]);
214            rgb.push((image_data[offset + 1] as f32 / 255.0 - mean[1]) / std[1]);
215            rgb.push((image_data[offset + 2] as f32 / 255.0 - mean[2]) / std[2]);
216        }
217
218        Ok(rgb)
219    }
220
221    fn forward(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
222        let conv_out = layers::conv2d_3x3(
223            input,
224            &self.weights.conv_weights,
225            3,
226            16,
227            self.config.input_size as usize,
228            self.config.input_size as usize,
229        );
230
231        let bn_out = layers::batch_norm(
232            &conv_out,
233            &self.weights.bn_gamma,
234            &self.weights.bn_beta,
235            &self.weights.bn_mean,
236            &self.weights.bn_var,
237            1e-5,
238        );
239
240        let activated: Vec<f32> = bn_out.iter().map(|&x| x.max(0.0)).collect();
241        Ok(activated)
242    }
243
244    fn global_avg_pool(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
245        let channels = 16;
246        let spatial = features.len() / channels;
247        let mut pooled = vec![0.0f32; channels];
248
249        for i in 0..spatial {
250            for c in 0..channels {
251                pooled[c] += features[i * channels + c];
252            }
253        }
254
255        let inv_spatial = 1.0 / spatial as f32;
256        for p in pooled.iter_mut() {
257            *p *= inv_spatial;
258        }
259
260        Ok(pooled)
261    }
262
263    fn project(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
264        let in_dim = features.len();
265        let out_dim = self.config.embedding_dim;
266        let mut output = vec![0.0f32; out_dim];
267
268        for o in 0..out_dim {
269            let mut sum = 0.0f32;
270            for i in 0..in_dim {
271                sum += features[i] * self.weights.projection[i * out_dim + o];
272            }
273            output[o] = sum;
274        }
275
276        Ok(output)
277    }
278
279    fn l2_normalize(&self, vec: &mut [f32]) {
280        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
281        if norm > 1e-10 {
282            for x in vec.iter_mut() {
283                *x /= norm;
284            }
285        }
286    }
287}
288
289/// Embedding extractor trait
290pub trait EmbeddingExtractor {
291    fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>>;
292    fn embedding_dim(&self) -> usize;
293}
294
295impl EmbeddingExtractor for CnnEmbedder {
296    fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
297        CnnEmbedder::extract(self, image_data, width, height)
298    }
299
300    fn embedding_dim(&self) -> usize {
301        CnnEmbedder::embedding_dim(self)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_embedder_creation() {
311        let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
312        assert_eq!(embedder.embedding_dim(), 512);
313    }
314
315    #[test]
316    fn test_v3_small() {
317        let embedder = CnnEmbedder::new_v3_small().unwrap();
318        assert_eq!(embedder.embedding_dim(), 576);
319    }
320
321    #[test]
322    fn test_v3_large() {
323        let embedder = CnnEmbedder::new_v3_large().unwrap();
324        assert_eq!(embedder.embedding_dim(), 960);
325    }
326
327    #[test]
328    fn test_extract_embedding() {
329        let embedder = CnnEmbedder::new(EmbeddingConfig {
330            input_size: 4,
331            embedding_dim: 8,
332            normalize: true,
333            quantized: false,
334        }).unwrap();
335
336        let image_data = vec![128u8; 4 * 4 * 4];
337        let embedding = embedder.extract(&image_data, 4, 4).unwrap();
338
339        assert_eq!(embedding.len(), 8);
340        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
341        assert!((norm - 1.0).abs() < 1e-5 || norm < 1e-10);
342    }
343
344    #[test]
345    fn test_invalid_input() {
346        let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
347        let image_data = vec![0u8; 100];
348        let result = embedder.extract(&image_data, 10, 10);
349        assert!(result.is_err());
350    }
351}