scirs2_neural/models/architectures/clip.rs
1//! CLIP (Contrastive Language-Image Pre-training) Architecture
2//!
3//! This module implements a CLIP-like architecture as described in
4//! "Learning Transferable Visual Models From Natural Language Supervision"
5//! (<https://arxiv.org/abs/2103.00020>)
6//! CLIP is a multi-modal model that learns visual concepts from natural language supervision,
7//! enabling zero-shot transfer to various visual classification tasks.
8
9use crate::error::Result;
10use crate::layers::{Dense, Layer, LayerNorm, Sequential};
11// TODO: Re-enable once PatchEmbedding is implemented
12// use crate::models::architectures::{ViTConfig, VisionTransformer};
13use crate::models::architectures::ViTConfig;
14use crate::transformer::TransformerEncoderLayer;
15use crate::utils::positional_encoding::{PositionalEncoding, SinusoidalPositionalEncoding};
16use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
17use scirs2_core::numeric::{Float, NumAssign};
18use scirs2_core::random::{rngs::SmallRng, SeedableRng};
19use serde::{Deserialize, Serialize};
20use std::fmt::Debug;
21/// Type alias for CLIP output (image embeddings, text embeddings, logit scale)
22type ClipOutput<F> = (Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>);
23/// Configuration for the text encoder in CLIP
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct CLIPTextConfig {
26 /// Vocabulary size
27 pub vocab_size: usize,
28 /// Hidden size for embeddings and transformer
29 pub hidden_size: usize,
30 /// Intermediate size in feed-forward layers
31 pub intermediate_size: usize,
32 /// Number of transformer layers
33 pub num_layers: usize,
34 /// Number of attention heads
35 pub num_heads: usize,
36 /// Maximum sequence length
37 pub max_position_embeddings: usize,
38 /// Dropout rate
39 pub dropout_rate: f64,
40 /// Layer norm epsilon
41 pub layer_norm_eps: f64,
42}
43/// Configuration for the CLIP model
44pub struct CLIPConfig {
45 /// Text encoder configuration
46 pub text_config: CLIPTextConfig,
47 /// Vision encoder configuration
48 pub vision_config: ViTConfig,
49 /// Projection dimension for both text and vision encoders
50 pub projection_dim: usize,
51 /// Whether to include the classifier
52 pub include_head: bool,
53 /// Number of classes for the classifier (if include_head is true)
54 pub num_classes: usize,
55}
56
57impl Default for CLIPTextConfig {
58 fn default() -> Self {
59 Self {
60 vocab_size: 49408,
61 hidden_size: 512,
62 intermediate_size: 2048,
63 num_layers: 12,
64 num_heads: 8,
65 max_position_embeddings: 77,
66 dropout_rate: 0.1,
67 layer_norm_eps: 1e-5,
68 }
69 }
70}
71
72/// Text encoder for CLIP model
73#[derive(Debug, Clone)]
74pub struct CLIPTextEncoder<
75 F: Float
76 + Debug
77 + ScalarOperand
78 + Send
79 + Sync
80 + 'static
81 + scirs2_core::simd_ops::SimdUnifiedOps
82 + NumAssign,
83> {
84 /// Token embedding
85 pub token_embedding: Sequential<F>,
86 /// Position embedding
87 pub position_embedding: SinusoidalPositionalEncoding<F>,
88 /// Transformer encoder layers
89 pub encoder_layers: Vec<TransformerEncoderLayer<F>>,
90 /// Layer normalization
91 pub layer_norm: LayerNorm<F>,
92 /// Final projection layer
93 pub projection: Dense<F>,
94 /// Text configuration
95 pub config: CLIPTextConfig,
96}
97
98impl<
99 F: Float
100 + Debug
101 + ScalarOperand
102 + Send
103 + Sync
104 + scirs2_core::simd_ops::SimdUnifiedOps
105 + NumAssign,
106 > CLIPTextEncoder<F>
107{
108 /// Create a new CLIPTextEncoder
109 pub fn new(_config: CLIPTextConfig, projection_dim: usize) -> Result<Self> {
110 // Token embedding
111 let mut token_embedding = Sequential::new();
112 let mut rng = SmallRng::from_seed([42; 32]);
113 token_embedding.add(Dense::<F>::new(
114 _config.vocab_size,
115 _config.hidden_size,
116 None,
117 &mut rng,
118 )?);
119 // Position embedding
120 let position_embedding = SinusoidalPositionalEncoding::<F>::new(
121 _config.hidden_size,
122 _config.max_position_embeddings,
123 );
124 // Transformer encoder layers
125 let mut encoder_layers = Vec::with_capacity(_config.num_layers);
126 for _i in 0.._config.num_layers {
127 encoder_layers.push(TransformerEncoderLayer::<F>::new(
128 _config.hidden_size,
129 _config.num_heads,
130 _config.intermediate_size,
131 _config.dropout_rate,
132 _config.layer_norm_eps,
133 &mut rng,
134 )?);
135 }
136 // Layer normalization
137 let layer_norm =
138 LayerNorm::<F>::new(_config.hidden_size, _config.layer_norm_eps, &mut rng)?;
139 // Projection
140 let projection = Dense::<F>::new(_config.hidden_size, projection_dim, None, &mut rng)?;
141 Ok(Self {
142 token_embedding,
143 position_embedding,
144 encoder_layers,
145 layer_norm,
146 projection,
147 config: _config,
148 })
149 }
150}
151impl<
152 F: Float
153 + Debug
154 + ScalarOperand
155 + Send
156 + Sync
157 + scirs2_core::simd_ops::SimdUnifiedOps
158 + 'static
159 + NumAssign,
160 > Layer<F> for CLIPTextEncoder<F>
161{
162 fn as_any(&self) -> &dyn std::any::Any {
163 self
164 }
165
166 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
167 self
168 }
169
170 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
171 // Apply token embedding
172 let mut x = self.token_embedding.forward(input)?;
173 // Apply position embedding
174 x = self.position_embedding.forward(&x)?;
175 // Apply transformer encoder layers
176 for layer in &self.encoder_layers {
177 x = layer.forward(&x)?;
178 }
179 // Apply layer normalization
180 x = self.layer_norm.forward(&x)?;
181 // Extract the [CLS] token embedding (assuming it's the first token)
182 let batch_size = x.shape()[0];
183 let hidden_size = x.shape()[2];
184 let cls_token = x
185 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
186 .into_shape_with_order((batch_size, hidden_size))?;
187 // Apply projection - convert to owned array to fix the reference type
188 let cls_token_owned = cls_token.to_owned().into_dyn();
189 let output = self.projection.forward(&cls_token_owned)?;
190 Ok(output)
191 }
192
193 fn backward(
194 &self,
195 input: &Array<F, IxDyn>,
196 grad_output: &Array<F, IxDyn>,
197 ) -> Result<Array<F, IxDyn>> {
198 // CLIPTextEncoder backward: reverse the forward pass
199 // Backward through projection
200 let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
201 // Expand CLS token gradient back to full sequence
202 // Note: This is simplified - in reality we need to handle the slicing properly
203 let batch_size = input.shape()[0];
204 let seq_len = input.shape()[1];
205 let hidden_size = grad_after_proj.shape()[1];
206 // Create gradient for full sequence (most gradients go to CLS token position)
207 let mut grad_full_seq =
208 Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, seq_len, hidden_size]));
209 // Put the gradient at the CLS token position (index 0)
210 for i in 0..batch_size {
211 for j in 0..hidden_size {
212 grad_full_seq[[i, 0, j]] = grad_after_proj[[i, j]];
213 }
214 }
215 let grad_full_seq = grad_full_seq.into_dyn();
216 // Backward through layer normalization
217 let mut grad = self.layer_norm.backward(&grad_full_seq, &grad_full_seq)?;
218 // Backward through transformer encoder layers in reverse order
219 for layer in self.encoder_layers.iter().rev() {
220 grad = layer.backward(&grad, &grad)?;
221 }
222 // TODO: Backward through position embedding when backward method is implemented
223 // grad = self.position_embedding.backward(&grad, &grad)?;
224 // Backward through token embedding
225 let grad_input = self.token_embedding.backward(input, &grad)?;
226 Ok(grad_input)
227 }
228
229 fn update(&mut self, learning_rate: F) -> Result<()> {
230 // Update all components
231 // Update token embedding
232 self.token_embedding.update(learning_rate)?;
233 // Update position embedding
234 self.position_embedding.update(learning_rate)?;
235 // Update all encoder layers
236 for layer in &mut self.encoder_layers {
237 layer.update(learning_rate)?;
238 }
239 // Update layer normalization
240 self.layer_norm.update(learning_rate)?;
241 // Update projection layer
242 self.projection.update(learning_rate)?;
243 Ok(())
244 }
245
246 fn params(&self) -> Vec<Array<F, IxDyn>> {
247 let mut params = Vec::new();
248 params.extend(self.token_embedding.params());
249 // position_embedding has no trainable parameters, so we skip it
250 for layer in &self.encoder_layers {
251 params.extend(layer.params());
252 }
253 params.extend(self.layer_norm.params());
254 params.extend(self.projection.params());
255 params
256 }
257
258 fn set_training(&mut self, training: bool) {
259 self.token_embedding.set_training(training);
260 self.position_embedding.set_training(training);
261 for layer in &mut self.encoder_layers {
262 layer.set_training(training);
263 }
264 self.layer_norm.set_training(training);
265 self.projection.set_training(training);
266 }
267
268 fn is_training(&self) -> bool {
269 self.token_embedding.is_training()
270 }
271}
272// Vision encoder for CLIP model (uses Vision Transformer)
273// TODO: Re-enable once PatchEmbedding is implemented
274// pub struct CLIPVisionEncoder<
275// F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps,
276// > {
277// /// Vision Transformer
278// pub vision_transformer: VisionTransformer<F>,
279// /// Projection layer
280// pub projection: Dense<F>,
281// }
282//
283// impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps + NumAssign>
284// CLIPVisionEncoder<F>
285// {
286// /// Create a new CLIPVisionEncoder
287// pub fn new(_config: ViTConfig, projection_dim: usize) -> Result<Self> {
288// // Create ViT with a clone of the _config to avoid ownership issues
289// let vision_transformer = VisionTransformer::<F>::new(_config.clone())?;
290// // Projection layer
291// let mut rng_proj = SmallRng::from_seed([42; 32]);
292// let projection = Dense::<F>::new(_config.embed_dim, projection_dim, None, &mut rng_proj)?;
293// Ok(Self {
294// vision_transformer,
295// projection,
296// })
297// }
298// }
299//
300// impl<
301// F: Float
302// + Debug
303// + ScalarOperand
304// + Send
305// + Sync
306// + scirs2_core::simd_ops::SimdUnifiedOps
307// + 'static
308// + NumAssign,
309// > Layer<F> for CLIPVisionEncoder<F>
310// {
311// fn as_any(&self) -> &dyn std::any::Any {
312// self
313// }
314//
315// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
316// self
317// }
318//
319// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
320// // Apply vision transformer
321// let x = self.vision_transformer.forward(input)?;
322// // Apply projection
323// let output = self.projection.forward(&x)?;
324// Ok(output)
325// }
326//
327// fn backward(
328// &self,
329// input: &Array<F, IxDyn>,
330// grad_output: &Array<F, IxDyn>,
331// ) -> Result<Array<F, IxDyn>> {
332// // CLIPVisionEncoder backward: reverse the forward pass
333// // Backward through projection
334// let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
335// // Backward through vision transformer
336// let grad_input = self.vision_transformer.backward(input, &grad_after_proj)?;
337// Ok(grad_input)
338// }
339//
340// fn update(&mut self, learning_rate: F) -> Result<()> {
341// // Update projection
342// self.projection.update(learning_rate)?;
343// // Update vision transformer
344// self.vision_transformer.update(learning_rate)?;
345// Ok(())
346// }
347//
348// fn params(&self) -> Vec<Array<F, IxDyn>> {
349// let mut params = Vec::new();
350// params.extend(self.projection.params());
351// params.extend(self.vision_transformer.params());
352// params
353// }
354//
355// fn set_training(&mut self, training: bool) {
356// self.projection.set_training(training);
357// self.vision_transformer.set_training(training);
358// }
359//
360// fn is_training(&self) -> bool {
361// self.vision_transformer.is_training()
362// }
363// }
364// CLIP model implementation
365// TODO: Re-enable once PatchEmbedding is implemented
366// pub struct CLIP<
367// F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps,
368// > {
369// /// Vision encoder
370// pub vision_encoder: CLIPVisionEncoder<F>,
371// /// Text encoder
372// pub text_encoder: CLIPTextEncoder<F>,
373// /// Optional classifier for zero-shot classification
374// pub classifier: Option<Dense<F>>,
375// /// Model configuration
376// pub _config: CLIPConfig,
377// /// Temperature parameter for contrastive loss
378// pub logit_scale: F,
379// }
380
381// TODO: Re-enable once PatchEmbedding is implemented
382// impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + scirs2_core::simd_ops::SimdUnifiedOps + NumAssign>
383// CLIP<F>
384// {
385// /// Create a new CLIP model
386// pub fn new(config: CLIPConfig) -> Result<Self> {
387// // Create vision encoder
388// let vision_encoder =
389// CLIPVisionEncoder::<F>::new(config.vision_config.clone(), config.projection_dim)?;
390// // Create text encoder
391// let text_encoder =
392// CLIPTextEncoder::<F>::new(config.text_config.clone(), config.projection_dim)?;
393// // Create classifier if needed
394// let classifier = if config.include_head {
395// let mut rng_cls = SmallRng::from_seed([42; 32]);
396// Some(Dense::<F>::new(
397// config.projection_dim,
398// config.num_classes,
399// None,
400// &mut rng_cls,
401// )?)
402// } else {
403// None
404// };
405// // Initialize logit scale (typically ln(1/0.07))
406// let logit_scale = F::from(2.6592).expect("Failed to convert constant to float");
407// Ok(Self {
408// vision_encoder,
409// text_encoder,
410// classifier,
411// _config: config,
412// logit_scale,
413// })
414// }
415// /// Forward pass for image-text contrastive learning
416// pub fn forward_contrastive(
417// &self,
418// image_input: &Array<F, IxDyn>,
419// text_input: &Array<F, IxDyn>,
420// ) -> Result<ClipOutput<F>> {
421// // Get image and text embeddings
422// let image_features = self.vision_encoder.forward(image_input)?;
423// let text_features = self.text_encoder.forward(text_input)?;
424// // Normalize embeddings
425// let image_features_norm = normalize_features(&image_features)?;
426// let text_features_norm = normalize_features(&text_features)?;
427// // Compute similarity matrix (batch_size x batch_size)
428// let logits_per_image =
429// compute_similarity(&image_features_norm, &text_features_norm, self.logit_scale)?;
430// // Transpose to get logits_pertext (currently unused but kept for API consistency)
431// let _logits_pertext = logits_per_image.t().into_dyn();
432// Ok((image_features, text_features, logits_per_image))
433// }
434//
435// /// Forward pass for zero-shot image classification using a text encoder
436// pub fn forward_classification(
437// &self,
438// image_input: &Array<F, IxDyn>,
439// text_embeddings: &Array<F, IxDyn>,
440// ) -> Result<Array<F, IxDyn>> {
441// // Get image embeddings
442// let image_features = self.vision_encoder.forward(image_input)?;
443// // Normalize image embeddings
444// let image_features_norm = normalize_features(&image_features)?;
445// // Compute similarity with text embeddings
446// let logits = compute_similarity(&image_features_norm, text_embeddings, self.logit_scale)?;
447// Ok(logits)
448// }
449// /// Create a CLIP model with default settings
450// pub fn clip_base(num_classes: usize, include_head: bool) -> Result<Self> {
451// let vision_config = ViTConfig {
452// image_size: (224, 224),
453// patch_size: (16, 16),
454// in_channels: 3,
455// num_classes,
456// embed_dim: 768,
457// num_heads: 12,
458// mlp_dim: 3072,
459// num_layers: 12,
460// dropout_rate: 0.1,
461// attention_dropout_rate: 0.1,
462// };
463// let text_config = CLIPTextConfig::default();
464// let config = CLIPConfig {
465// text_config,
466// vision_config,
467// projection_dim: 512,
468// include_head,
469// num_classes,
470// };
471// Self::new(config)
472// }
473// /// Create a small CLIP model
474// pub fn clip_small(num_classes: usize, include_head: bool) -> Result<Self> {
475// let vision_config = ViTConfig {
476// image_size: (224, 224),
477// patch_size: (16, 16),
478// in_channels: 3,
479// num_classes,
480// embed_dim: 512,
481// num_heads: 6,
482// mlp_dim: 2048,
483// num_layers: 8,
484// dropout_rate: 0.1,
485// attention_dropout_rate: 0.1,
486// };
487// let text_config = CLIPTextConfig {
488// vocab_size: 49408,
489// hidden_size: 384,
490// intermediate_size: 1536,
491// num_layers: 8,
492// num_heads: 6,
493// max_position_embeddings: 77,
494// dropout_rate: 0.1,
495// layer_norm_eps: 1e-5,
496// };
497// let config = CLIPConfig {
498// text_config,
499// vision_config,
500// projection_dim: 256,
501// include_head,
502// num_classes,
503// };
504// Self::new(config)
505// }
506// }
507//
508// impl<
509// F: Float
510// + Debug
511// + ScalarOperand
512// + Send
513// + Sync
514// + scirs2_core::simd_ops::SimdUnifiedOps
515// + 'static
516// + NumAssign,
517// > Layer<F> for CLIP<F>
518// {
519// fn as_any(&self) -> &dyn std::any::Any {
520// self
521// }
522//
523// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
524// self
525// }
526//
527// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
528// // In a typical scenario, input would be an image
529// // For classification tasks, we would need pre-computed text embeddings
530// // Get image features
531// let image_features = self.vision_encoder.forward(input)?;
532// // If classifier is present, use it for direct classification
533// if let Some(ref classifier) = self.classifier {
534// return classifier.forward(&image_features);
535// }
536// Ok(image_features)
537// }
538//
539// fn backward(
540// &self,
541// input: &Array<F, IxDyn>,
542// grad_output: &Array<F, IxDyn>,
543// ) -> Result<Array<F, IxDyn>> {
544// // CLIP backward: reverse the forward pass
545// let mut grad = grad_output.clone();
546// // Backward through classifier if present
547// if let Some(ref classifier) = self.classifier {
548// // For proper gradient computation, we need the intermediate features
549// // This is a simplified version
550// grad = classifier.backward(&grad, &grad)?;
551// }
552// // Backward through vision encoder
553// let grad_input = self.vision_encoder.backward(input, &grad)?;
554// Ok(grad_input)
555// }
556//
557// fn update(&mut self, learning_rate: F) -> Result<()> {
558// // Update vision encoder
559// self.vision_encoder.update(learning_rate)?;
560// // Update text encoder
561// self.text_encoder.update(learning_rate)?;
562// // Update classifier if present
563// if let Some(ref mut classifier) = self.classifier {
564// classifier.update(learning_rate)?;
565// }
566// Ok(())
567// }
568//
569// fn params(&self) -> Vec<Array<F, IxDyn>> {
570// let mut params = Vec::new();
571// params.extend(self.vision_encoder.params());
572// params.extend(self.text_encoder.params());
573// if let Some(ref classifier) = self.classifier {
574// params.extend(classifier.params());
575// }
576// params
577// }
578//
579// fn set_training(&mut self, training: bool) {
580// self.vision_encoder.set_training(training);
581// self.text_encoder.set_training(training);
582// if let Some(ref mut classifier) = self.classifier {
583// classifier.set_training(training);
584// }
585// }
586//
587// fn is_training(&self) -> bool {
588// self.vision_encoder.is_training()
589// }
590// }
591/// Normalize feature vectors (L2 normalization)
592#[allow(dead_code)]
593fn normalize_features<F: Float + Debug + ScalarOperand>(
594 features: &Array<F, IxDyn>,
595) -> Result<Array<F, IxDyn>> {
596 let shape = features.shape();
597 let batch_size = shape[0];
598 let feature_dim = shape[1];
599 // Reshape to 2D for easier computation
600 let features_2d = features
601 .clone()
602 .into_shape_with_order((batch_size, feature_dim))?;
603 // Compute L2 norm along the feature dimension
604 let norm = features_2d.map_axis(Axis(1), |x| {
605 let sum_squares = x.iter().fold(F::zero(), |acc, &val| acc + val * val);
606 let norm = sum_squares.sqrt();
607 // Avoid division by zero
608 if norm > F::from(1e-12).expect("Failed to convert constant to float") {
609 norm
610 } else {
611 F::one()
612 }
613 });
614 // Expand norm to match feature dims for broadcasting
615 let norm_expanded = norm.insert_axis(Axis(1));
616 // Normalize features
617 let normalized = features_2d.clone() / norm_expanded;
618 // Reshape back to original shape
619 Ok(normalized.into_shape_with_order(shape)?)
620}
621
622/// Compute similarity matrix between two sets of features
623#[allow(dead_code)]
624fn compute_similarity<F: Float + Debug + ScalarOperand>(
625 features_a: &Array<F, IxDyn>,
626 features_b: &Array<F, IxDyn>,
627 temperature: F,
628) -> Result<Array<F, IxDyn>> {
629 // Get shapes
630 let shape_a = features_a.shape();
631 let shape_b = features_b.shape();
632 let batch_a = shape_a[0];
633 let batch_b = shape_b[0];
634 // Reshape features to 2D matrices
635 let features_a_2d = features_a
636 .clone()
637 .into_shape_with_order((batch_a, shape_a[1]))?;
638 let features_b_2d = features_b
639 .clone()
640 .into_shape_with_order((batch_b, shape_b[1]))?;
641 // Compute dot product (similarity matrix)
642 let similarity = features_a_2d.dot(&features_b_2d.t());
643 // Apply temperature scaling
644 let scaled_similarity = similarity * temperature;
645 Ok(scaled_similarity.into_dyn())
646}