1mod error;
38mod tensor;
39
40pub mod layers;
42pub mod simd;
43pub mod kernels;
44
45pub mod quantize;
47pub mod int8;
48
49#[cfg(feature = "backbone")]
51pub mod backbone;
52#[cfg(feature = "backbone")]
53pub mod embedding;
54
55pub mod contrastive;
57
58pub use error::{CnnError, CnnResult};
59pub use tensor::Tensor;
60
61#[cfg(feature = "backbone")]
63pub use backbone::{
64 Backbone, BackboneExt, BackboneType,
65 MobileNetV3, MobileNetV3Config,
66 MobileNetV3Small, MobileNetV3Large, MobileNetConfig,
67 ConvBNActivation, InvertedResidual, SqueezeExcitation,
68 create_backbone, mobilenet_v3_small, mobilenet_v3_large,
69};
70
71#[cfg(feature = "backbone")]
73pub use embedding::{
74 MobileNetEmbedder, EmbeddingExtractorExt,
75 EmbeddingConfig as MobileNetEmbeddingConfig,
76 cosine_similarity, euclidean_distance,
77};
78
79use serde::{Deserialize, Serialize};
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct EmbeddingConfig {
88 pub input_size: u32,
90 pub embedding_dim: usize,
92 pub normalize: bool,
94 pub quantized: bool,
96}
97
98impl Default for EmbeddingConfig {
99 fn default() -> Self {
100 Self {
101 input_size: 224,
102 embedding_dim: 512,
103 normalize: true,
104 quantized: false,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct CnnEmbedder {
112 config: EmbeddingConfig,
113 weights: EmbedderWeights,
114}
115
116#[derive(Debug, Clone)]
118struct EmbedderWeights {
119 conv_weights: Vec<f32>,
121 bn_gamma: Vec<f32>,
123 bn_beta: Vec<f32>,
124 bn_mean: Vec<f32>,
125 bn_var: Vec<f32>,
126 projection: Vec<f32>,
128}
129
130impl Default for EmbedderWeights {
131 fn default() -> Self {
132 use rand::Rng;
133 let mut rng = rand::thread_rng();
134
135 let conv_size = 3 * 3 * 3 * 16;
136 let bn_size = 16;
137 let proj_size = 16 * 512;
138
139 Self {
140 conv_weights: (0..conv_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
141 bn_gamma: vec![1.0; bn_size],
142 bn_beta: vec![0.0; bn_size],
143 bn_mean: vec![0.0; bn_size],
144 bn_var: vec![1.0; bn_size],
145 projection: (0..proj_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
146 }
147 }
148}
149
150impl CnnEmbedder {
151 pub fn new(config: EmbeddingConfig) -> CnnResult<Self> {
153 let weights = EmbedderWeights::default();
154 Ok(Self { config, weights })
155 }
156
157 pub fn new_v3_small() -> CnnResult<Self> {
159 Self::new(EmbeddingConfig {
160 input_size: 224,
161 embedding_dim: 576,
162 normalize: true,
163 quantized: false,
164 })
165 }
166
167 pub fn new_v3_large() -> CnnResult<Self> {
169 Self::new(EmbeddingConfig {
170 input_size: 224,
171 embedding_dim: 960,
172 normalize: true,
173 quantized: false,
174 })
175 }
176
177 pub fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
179 let expected_size = (width * height * 4) as usize;
180 if image_data.len() != expected_size {
181 return Err(CnnError::InvalidInput(format!(
182 "Expected {} bytes for {}x{} RGBA image, got {}",
183 expected_size, width, height, image_data.len()
184 )));
185 }
186
187 let rgb_float = self.preprocess(image_data, width, height)?;
188 let features = self.forward(&rgb_float)?;
189 let pooled = self.global_avg_pool(&features)?;
190 let mut embedding = self.project(&pooled)?;
191
192 if self.config.normalize {
193 self.l2_normalize(&mut embedding);
194 }
195
196 Ok(embedding)
197 }
198
199 pub fn embedding_dim(&self) -> usize {
201 self.config.embedding_dim
202 }
203
204 pub fn input_size(&self) -> u32 {
206 self.config.input_size
207 }
208
209 fn preprocess(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
210 let pixels = (width * height) as usize;
211 let mut rgb = Vec::with_capacity(pixels * 3);
212
213 let mean = [0.485, 0.456, 0.406];
214 let std = [0.229, 0.224, 0.225];
215
216 for i in 0..pixels {
217 let offset = i * 4;
218 rgb.push((image_data[offset] as f32 / 255.0 - mean[0]) / std[0]);
219 rgb.push((image_data[offset + 1] as f32 / 255.0 - mean[1]) / std[1]);
220 rgb.push((image_data[offset + 2] as f32 / 255.0 - mean[2]) / std[2]);
221 }
222
223 Ok(rgb)
224 }
225
226 fn forward(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
227 let conv_out = layers::conv2d_3x3(
228 input,
229 &self.weights.conv_weights,
230 3,
231 16,
232 self.config.input_size as usize,
233 self.config.input_size as usize,
234 );
235
236 let bn_out = layers::batch_norm(
237 &conv_out,
238 &self.weights.bn_gamma,
239 &self.weights.bn_beta,
240 &self.weights.bn_mean,
241 &self.weights.bn_var,
242 1e-5,
243 );
244
245 let activated: Vec<f32> = bn_out.iter().map(|&x| x.max(0.0)).collect();
246 Ok(activated)
247 }
248
249 fn global_avg_pool(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
250 let channels = 16;
251 let spatial = features.len() / channels;
252 let mut pooled = vec![0.0f32; channels];
253
254 for i in 0..spatial {
255 for c in 0..channels {
256 pooled[c] += features[i * channels + c];
257 }
258 }
259
260 let inv_spatial = 1.0 / spatial as f32;
261 for p in pooled.iter_mut() {
262 *p *= inv_spatial;
263 }
264
265 Ok(pooled)
266 }
267
268 fn project(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
269 let in_dim = features.len();
270 let out_dim = self.config.embedding_dim;
271 let mut output = vec![0.0f32; out_dim];
272
273 for o in 0..out_dim {
274 let mut sum = 0.0f32;
275 for i in 0..in_dim {
276 sum += features[i] * self.weights.projection[i * out_dim + o];
277 }
278 output[o] = sum;
279 }
280
281 Ok(output)
282 }
283
284 fn l2_normalize(&self, vec: &mut [f32]) {
285 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
286 if norm > 1e-10 {
287 for x in vec.iter_mut() {
288 *x /= norm;
289 }
290 }
291 }
292}
293
294pub trait EmbeddingExtractor {
296 fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>>;
297 fn embedding_dim(&self) -> usize;
298}
299
300impl EmbeddingExtractor for CnnEmbedder {
301 fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
302 CnnEmbedder::extract(self, image_data, width, height)
303 }
304
305 fn embedding_dim(&self) -> usize {
306 CnnEmbedder::embedding_dim(self)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_embedder_creation() {
316 let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
317 assert_eq!(embedder.embedding_dim(), 512);
318 }
319
320 #[test]
321 fn test_v3_small() {
322 let embedder = CnnEmbedder::new_v3_small().unwrap();
323 assert_eq!(embedder.embedding_dim(), 576);
324 }
325
326 #[test]
327 fn test_v3_large() {
328 let embedder = CnnEmbedder::new_v3_large().unwrap();
329 assert_eq!(embedder.embedding_dim(), 960);
330 }
331
332 #[test]
333 fn test_extract_embedding() {
334 let embedder = CnnEmbedder::new(EmbeddingConfig {
335 input_size: 4,
336 embedding_dim: 8,
337 normalize: true,
338 quantized: false,
339 }).unwrap();
340
341 let image_data = vec![128u8; 4 * 4 * 4];
342 let embedding = embedder.extract(&image_data, 4, 4).unwrap();
343
344 assert_eq!(embedding.len(), 8);
345 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
346 assert!((norm - 1.0).abs() < 1e-5 || norm < 1e-10);
347 }
348
349 #[test]
350 fn test_invalid_input() {
351 let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
352 let image_data = vec![0u8; 100];
353 let result = embedder.extract(&image_data, 10, 10);
354 assert!(result.is_err());
355 }
356}