Skip to main content

scirs2_neural/models/architectures/
clip.rs

1//! CLIP (Contrastive Language-Image Pre-training) Architecture
2//!
3//! This module implements a CLIP-like architecture as described in
4//! "Learning Transferable Visual Models From Natural Language Supervision"
5//! (<https://arxiv.org/abs/2103.00020>)
6//! CLIP is a multi-modal model that learns visual concepts from natural language supervision,
7//! enabling zero-shot transfer to various visual classification tasks.
8
9use crate::error::Result;
10use crate::layers::{Dense, Layer, LayerNorm, Sequential};
11use crate::models::architectures::{ViTConfig, VisionTransformer};
12use crate::transformer::TransformerEncoderLayer;
13use crate::utils::positional_encoding::{PositionalEncoding, SinusoidalPositionalEncoding};
14use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
15use scirs2_core::numeric::{Float, NumAssign};
16use scirs2_core::random::{rngs::SmallRng, SeedableRng};
17use serde::{Deserialize, Serialize};
18use std::fmt::Debug;
19/// Type alias for CLIP output (image embeddings, text embeddings, logit scale)
20type ClipOutput<F> = (Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>);
21/// Configuration for the text encoder in CLIP
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CLIPTextConfig {
24    /// Vocabulary size
25    pub vocab_size: usize,
26    /// Hidden size for embeddings and transformer
27    pub hidden_size: usize,
28    /// Intermediate size in feed-forward layers
29    pub intermediate_size: usize,
30    /// Number of transformer layers
31    pub num_layers: usize,
32    /// Number of attention heads
33    pub num_heads: usize,
34    /// Maximum sequence length
35    pub max_position_embeddings: usize,
36    /// Dropout rate
37    pub dropout_rate: f64,
38    /// Layer norm epsilon
39    pub layer_norm_eps: f64,
40}
41/// Configuration for the CLIP model
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CLIPConfig {
44    /// Text encoder configuration
45    pub text_config: CLIPTextConfig,
46    /// Vision encoder configuration
47    pub vision_config: ViTConfig,
48    /// Projection dimension for both text and vision encoders
49    pub projection_dim: usize,
50    /// Whether to include the classifier
51    pub include_head: bool,
52    /// Number of classes for the classifier (if include_head is true)
53    pub num_classes: usize,
54}
55
56impl Default for CLIPTextConfig {
57    fn default() -> Self {
58        Self {
59            vocab_size: 49408,
60            hidden_size: 512,
61            intermediate_size: 2048,
62            num_layers: 12,
63            num_heads: 8,
64            max_position_embeddings: 77,
65            dropout_rate: 0.1,
66            layer_norm_eps: 1e-5,
67        }
68    }
69}
70
71/// Text encoder for CLIP model
72#[derive(Debug, Clone)]
73pub struct CLIPTextEncoder<
74    F: Float
75        + Debug
76        + ScalarOperand
77        + Send
78        + Sync
79        + 'static
80        + scirs2_core::simd_ops::SimdUnifiedOps
81        + NumAssign,
82> {
83    /// Token embedding
84    pub token_embedding: Sequential<F>,
85    /// Position embedding
86    pub position_embedding: SinusoidalPositionalEncoding<F>,
87    /// Transformer encoder layers
88    pub encoder_layers: Vec<TransformerEncoderLayer<F>>,
89    /// Layer normalization
90    pub layer_norm: LayerNorm<F>,
91    /// Final projection layer
92    pub projection: Dense<F>,
93    /// Text configuration
94    pub config: CLIPTextConfig,
95}
96
97impl<
98        F: Float
99            + Debug
100            + ScalarOperand
101            + Send
102            + Sync
103            + scirs2_core::simd_ops::SimdUnifiedOps
104            + NumAssign,
105    > CLIPTextEncoder<F>
106{
107    /// Create a new CLIPTextEncoder
108    pub fn new(_config: CLIPTextConfig, projection_dim: usize) -> Result<Self> {
109        // Token embedding
110        let mut token_embedding = Sequential::new();
111        let mut rng = SmallRng::from_seed([42; 32]);
112        token_embedding.add(Dense::<F>::new(
113            _config.vocab_size,
114            _config.hidden_size,
115            None,
116            &mut rng,
117        )?);
118        // Position embedding
119        let position_embedding = SinusoidalPositionalEncoding::<F>::new(
120            _config.hidden_size,
121            _config.max_position_embeddings,
122        );
123        // Transformer encoder layers
124        let mut encoder_layers = Vec::with_capacity(_config.num_layers);
125        for _i in 0.._config.num_layers {
126            encoder_layers.push(TransformerEncoderLayer::<F>::new(
127                _config.hidden_size,
128                _config.num_heads,
129                _config.intermediate_size,
130                _config.dropout_rate,
131                _config.layer_norm_eps,
132                &mut rng,
133            )?);
134        }
135        // Layer normalization
136        let layer_norm =
137            LayerNorm::<F>::new(_config.hidden_size, _config.layer_norm_eps, &mut rng)?;
138        // Projection
139        let projection = Dense::<F>::new(_config.hidden_size, projection_dim, None, &mut rng)?;
140        Ok(Self {
141            token_embedding,
142            position_embedding,
143            encoder_layers,
144            layer_norm,
145            projection,
146            config: _config,
147        })
148    }
149}
150impl<
151        F: Float
152            + Debug
153            + ScalarOperand
154            + Send
155            + Sync
156            + scirs2_core::simd_ops::SimdUnifiedOps
157            + 'static
158            + NumAssign,
159    > Layer<F> for CLIPTextEncoder<F>
160{
161    fn as_any(&self) -> &dyn std::any::Any {
162        self
163    }
164
165    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
166        self
167    }
168
169    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
170        // Apply token embedding
171        let mut x = self.token_embedding.forward(input)?;
172        // Apply position embedding
173        x = self.position_embedding.forward(&x)?;
174        // Apply transformer encoder layers
175        for layer in &self.encoder_layers {
176            x = layer.forward(&x)?;
177        }
178        // Apply layer normalization
179        x = self.layer_norm.forward(&x)?;
180        // Extract the [CLS] token embedding (assuming it's the first token)
181        let batch_size = x.shape()[0];
182        let hidden_size = x.shape()[2];
183        let cls_token = x
184            .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
185            .into_shape_with_order((batch_size, hidden_size))?;
186        // Apply projection - convert to owned array to fix the reference type
187        let cls_token_owned = cls_token.to_owned().into_dyn();
188        let output = self.projection.forward(&cls_token_owned)?;
189        Ok(output)
190    }
191
192    fn backward(
193        &self,
194        input: &Array<F, IxDyn>,
195        grad_output: &Array<F, IxDyn>,
196    ) -> Result<Array<F, IxDyn>> {
197        // CLIPTextEncoder backward: reverse the forward pass
198        // Backward through projection
199        let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
200        // Expand CLS token gradient back to full sequence
201        // Note: This is simplified - in reality we need to handle the slicing properly
202        let batch_size = input.shape()[0];
203        let seq_len = input.shape()[1];
204        let hidden_size = grad_after_proj.shape()[1];
205        // Create gradient for full sequence (most gradients go to CLS token position)
206        let mut grad_full_seq =
207            Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, seq_len, hidden_size]));
208        // Put the gradient at the CLS token position (index 0)
209        for i in 0..batch_size {
210            for j in 0..hidden_size {
211                grad_full_seq[[i, 0, j]] = grad_after_proj[[i, j]];
212            }
213        }
214        let grad_full_seq = grad_full_seq.into_dyn();
215        // Backward through layer normalization
216        let mut grad = self.layer_norm.backward(&grad_full_seq, &grad_full_seq)?;
217        // Backward through transformer encoder layers in reverse order
218        for layer in self.encoder_layers.iter().rev() {
219            grad = layer.backward(&grad, &grad)?;
220        }
221        // Backward through position embedding.
222        // SinusoidalPositionalEncoding is parameter-free; its backward is a pass-through
223        // (∂L/∂input = ∂L/∂output since output = input + fixed_encoding).
224        grad = self.position_embedding.backward(&grad, &grad)?;
225        // Backward through token embedding
226        let grad_input = self.token_embedding.backward(input, &grad)?;
227        Ok(grad_input)
228    }
229
230    fn update(&mut self, learning_rate: F) -> Result<()> {
231        // Update all components
232        // Update token embedding
233        self.token_embedding.update(learning_rate)?;
234        // Update position embedding
235        self.position_embedding.update(learning_rate)?;
236        // Update all encoder layers
237        for layer in &mut self.encoder_layers {
238            layer.update(learning_rate)?;
239        }
240        // Update layer normalization
241        self.layer_norm.update(learning_rate)?;
242        // Update projection layer
243        self.projection.update(learning_rate)?;
244        Ok(())
245    }
246
247    fn params(&self) -> Vec<Array<F, IxDyn>> {
248        let mut params = Vec::new();
249        params.extend(self.token_embedding.params());
250        // position_embedding has no trainable parameters, so we skip it
251        for layer in &self.encoder_layers {
252            params.extend(layer.params());
253        }
254        params.extend(self.layer_norm.params());
255        params.extend(self.projection.params());
256        params
257    }
258
259    fn set_training(&mut self, training: bool) {
260        self.token_embedding.set_training(training);
261        self.position_embedding.set_training(training);
262        for layer in &mut self.encoder_layers {
263            layer.set_training(training);
264        }
265        self.layer_norm.set_training(training);
266        self.projection.set_training(training);
267    }
268
269    fn is_training(&self) -> bool {
270        self.token_embedding.is_training()
271    }
272}
273/// Vision encoder for CLIP model — wraps a `VisionTransformer` followed by a linear projection.
274pub struct CLIPVisionEncoder<
275    F: Float
276        + Debug
277        + ScalarOperand
278        + Send
279        + Sync
280        + Clone
281        + 'static
282        + scirs2_core::simd_ops::SimdUnifiedOps
283        + NumAssign,
284> {
285    /// Vision Transformer backbone
286    pub vision_transformer: VisionTransformer<F>,
287    /// Linear projection into shared embedding space
288    pub projection: Dense<F>,
289}
290
291impl<
292        F: Float
293            + Debug
294            + ScalarOperand
295            + Send
296            + Sync
297            + Clone
298            + 'static
299            + scirs2_core::simd_ops::SimdUnifiedOps
300            + NumAssign,
301    > CLIPVisionEncoder<F>
302{
303    /// Create a new CLIPVisionEncoder.
304    pub fn new(config: ViTConfig, projection_dim: usize) -> Result<Self> {
305        let embed_dim = config.embed_dim;
306        let vision_transformer = VisionTransformer::<F>::new(config)?;
307        let mut rng_proj = SmallRng::from_seed([42; 32]);
308        let projection = Dense::<F>::new(embed_dim, projection_dim, None, &mut rng_proj)?;
309        Ok(Self {
310            vision_transformer,
311            projection,
312        })
313    }
314}
315
316impl<
317        F: Float
318            + Debug
319            + ScalarOperand
320            + Send
321            + Sync
322            + Clone
323            + scirs2_core::simd_ops::SimdUnifiedOps
324            + 'static
325            + NumAssign,
326    > Layer<F> for CLIPVisionEncoder<F>
327{
328    fn as_any(&self) -> &dyn std::any::Any {
329        self
330    }
331
332    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
333        self
334    }
335
336    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
337        let x = self.vision_transformer.forward(input)?;
338        self.projection.forward(&x)
339    }
340
341    fn backward(
342        &self,
343        input: &Array<F, IxDyn>,
344        grad_output: &Array<F, IxDyn>,
345    ) -> Result<Array<F, IxDyn>> {
346        let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
347        self.vision_transformer.backward(input, &grad_after_proj)
348    }
349
350    fn update(&mut self, learning_rate: F) -> Result<()> {
351        self.projection.update(learning_rate)?;
352        self.vision_transformer.update(learning_rate)?;
353        Ok(())
354    }
355
356    fn params(&self) -> Vec<Array<F, IxDyn>> {
357        let mut params = Vec::new();
358        params.extend(self.projection.params());
359        params.extend(self.vision_transformer.params());
360        params
361    }
362
363    fn set_training(&mut self, training: bool) {
364        self.projection.set_training(training);
365        self.vision_transformer.set_training(training);
366    }
367
368    fn is_training(&self) -> bool {
369        self.vision_transformer.is_training()
370    }
371
372    fn layer_type(&self) -> &str {
373        "CLIPVisionEncoder"
374    }
375}
376
377/// CLIP — Contrastive Language-Image Pre-training model.
378///
379/// Combines a vision encoder (ViT) and a text encoder (Transformer) in a shared
380/// embedding space, optimised with a contrastive objective.
381pub struct CLIP<
382    F: Float
383        + Debug
384        + ScalarOperand
385        + Send
386        + Sync
387        + Clone
388        + 'static
389        + scirs2_core::simd_ops::SimdUnifiedOps
390        + NumAssign,
391> {
392    /// Vision encoder
393    pub vision_encoder: CLIPVisionEncoder<F>,
394    /// Text encoder
395    pub text_encoder: CLIPTextEncoder<F>,
396    /// Optional classifier for downstream classification
397    pub classifier: Option<Dense<F>>,
398    /// Model configuration
399    pub _config: CLIPConfig,
400    /// Log temperature parameter for contrastive loss
401    pub logit_scale: F,
402}
403
404impl<
405        F: Float
406            + Debug
407            + ScalarOperand
408            + Send
409            + Sync
410            + Clone
411            + 'static
412            + scirs2_core::simd_ops::SimdUnifiedOps
413            + NumAssign,
414    > CLIP<F>
415{
416    /// Create a new CLIP model from config.
417    pub fn new(config: CLIPConfig) -> Result<Self> {
418        let vision_encoder =
419            CLIPVisionEncoder::<F>::new(config.vision_config.clone(), config.projection_dim)?;
420        let text_encoder =
421            CLIPTextEncoder::<F>::new(config.text_config.clone(), config.projection_dim)?;
422        let classifier = if config.include_head {
423            let mut rng_cls = SmallRng::from_seed([42; 32]);
424            Some(Dense::<F>::new(
425                config.projection_dim,
426                config.num_classes,
427                None,
428                &mut rng_cls,
429            )?)
430        } else {
431            None
432        };
433        // ln(1/0.07) ≈ 2.6592
434        let logit_scale = F::from(2.6592_f64).ok_or_else(|| {
435            crate::error::NeuralError::InvalidArchitecture(
436                "CLIP: failed to convert logit_scale to float".to_string(),
437            )
438        })?;
439        Ok(Self {
440            vision_encoder,
441            text_encoder,
442            classifier,
443            _config: config,
444            logit_scale,
445        })
446    }
447
448    /// Forward pass for image-text contrastive learning.
449    /// Returns `(image_embeddings, text_embeddings, similarity_matrix)`.
450    pub fn forward_contrastive(
451        &self,
452        image_input: &Array<F, IxDyn>,
453        text_input: &Array<F, IxDyn>,
454    ) -> Result<ClipOutput<F>> {
455        let image_features = self.vision_encoder.forward(image_input)?;
456        let text_features = self.text_encoder.forward(text_input)?;
457        let image_features_norm = normalize_features(&image_features)?;
458        let text_features_norm = normalize_features(&text_features)?;
459        let logits_per_image =
460            compute_similarity(&image_features_norm, &text_features_norm, self.logit_scale)?;
461        Ok((image_features, text_features, logits_per_image))
462    }
463
464    /// Forward pass for zero-shot image classification using pre-computed text embeddings.
465    pub fn forward_classification(
466        &self,
467        image_input: &Array<F, IxDyn>,
468        text_embeddings: &Array<F, IxDyn>,
469    ) -> Result<Array<F, IxDyn>> {
470        let image_features = self.vision_encoder.forward(image_input)?;
471        let image_features_norm = normalize_features(&image_features)?;
472        compute_similarity(&image_features_norm, text_embeddings, self.logit_scale)
473    }
474
475    /// Create a CLIP-Base model (ViT-B/16 + 12-layer text encoder).
476    pub fn clip_base(num_classes: usize, include_head: bool) -> Result<Self> {
477        let vision_config = ViTConfig {
478            image_size: (224, 224),
479            patch_size: (16, 16),
480            in_channels: 3,
481            num_classes,
482            embed_dim: 768,
483            num_heads: 12,
484            mlp_dim: 3072,
485            num_layers: 12,
486            dropout_rate: 0.1,
487            attention_dropout_rate: 0.1,
488        };
489        Self::new(CLIPConfig {
490            text_config: CLIPTextConfig::default(),
491            vision_config,
492            projection_dim: 512,
493            include_head,
494            num_classes,
495        })
496    }
497
498    /// Create a small CLIP model.
499    pub fn clip_small(num_classes: usize, include_head: bool) -> Result<Self> {
500        let vision_config = ViTConfig {
501            image_size: (224, 224),
502            patch_size: (16, 16),
503            in_channels: 3,
504            num_classes,
505            embed_dim: 512,
506            num_heads: 6,
507            mlp_dim: 2048,
508            num_layers: 8,
509            dropout_rate: 0.1,
510            attention_dropout_rate: 0.1,
511        };
512        let text_config = CLIPTextConfig {
513            vocab_size: 49408,
514            hidden_size: 384,
515            intermediate_size: 1536,
516            num_layers: 8,
517            num_heads: 6,
518            max_position_embeddings: 77,
519            dropout_rate: 0.1,
520            layer_norm_eps: 1e-5,
521        };
522        Self::new(CLIPConfig {
523            text_config,
524            vision_config,
525            projection_dim: 256,
526            include_head,
527            num_classes,
528        })
529    }
530}
531
532impl<
533        F: Float
534            + Debug
535            + ScalarOperand
536            + Send
537            + Sync
538            + Clone
539            + scirs2_core::simd_ops::SimdUnifiedOps
540            + 'static
541            + NumAssign,
542    > Layer<F> for CLIP<F>
543{
544    fn as_any(&self) -> &dyn std::any::Any {
545        self
546    }
547
548    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
549        self
550    }
551
552    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
553        let image_features = self.vision_encoder.forward(input)?;
554        if let Some(ref classifier) = self.classifier {
555            return classifier.forward(&image_features);
556        }
557        Ok(image_features)
558    }
559
560    fn backward(
561        &self,
562        input: &Array<F, IxDyn>,
563        grad_output: &Array<F, IxDyn>,
564    ) -> Result<Array<F, IxDyn>> {
565        let mut grad = grad_output.clone();
566        if let Some(ref classifier) = self.classifier {
567            grad = classifier.backward(&grad, &grad)?;
568        }
569        self.vision_encoder.backward(input, &grad)
570    }
571
572    fn update(&mut self, learning_rate: F) -> Result<()> {
573        self.vision_encoder.update(learning_rate)?;
574        self.text_encoder.update(learning_rate)?;
575        if let Some(ref mut classifier) = self.classifier {
576            classifier.update(learning_rate)?;
577        }
578        Ok(())
579    }
580
581    fn params(&self) -> Vec<Array<F, IxDyn>> {
582        let mut params = Vec::new();
583        params.extend(self.vision_encoder.params());
584        params.extend(self.text_encoder.params());
585        if let Some(ref classifier) = self.classifier {
586            params.extend(classifier.params());
587        }
588        params
589    }
590
591    fn set_training(&mut self, training: bool) {
592        self.vision_encoder.set_training(training);
593        self.text_encoder.set_training(training);
594        if let Some(ref mut classifier) = self.classifier {
595            classifier.set_training(training);
596        }
597    }
598
599    fn is_training(&self) -> bool {
600        self.vision_encoder.is_training()
601    }
602
603    fn layer_type(&self) -> &str {
604        "CLIP"
605    }
606}
607/// Normalize feature vectors (L2 normalization)
608#[allow(dead_code)]
609fn normalize_features<F: Float + Debug + ScalarOperand>(
610    features: &Array<F, IxDyn>,
611) -> Result<Array<F, IxDyn>> {
612    let shape = features.shape();
613    let batch_size = shape[0];
614    let feature_dim = shape[1];
615    // Reshape to 2D for easier computation
616    let features_2d = features
617        .clone()
618        .into_shape_with_order((batch_size, feature_dim))?;
619    // Compute L2 norm along the feature dimension
620    let norm = features_2d.map_axis(Axis(1), |x| {
621        let sum_squares = x.iter().fold(F::zero(), |acc, &val| acc + val * val);
622        let norm = sum_squares.sqrt();
623        // Avoid division by zero
624        if norm > F::from(1e-12).expect("Failed to convert constant to float") {
625            norm
626        } else {
627            F::one()
628        }
629    });
630    // Expand norm to match feature dims for broadcasting
631    let norm_expanded = norm.insert_axis(Axis(1));
632    // Normalize features
633    let normalized = features_2d.clone() / norm_expanded;
634    // Reshape back to original shape
635    Ok(normalized.into_shape_with_order(shape)?)
636}
637
638/// Compute similarity matrix between two sets of features
639#[allow(dead_code)]
640fn compute_similarity<F: Float + Debug + ScalarOperand>(
641    features_a: &Array<F, IxDyn>,
642    features_b: &Array<F, IxDyn>,
643    temperature: F,
644) -> Result<Array<F, IxDyn>> {
645    // Get shapes
646    let shape_a = features_a.shape();
647    let shape_b = features_b.shape();
648    let batch_a = shape_a[0];
649    let batch_b = shape_b[0];
650    // Reshape features to 2D matrices
651    let features_a_2d = features_a
652        .clone()
653        .into_shape_with_order((batch_a, shape_a[1]))?;
654    let features_b_2d = features_b
655        .clone()
656        .into_shape_with_order((batch_b, shape_b[1]))?;
657    // Compute dot product (similarity matrix)
658    let similarity = features_a_2d.dot(&features_b_2d.t());
659    // Apply temperature scaling
660    let scaled_similarity = similarity * temperature;
661    Ok(scaled_similarity.into_dyn())
662}