Skip to main content

scirs2_neural/models/architectures/
clip.rs

1//! CLIP (Contrastive Language-Image Pre-training) Architecture
2//!
3//! This module implements a CLIP-like architecture as described in
4//! "Learning Transferable Visual Models From Natural Language Supervision"
5//! (<https://arxiv.org/abs/2103.00020>)
6//! CLIP is a multi-modal model that learns visual concepts from natural language supervision,
7//! enabling zero-shot transfer to various visual classification tasks.
8
9use crate::error::Result;
10use crate::layers::{Dense, Layer, LayerNorm, Sequential};
11// TODO: Re-enable once PatchEmbedding is implemented
12// use crate::models::architectures::{ViTConfig, VisionTransformer};
13use crate::models::architectures::ViTConfig;
14use crate::transformer::TransformerEncoderLayer;
15use crate::utils::positional_encoding::{PositionalEncoding, SinusoidalPositionalEncoding};
16use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
17use scirs2_core::numeric::{Float, NumAssign};
18use scirs2_core::random::{rngs::SmallRng, SeedableRng};
19use serde::{Deserialize, Serialize};
20use std::fmt::Debug;
21/// Type alias for CLIP output (image embeddings, text embeddings, logit scale)
22type ClipOutput<F> = (Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>);
23/// Configuration for the text encoder in CLIP
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CLIPTextConfig {
26    /// Vocabulary size
27    pub vocab_size: usize,
28    /// Hidden size for embeddings and transformer
29    pub hidden_size: usize,
30    /// Intermediate size in feed-forward layers
31    pub intermediate_size: usize,
32    /// Number of transformer layers
33    pub num_layers: usize,
34    /// Number of attention heads
35    pub num_heads: usize,
36    /// Maximum sequence length
37    pub max_position_embeddings: usize,
38    /// Dropout rate
39    pub dropout_rate: f64,
40    /// Layer norm epsilon
41    pub layer_norm_eps: f64,
42}
43/// Configuration for the CLIP model
44pub struct CLIPConfig {
45    /// Text encoder configuration
46    pub text_config: CLIPTextConfig,
47    /// Vision encoder configuration
48    pub vision_config: ViTConfig,
49    /// Projection dimension for both text and vision encoders
50    pub projection_dim: usize,
51    /// Whether to include the classifier
52    pub include_head: bool,
53    /// Number of classes for the classifier (if include_head is true)
54    pub num_classes: usize,
55}
56
57impl Default for CLIPTextConfig {
58    fn default() -> Self {
59        Self {
60            vocab_size: 49408,
61            hidden_size: 512,
62            intermediate_size: 2048,
63            num_layers: 12,
64            num_heads: 8,
65            max_position_embeddings: 77,
66            dropout_rate: 0.1,
67            layer_norm_eps: 1e-5,
68        }
69    }
70}
71
72/// Text encoder for CLIP model
73#[derive(Debug, Clone)]
74pub struct CLIPTextEncoder<
75    F: Float
76        + Debug
77        + ScalarOperand
78        + Send
79        + Sync
80        + 'static
81        + scirs2_core::simd_ops::SimdUnifiedOps
82        + NumAssign,
83> {
84    /// Token embedding
85    pub token_embedding: Sequential<F>,
86    /// Position embedding
87    pub position_embedding: SinusoidalPositionalEncoding<F>,
88    /// Transformer encoder layers
89    pub encoder_layers: Vec<TransformerEncoderLayer<F>>,
90    /// Layer normalization
91    pub layer_norm: LayerNorm<F>,
92    /// Final projection layer
93    pub projection: Dense<F>,
94    /// Text configuration
95    pub config: CLIPTextConfig,
96}
97
98impl<
99        F: Float
100            + Debug
101            + ScalarOperand
102            + Send
103            + Sync
104            + scirs2_core::simd_ops::SimdUnifiedOps
105            + NumAssign,
106    > CLIPTextEncoder<F>
107{
108    /// Create a new CLIPTextEncoder
109    pub fn new(_config: CLIPTextConfig, projection_dim: usize) -> Result<Self> {
110        // Token embedding
111        let mut token_embedding = Sequential::new();
112        let mut rng = SmallRng::from_seed([42; 32]);
113        token_embedding.add(Dense::<F>::new(
114            _config.vocab_size,
115            _config.hidden_size,
116            None,
117            &mut rng,
118        )?);
119        // Position embedding
120        let position_embedding = SinusoidalPositionalEncoding::<F>::new(
121            _config.hidden_size,
122            _config.max_position_embeddings,
123        );
124        // Transformer encoder layers
125        let mut encoder_layers = Vec::with_capacity(_config.num_layers);
126        for _i in 0.._config.num_layers {
127            encoder_layers.push(TransformerEncoderLayer::<F>::new(
128                _config.hidden_size,
129                _config.num_heads,
130                _config.intermediate_size,
131                _config.dropout_rate,
132                _config.layer_norm_eps,
133                &mut rng,
134            )?);
135        }
136        // Layer normalization
137        let layer_norm =
138            LayerNorm::<F>::new(_config.hidden_size, _config.layer_norm_eps, &mut rng)?;
139        // Projection
140        let projection = Dense::<F>::new(_config.hidden_size, projection_dim, None, &mut rng)?;
141        Ok(Self {
142            token_embedding,
143            position_embedding,
144            encoder_layers,
145            layer_norm,
146            projection,
147            config: _config,
148        })
149    }
150}
151impl<
152        F: Float
153            + Debug
154            + ScalarOperand
155            + Send
156            + Sync
157            + scirs2_core::simd_ops::SimdUnifiedOps
158            + 'static
159            + NumAssign,
160    > Layer<F> for CLIPTextEncoder<F>
161{
162    fn as_any(&self) -> &dyn std::any::Any {
163        self
164    }
165
166    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
167        self
168    }
169
170    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
171        // Apply token embedding
172        let mut x = self.token_embedding.forward(input)?;
173        // Apply position embedding
174        x = self.position_embedding.forward(&x)?;
175        // Apply transformer encoder layers
176        for layer in &self.encoder_layers {
177            x = layer.forward(&x)?;
178        }
179        // Apply layer normalization
180        x = self.layer_norm.forward(&x)?;
181        // Extract the [CLS] token embedding (assuming it's the first token)
182        let batch_size = x.shape()[0];
183        let hidden_size = x.shape()[2];
184        let cls_token = x
185            .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
186            .into_shape_with_order((batch_size, hidden_size))?;
187        // Apply projection - convert to owned array to fix the reference type
188        let cls_token_owned = cls_token.to_owned().into_dyn();
189        let output = self.projection.forward(&cls_token_owned)?;
190        Ok(output)
191    }
192
193    fn backward(
194        &self,
195        input: &Array<F, IxDyn>,
196        grad_output: &Array<F, IxDyn>,
197    ) -> Result<Array<F, IxDyn>> {
198        // CLIPTextEncoder backward: reverse the forward pass
199        // Backward through projection
200        let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
201        // Expand CLS token gradient back to full sequence
202        // Note: This is simplified - in reality we need to handle the slicing properly
203        let batch_size = input.shape()[0];
204        let seq_len = input.shape()[1];
205        let hidden_size = grad_after_proj.shape()[1];
206        // Create gradient for full sequence (most gradients go to CLS token position)
207        let mut grad_full_seq =
208            Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, seq_len, hidden_size]));
209        // Put the gradient at the CLS token position (index 0)
210        for i in 0..batch_size {
211            for j in 0..hidden_size {
212                grad_full_seq[[i, 0, j]] = grad_after_proj[[i, j]];
213            }
214        }
215        let grad_full_seq = grad_full_seq.into_dyn();
216        // Backward through layer normalization
217        let mut grad = self.layer_norm.backward(&grad_full_seq, &grad_full_seq)?;
218        // Backward through transformer encoder layers in reverse order
219        for layer in self.encoder_layers.iter().rev() {
220            grad = layer.backward(&grad, &grad)?;
221        }
222        // TODO: Backward through position embedding when backward method is implemented
223        // grad = self.position_embedding.backward(&grad, &grad)?;
224        // Backward through token embedding
225        let grad_input = self.token_embedding.backward(input, &grad)?;
226        Ok(grad_input)
227    }
228
229    fn update(&mut self, learning_rate: F) -> Result<()> {
230        // Update all components
231        // Update token embedding
232        self.token_embedding.update(learning_rate)?;
233        // Update position embedding
234        self.position_embedding.update(learning_rate)?;
235        // Update all encoder layers
236        for layer in &mut self.encoder_layers {
237            layer.update(learning_rate)?;
238        }
239        // Update layer normalization
240        self.layer_norm.update(learning_rate)?;
241        // Update projection layer
242        self.projection.update(learning_rate)?;
243        Ok(())
244    }
245
246    fn params(&self) -> Vec<Array<F, IxDyn>> {
247        let mut params = Vec::new();
248        params.extend(self.token_embedding.params());
249        // position_embedding has no trainable parameters, so we skip it
250        for layer in &self.encoder_layers {
251            params.extend(layer.params());
252        }
253        params.extend(self.layer_norm.params());
254        params.extend(self.projection.params());
255        params
256    }
257
258    fn set_training(&mut self, training: bool) {
259        self.token_embedding.set_training(training);
260        self.position_embedding.set_training(training);
261        for layer in &mut self.encoder_layers {
262            layer.set_training(training);
263        }
264        self.layer_norm.set_training(training);
265        self.projection.set_training(training);
266    }
267
268    fn is_training(&self) -> bool {
269        self.token_embedding.is_training()
270    }
271}
272// Vision encoder for CLIP model (uses Vision Transformer)
273// TODO: Re-enable once PatchEmbedding is implemented
274// pub struct CLIPVisionEncoder<
275//     F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps,
276// > {
277//     /// Vision Transformer
278//     pub vision_transformer: VisionTransformer<F>,
279//     /// Projection layer
280//     pub projection: Dense<F>,
281// }
282//
283// impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps + NumAssign>
284//     CLIPVisionEncoder<F>
285// {
286//     /// Create a new CLIPVisionEncoder
287//     pub fn new(_config: ViTConfig, projection_dim: usize) -> Result<Self> {
288//         // Create ViT with a clone of the _config to avoid ownership issues
289//         let vision_transformer = VisionTransformer::<F>::new(_config.clone())?;
290//         // Projection layer
291//         let mut rng_proj = SmallRng::from_seed([42; 32]);
292//         let projection = Dense::<F>::new(_config.embed_dim, projection_dim, None, &mut rng_proj)?;
293//         Ok(Self {
294//             vision_transformer,
295//             projection,
296//         })
297//     }
298// }
299//
300// impl<
301//         F: Float
302//             + Debug
303//             + ScalarOperand
304//             + Send
305//             + Sync
306//             + scirs2_core::simd_ops::SimdUnifiedOps
307//             + 'static
308//             + NumAssign,
309//     > Layer<F> for CLIPVisionEncoder<F>
310// {
311//     fn as_any(&self) -> &dyn std::any::Any {
312//         self
313//     }
314//
315//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
316//         self
317//     }
318//
319//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
320//         // Apply vision transformer
321//         let x = self.vision_transformer.forward(input)?;
322//         // Apply projection
323//         let output = self.projection.forward(&x)?;
324//         Ok(output)
325//     }
326//
327//     fn backward(
328//         &self,
329//         input: &Array<F, IxDyn>,
330//         grad_output: &Array<F, IxDyn>,
331//     ) -> Result<Array<F, IxDyn>> {
332//         // CLIPVisionEncoder backward: reverse the forward pass
333//         // Backward through projection
334//         let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
335//         // Backward through vision transformer
336//         let grad_input = self.vision_transformer.backward(input, &grad_after_proj)?;
337//         Ok(grad_input)
338//     }
339//
340//     fn update(&mut self, learning_rate: F) -> Result<()> {
341//         // Update projection
342//         self.projection.update(learning_rate)?;
343//         // Update vision transformer
344//         self.vision_transformer.update(learning_rate)?;
345//         Ok(())
346//     }
347//
348//     fn params(&self) -> Vec<Array<F, IxDyn>> {
349//         let mut params = Vec::new();
350//         params.extend(self.projection.params());
351//         params.extend(self.vision_transformer.params());
352//         params
353//     }
354//
355//     fn set_training(&mut self, training: bool) {
356//         self.projection.set_training(training);
357//         self.vision_transformer.set_training(training);
358//     }
359//
360//     fn is_training(&self) -> bool {
361//         self.vision_transformer.is_training()
362//     }
363// }
364// CLIP model implementation
365// TODO: Re-enable once PatchEmbedding is implemented
366// pub struct CLIP<
367//     F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps,
368// > {
369//     /// Vision encoder
370//     pub vision_encoder: CLIPVisionEncoder<F>,
371//     /// Text encoder
372//     pub text_encoder: CLIPTextEncoder<F>,
373//     /// Optional classifier for zero-shot classification
374//     pub classifier: Option<Dense<F>>,
375//     /// Model configuration
376//     pub _config: CLIPConfig,
377//     /// Temperature parameter for contrastive loss
378//     pub logit_scale: F,
379// }
380
381// TODO: Re-enable once PatchEmbedding is implemented
382// impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps + NumAssign>
383//     CLIP<F>
384// {
385//     /// Create a new CLIP model
386//     pub fn new(config: CLIPConfig) -> Result<Self> {
387//         // Create vision encoder
388//         let vision_encoder =
389//             CLIPVisionEncoder::<F>::new(config.vision_config.clone(), config.projection_dim)?;
390//         // Create text encoder
391//         let text_encoder =
392//             CLIPTextEncoder::<F>::new(config.text_config.clone(), config.projection_dim)?;
393//         // Create classifier if needed
394//         let classifier = if config.include_head {
395//             let mut rng_cls = SmallRng::from_seed([42; 32]);
396//             Some(Dense::<F>::new(
397//                 config.projection_dim,
398//                 config.num_classes,
399//                 None,
400//                 &mut rng_cls,
401//             )?)
402//         } else {
403//             None
404//         };
405//         // Initialize logit scale (typically ln(1/0.07))
406//         let logit_scale = F::from(2.6592).expect("Failed to convert constant to float");
407//         Ok(Self {
408//             vision_encoder,
409//             text_encoder,
410//             classifier,
411//             _config: config,
412//             logit_scale,
413//         })
414//     }
415//     /// Forward pass for image-text contrastive learning
416//     pub fn forward_contrastive(
417//         &self,
418//         image_input: &Array<F, IxDyn>,
419//         text_input: &Array<F, IxDyn>,
420//     ) -> Result<ClipOutput<F>> {
421//         // Get image and text embeddings
422//         let image_features = self.vision_encoder.forward(image_input)?;
423//         let text_features = self.text_encoder.forward(text_input)?;
424//         // Normalize embeddings
425//         let image_features_norm = normalize_features(&image_features)?;
426//         let text_features_norm = normalize_features(&text_features)?;
427//         // Compute similarity matrix (batch_size x batch_size)
428//         let logits_per_image =
429//             compute_similarity(&image_features_norm, &text_features_norm, self.logit_scale)?;
430//         // Transpose to get logits_pertext (currently unused but kept for API consistency)
431//         let _logits_pertext = logits_per_image.t().into_dyn();
432//         Ok((image_features, text_features, logits_per_image))
433//     }
434//
435//     /// Forward pass for zero-shot image classification using a text encoder
436//     pub fn forward_classification(
437//         &self,
438//         image_input: &Array<F, IxDyn>,
439//         text_embeddings: &Array<F, IxDyn>,
440//     ) -> Result<Array<F, IxDyn>> {
441//         // Get image embeddings
442//         let image_features = self.vision_encoder.forward(image_input)?;
443//         // Normalize image embeddings
444//         let image_features_norm = normalize_features(&image_features)?;
445//         // Compute similarity with text embeddings
446//         let logits = compute_similarity(&image_features_norm, text_embeddings, self.logit_scale)?;
447//         Ok(logits)
448//     }
449//     /// Create a CLIP model with default settings
450//     pub fn clip_base(num_classes: usize, include_head: bool) -> Result<Self> {
451//         let vision_config = ViTConfig {
452//             image_size: (224, 224),
453//             patch_size: (16, 16),
454//             in_channels: 3,
455//             num_classes,
456//             embed_dim: 768,
457//             num_heads: 12,
458//             mlp_dim: 3072,
459//             num_layers: 12,
460//             dropout_rate: 0.1,
461//             attention_dropout_rate: 0.1,
462//         };
463//         let text_config = CLIPTextConfig::default();
464//         let config = CLIPConfig {
465//             text_config,
466//             vision_config,
467//             projection_dim: 512,
468//             include_head,
469//             num_classes,
470//         };
471//         Self::new(config)
472//     }
473//     /// Create a small CLIP model
474//     pub fn clip_small(num_classes: usize, include_head: bool) -> Result<Self> {
475//         let vision_config = ViTConfig {
476//             image_size: (224, 224),
477//             patch_size: (16, 16),
478//             in_channels: 3,
479//             num_classes,
480//             embed_dim: 512,
481//             num_heads: 6,
482//             mlp_dim: 2048,
483//             num_layers: 8,
484//             dropout_rate: 0.1,
485//             attention_dropout_rate: 0.1,
486//         };
487//         let text_config = CLIPTextConfig {
488//             vocab_size: 49408,
489//             hidden_size: 384,
490//             intermediate_size: 1536,
491//             num_layers: 8,
492//             num_heads: 6,
493//             max_position_embeddings: 77,
494//             dropout_rate: 0.1,
495//             layer_norm_eps: 1e-5,
496//         };
497//         let config = CLIPConfig {
498//             text_config,
499//             vision_config,
500//             projection_dim: 256,
501//             include_head,
502//             num_classes,
503//         };
504//         Self::new(config)
505//     }
506// }
507//
508// impl<
509//         F: Float
510//             + Debug
511//             + ScalarOperand
512//             + Send
513//             + Sync
514//             + scirs2_core::simd_ops::SimdUnifiedOps
515//             + 'static
516//             + NumAssign,
517//     > Layer<F> for CLIP<F>
518// {
519//     fn as_any(&self) -> &dyn std::any::Any {
520//         self
521//     }
522//
523//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
524//         self
525//     }
526//
527//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
528//         // In a typical scenario, input would be an image
529//         // For classification tasks, we would need pre-computed text embeddings
530//         // Get image features
531//         let image_features = self.vision_encoder.forward(input)?;
532//         // If classifier is present, use it for direct classification
533//         if let Some(ref classifier) = self.classifier {
534//             return classifier.forward(&image_features);
535//         }
536//         Ok(image_features)
537//     }
538//
539//     fn backward(
540//         &self,
541//         input: &Array<F, IxDyn>,
542//         grad_output: &Array<F, IxDyn>,
543//     ) -> Result<Array<F, IxDyn>> {
544//         // CLIP backward: reverse the forward pass
545//         let mut grad = grad_output.clone();
546//         // Backward through classifier if present
547//         if let Some(ref classifier) = self.classifier {
548//             // For proper gradient computation, we need the intermediate features
549//             // This is a simplified version
550//             grad = classifier.backward(&grad, &grad)?;
551//         }
552//         // Backward through vision encoder
553//         let grad_input = self.vision_encoder.backward(input, &grad)?;
554//         Ok(grad_input)
555//     }
556//
557//     fn update(&mut self, learning_rate: F) -> Result<()> {
558//         // Update vision encoder
559//         self.vision_encoder.update(learning_rate)?;
560//         // Update text encoder
561//         self.text_encoder.update(learning_rate)?;
562//         // Update classifier if present
563//         if let Some(ref mut classifier) = self.classifier {
564//             classifier.update(learning_rate)?;
565//         }
566//         Ok(())
567//     }
568//
569//     fn params(&self) -> Vec<Array<F, IxDyn>> {
570//         let mut params = Vec::new();
571//         params.extend(self.vision_encoder.params());
572//         params.extend(self.text_encoder.params());
573//         if let Some(ref classifier) = self.classifier {
574//             params.extend(classifier.params());
575//         }
576//         params
577//     }
578//
579//     fn set_training(&mut self, training: bool) {
580//         self.vision_encoder.set_training(training);
581//         self.text_encoder.set_training(training);
582//         if let Some(ref mut classifier) = self.classifier {
583//             classifier.set_training(training);
584//         }
585//     }
586//
587//     fn is_training(&self) -> bool {
588//         self.vision_encoder.is_training()
589//     }
590// }
591/// Normalize feature vectors (L2 normalization)
592#[allow(dead_code)]
593fn normalize_features<F: Float + Debug + ScalarOperand>(
594    features: &Array<F, IxDyn>,
595) -> Result<Array<F, IxDyn>> {
596    let shape = features.shape();
597    let batch_size = shape[0];
598    let feature_dim = shape[1];
599    // Reshape to 2D for easier computation
600    let features_2d = features
601        .clone()
602        .into_shape_with_order((batch_size, feature_dim))?;
603    // Compute L2 norm along the feature dimension
604    let norm = features_2d.map_axis(Axis(1), |x| {
605        let sum_squares = x.iter().fold(F::zero(), |acc, &val| acc + val * val);
606        let norm = sum_squares.sqrt();
607        // Avoid division by zero
608        if norm > F::from(1e-12).expect("Failed to convert constant to float") {
609            norm
610        } else {
611            F::one()
612        }
613    });
614    // Expand norm to match feature dims for broadcasting
615    let norm_expanded = norm.insert_axis(Axis(1));
616    // Normalize features
617    let normalized = features_2d.clone() / norm_expanded;
618    // Reshape back to original shape
619    Ok(normalized.into_shape_with_order(shape)?)
620}
621
622/// Compute similarity matrix between two sets of features
623#[allow(dead_code)]
624fn compute_similarity<F: Float + Debug + ScalarOperand>(
625    features_a: &Array<F, IxDyn>,
626    features_b: &Array<F, IxDyn>,
627    temperature: F,
628) -> Result<Array<F, IxDyn>> {
629    // Get shapes
630    let shape_a = features_a.shape();
631    let shape_b = features_b.shape();
632    let batch_a = shape_a[0];
633    let batch_b = shape_b[0];
634    // Reshape features to 2D matrices
635    let features_a_2d = features_a
636        .clone()
637        .into_shape_with_order((batch_a, shape_a[1]))?;
638    let features_b_2d = features_b
639        .clone()
640        .into_shape_with_order((batch_b, shape_b[1]))?;
641    // Compute dot product (similarity matrix)
642    let similarity = features_a_2d.dot(&features_b_2d.t());
643    // Apply temperature scaling
644    let scaled_similarity = similarity * temperature;
645    Ok(scaled_similarity.into_dyn())
646}