1use crate::error::Result;
10use crate::layers::{Dense, Layer, LayerNorm, Sequential};
11use crate::models::architectures::{ViTConfig, VisionTransformer};
12use crate::transformer::TransformerEncoderLayer;
13use crate::utils::positional_encoding::{PositionalEncoding, SinusoidalPositionalEncoding};
14use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
15use scirs2_core::numeric::{Float, NumAssign};
16use scirs2_core::random::{rngs::SmallRng, SeedableRng};
17use serde::{Deserialize, Serialize};
18use std::fmt::Debug;
19type ClipOutput<F> = (Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>);
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CLIPTextConfig {
24 pub vocab_size: usize,
26 pub hidden_size: usize,
28 pub intermediate_size: usize,
30 pub num_layers: usize,
32 pub num_heads: usize,
34 pub max_position_embeddings: usize,
36 pub dropout_rate: f64,
38 pub layer_norm_eps: f64,
40}
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CLIPConfig {
44 pub text_config: CLIPTextConfig,
46 pub vision_config: ViTConfig,
48 pub projection_dim: usize,
50 pub include_head: bool,
52 pub num_classes: usize,
54}
55
56impl Default for CLIPTextConfig {
57 fn default() -> Self {
58 Self {
59 vocab_size: 49408,
60 hidden_size: 512,
61 intermediate_size: 2048,
62 num_layers: 12,
63 num_heads: 8,
64 max_position_embeddings: 77,
65 dropout_rate: 0.1,
66 layer_norm_eps: 1e-5,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct CLIPTextEncoder<
74 F: Float
75 + Debug
76 + ScalarOperand
77 + Send
78 + Sync
79 + 'static
80 + scirs2_core::simd_ops::SimdUnifiedOps
81 + NumAssign,
82> {
83 pub token_embedding: Sequential<F>,
85 pub position_embedding: SinusoidalPositionalEncoding<F>,
87 pub encoder_layers: Vec<TransformerEncoderLayer<F>>,
89 pub layer_norm: LayerNorm<F>,
91 pub projection: Dense<F>,
93 pub config: CLIPTextConfig,
95}
96
97impl<
98 F: Float
99 + Debug
100 + ScalarOperand
101 + Send
102 + Sync
103 + scirs2_core::simd_ops::SimdUnifiedOps
104 + NumAssign,
105 > CLIPTextEncoder<F>
106{
107 pub fn new(_config: CLIPTextConfig, projection_dim: usize) -> Result<Self> {
109 let mut token_embedding = Sequential::new();
111 let mut rng = SmallRng::from_seed([42; 32]);
112 token_embedding.add(Dense::<F>::new(
113 _config.vocab_size,
114 _config.hidden_size,
115 None,
116 &mut rng,
117 )?);
118 let position_embedding = SinusoidalPositionalEncoding::<F>::new(
120 _config.hidden_size,
121 _config.max_position_embeddings,
122 );
123 let mut encoder_layers = Vec::with_capacity(_config.num_layers);
125 for _i in 0.._config.num_layers {
126 encoder_layers.push(TransformerEncoderLayer::<F>::new(
127 _config.hidden_size,
128 _config.num_heads,
129 _config.intermediate_size,
130 _config.dropout_rate,
131 _config.layer_norm_eps,
132 &mut rng,
133 )?);
134 }
135 let layer_norm =
137 LayerNorm::<F>::new(_config.hidden_size, _config.layer_norm_eps, &mut rng)?;
138 let projection = Dense::<F>::new(_config.hidden_size, projection_dim, None, &mut rng)?;
140 Ok(Self {
141 token_embedding,
142 position_embedding,
143 encoder_layers,
144 layer_norm,
145 projection,
146 config: _config,
147 })
148 }
149}
150impl<
151 F: Float
152 + Debug
153 + ScalarOperand
154 + Send
155 + Sync
156 + scirs2_core::simd_ops::SimdUnifiedOps
157 + 'static
158 + NumAssign,
159 > Layer<F> for CLIPTextEncoder<F>
160{
161 fn as_any(&self) -> &dyn std::any::Any {
162 self
163 }
164
165 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
166 self
167 }
168
169 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
170 let mut x = self.token_embedding.forward(input)?;
172 x = self.position_embedding.forward(&x)?;
174 for layer in &self.encoder_layers {
176 x = layer.forward(&x)?;
177 }
178 x = self.layer_norm.forward(&x)?;
180 let batch_size = x.shape()[0];
182 let hidden_size = x.shape()[2];
183 let cls_token = x
184 .slice_axis(Axis(1), scirs2_core::ndarray::Slice::from(0..1))
185 .into_shape_with_order((batch_size, hidden_size))?;
186 let cls_token_owned = cls_token.to_owned().into_dyn();
188 let output = self.projection.forward(&cls_token_owned)?;
189 Ok(output)
190 }
191
192 fn backward(
193 &self,
194 input: &Array<F, IxDyn>,
195 grad_output: &Array<F, IxDyn>,
196 ) -> Result<Array<F, IxDyn>> {
197 let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
200 let batch_size = input.shape()[0];
203 let seq_len = input.shape()[1];
204 let hidden_size = grad_after_proj.shape()[1];
205 let mut grad_full_seq =
207 Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, seq_len, hidden_size]));
208 for i in 0..batch_size {
210 for j in 0..hidden_size {
211 grad_full_seq[[i, 0, j]] = grad_after_proj[[i, j]];
212 }
213 }
214 let grad_full_seq = grad_full_seq.into_dyn();
215 let mut grad = self.layer_norm.backward(&grad_full_seq, &grad_full_seq)?;
217 for layer in self.encoder_layers.iter().rev() {
219 grad = layer.backward(&grad, &grad)?;
220 }
221 grad = self.position_embedding.backward(&grad, &grad)?;
225 let grad_input = self.token_embedding.backward(input, &grad)?;
227 Ok(grad_input)
228 }
229
230 fn update(&mut self, learning_rate: F) -> Result<()> {
231 self.token_embedding.update(learning_rate)?;
234 self.position_embedding.update(learning_rate)?;
236 for layer in &mut self.encoder_layers {
238 layer.update(learning_rate)?;
239 }
240 self.layer_norm.update(learning_rate)?;
242 self.projection.update(learning_rate)?;
244 Ok(())
245 }
246
247 fn params(&self) -> Vec<Array<F, IxDyn>> {
248 let mut params = Vec::new();
249 params.extend(self.token_embedding.params());
250 for layer in &self.encoder_layers {
252 params.extend(layer.params());
253 }
254 params.extend(self.layer_norm.params());
255 params.extend(self.projection.params());
256 params
257 }
258
259 fn set_training(&mut self, training: bool) {
260 self.token_embedding.set_training(training);
261 self.position_embedding.set_training(training);
262 for layer in &mut self.encoder_layers {
263 layer.set_training(training);
264 }
265 self.layer_norm.set_training(training);
266 self.projection.set_training(training);
267 }
268
269 fn is_training(&self) -> bool {
270 self.token_embedding.is_training()
271 }
272}
273pub struct CLIPVisionEncoder<
275 F: Float
276 + Debug
277 + ScalarOperand
278 + Send
279 + Sync
280 + Clone
281 + 'static
282 + scirs2_core::simd_ops::SimdUnifiedOps
283 + NumAssign,
284> {
285 pub vision_transformer: VisionTransformer<F>,
287 pub projection: Dense<F>,
289}
290
291impl<
292 F: Float
293 + Debug
294 + ScalarOperand
295 + Send
296 + Sync
297 + Clone
298 + 'static
299 + scirs2_core::simd_ops::SimdUnifiedOps
300 + NumAssign,
301 > CLIPVisionEncoder<F>
302{
303 pub fn new(config: ViTConfig, projection_dim: usize) -> Result<Self> {
305 let embed_dim = config.embed_dim;
306 let vision_transformer = VisionTransformer::<F>::new(config)?;
307 let mut rng_proj = SmallRng::from_seed([42; 32]);
308 let projection = Dense::<F>::new(embed_dim, projection_dim, None, &mut rng_proj)?;
309 Ok(Self {
310 vision_transformer,
311 projection,
312 })
313 }
314}
315
316impl<
317 F: Float
318 + Debug
319 + ScalarOperand
320 + Send
321 + Sync
322 + Clone
323 + scirs2_core::simd_ops::SimdUnifiedOps
324 + 'static
325 + NumAssign,
326 > Layer<F> for CLIPVisionEncoder<F>
327{
328 fn as_any(&self) -> &dyn std::any::Any {
329 self
330 }
331
332 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
333 self
334 }
335
336 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
337 let x = self.vision_transformer.forward(input)?;
338 self.projection.forward(&x)
339 }
340
341 fn backward(
342 &self,
343 input: &Array<F, IxDyn>,
344 grad_output: &Array<F, IxDyn>,
345 ) -> Result<Array<F, IxDyn>> {
346 let grad_after_proj = self.projection.backward(grad_output, grad_output)?;
347 self.vision_transformer.backward(input, &grad_after_proj)
348 }
349
350 fn update(&mut self, learning_rate: F) -> Result<()> {
351 self.projection.update(learning_rate)?;
352 self.vision_transformer.update(learning_rate)?;
353 Ok(())
354 }
355
356 fn params(&self) -> Vec<Array<F, IxDyn>> {
357 let mut params = Vec::new();
358 params.extend(self.projection.params());
359 params.extend(self.vision_transformer.params());
360 params
361 }
362
363 fn set_training(&mut self, training: bool) {
364 self.projection.set_training(training);
365 self.vision_transformer.set_training(training);
366 }
367
368 fn is_training(&self) -> bool {
369 self.vision_transformer.is_training()
370 }
371
372 fn layer_type(&self) -> &str {
373 "CLIPVisionEncoder"
374 }
375}
376
377pub struct CLIP<
382 F: Float
383 + Debug
384 + ScalarOperand
385 + Send
386 + Sync
387 + Clone
388 + 'static
389 + scirs2_core::simd_ops::SimdUnifiedOps
390 + NumAssign,
391> {
392 pub vision_encoder: CLIPVisionEncoder<F>,
394 pub text_encoder: CLIPTextEncoder<F>,
396 pub classifier: Option<Dense<F>>,
398 pub _config: CLIPConfig,
400 pub logit_scale: F,
402}
403
404impl<
405 F: Float
406 + Debug
407 + ScalarOperand
408 + Send
409 + Sync
410 + Clone
411 + 'static
412 + scirs2_core::simd_ops::SimdUnifiedOps
413 + NumAssign,
414 > CLIP<F>
415{
416 pub fn new(config: CLIPConfig) -> Result<Self> {
418 let vision_encoder =
419 CLIPVisionEncoder::<F>::new(config.vision_config.clone(), config.projection_dim)?;
420 let text_encoder =
421 CLIPTextEncoder::<F>::new(config.text_config.clone(), config.projection_dim)?;
422 let classifier = if config.include_head {
423 let mut rng_cls = SmallRng::from_seed([42; 32]);
424 Some(Dense::<F>::new(
425 config.projection_dim,
426 config.num_classes,
427 None,
428 &mut rng_cls,
429 )?)
430 } else {
431 None
432 };
433 let logit_scale = F::from(2.6592_f64).ok_or_else(|| {
435 crate::error::NeuralError::InvalidArchitecture(
436 "CLIP: failed to convert logit_scale to float".to_string(),
437 )
438 })?;
439 Ok(Self {
440 vision_encoder,
441 text_encoder,
442 classifier,
443 _config: config,
444 logit_scale,
445 })
446 }
447
448 pub fn forward_contrastive(
451 &self,
452 image_input: &Array<F, IxDyn>,
453 text_input: &Array<F, IxDyn>,
454 ) -> Result<ClipOutput<F>> {
455 let image_features = self.vision_encoder.forward(image_input)?;
456 let text_features = self.text_encoder.forward(text_input)?;
457 let image_features_norm = normalize_features(&image_features)?;
458 let text_features_norm = normalize_features(&text_features)?;
459 let logits_per_image =
460 compute_similarity(&image_features_norm, &text_features_norm, self.logit_scale)?;
461 Ok((image_features, text_features, logits_per_image))
462 }
463
464 pub fn forward_classification(
466 &self,
467 image_input: &Array<F, IxDyn>,
468 text_embeddings: &Array<F, IxDyn>,
469 ) -> Result<Array<F, IxDyn>> {
470 let image_features = self.vision_encoder.forward(image_input)?;
471 let image_features_norm = normalize_features(&image_features)?;
472 compute_similarity(&image_features_norm, text_embeddings, self.logit_scale)
473 }
474
475 pub fn clip_base(num_classes: usize, include_head: bool) -> Result<Self> {
477 let vision_config = ViTConfig {
478 image_size: (224, 224),
479 patch_size: (16, 16),
480 in_channels: 3,
481 num_classes,
482 embed_dim: 768,
483 num_heads: 12,
484 mlp_dim: 3072,
485 num_layers: 12,
486 dropout_rate: 0.1,
487 attention_dropout_rate: 0.1,
488 };
489 Self::new(CLIPConfig {
490 text_config: CLIPTextConfig::default(),
491 vision_config,
492 projection_dim: 512,
493 include_head,
494 num_classes,
495 })
496 }
497
498 pub fn clip_small(num_classes: usize, include_head: bool) -> Result<Self> {
500 let vision_config = ViTConfig {
501 image_size: (224, 224),
502 patch_size: (16, 16),
503 in_channels: 3,
504 num_classes,
505 embed_dim: 512,
506 num_heads: 6,
507 mlp_dim: 2048,
508 num_layers: 8,
509 dropout_rate: 0.1,
510 attention_dropout_rate: 0.1,
511 };
512 let text_config = CLIPTextConfig {
513 vocab_size: 49408,
514 hidden_size: 384,
515 intermediate_size: 1536,
516 num_layers: 8,
517 num_heads: 6,
518 max_position_embeddings: 77,
519 dropout_rate: 0.1,
520 layer_norm_eps: 1e-5,
521 };
522 Self::new(CLIPConfig {
523 text_config,
524 vision_config,
525 projection_dim: 256,
526 include_head,
527 num_classes,
528 })
529 }
530}
531
532impl<
533 F: Float
534 + Debug
535 + ScalarOperand
536 + Send
537 + Sync
538 + Clone
539 + scirs2_core::simd_ops::SimdUnifiedOps
540 + 'static
541 + NumAssign,
542 > Layer<F> for CLIP<F>
543{
544 fn as_any(&self) -> &dyn std::any::Any {
545 self
546 }
547
548 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
549 self
550 }
551
552 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
553 let image_features = self.vision_encoder.forward(input)?;
554 if let Some(ref classifier) = self.classifier {
555 return classifier.forward(&image_features);
556 }
557 Ok(image_features)
558 }
559
560 fn backward(
561 &self,
562 input: &Array<F, IxDyn>,
563 grad_output: &Array<F, IxDyn>,
564 ) -> Result<Array<F, IxDyn>> {
565 let mut grad = grad_output.clone();
566 if let Some(ref classifier) = self.classifier {
567 grad = classifier.backward(&grad, &grad)?;
568 }
569 self.vision_encoder.backward(input, &grad)
570 }
571
572 fn update(&mut self, learning_rate: F) -> Result<()> {
573 self.vision_encoder.update(learning_rate)?;
574 self.text_encoder.update(learning_rate)?;
575 if let Some(ref mut classifier) = self.classifier {
576 classifier.update(learning_rate)?;
577 }
578 Ok(())
579 }
580
581 fn params(&self) -> Vec<Array<F, IxDyn>> {
582 let mut params = Vec::new();
583 params.extend(self.vision_encoder.params());
584 params.extend(self.text_encoder.params());
585 if let Some(ref classifier) = self.classifier {
586 params.extend(classifier.params());
587 }
588 params
589 }
590
591 fn set_training(&mut self, training: bool) {
592 self.vision_encoder.set_training(training);
593 self.text_encoder.set_training(training);
594 if let Some(ref mut classifier) = self.classifier {
595 classifier.set_training(training);
596 }
597 }
598
599 fn is_training(&self) -> bool {
600 self.vision_encoder.is_training()
601 }
602
603 fn layer_type(&self) -> &str {
604 "CLIP"
605 }
606}
607#[allow(dead_code)]
609fn normalize_features<F: Float + Debug + ScalarOperand>(
610 features: &Array<F, IxDyn>,
611) -> Result<Array<F, IxDyn>> {
612 let shape = features.shape();
613 let batch_size = shape[0];
614 let feature_dim = shape[1];
615 let features_2d = features
617 .clone()
618 .into_shape_with_order((batch_size, feature_dim))?;
619 let norm = features_2d.map_axis(Axis(1), |x| {
621 let sum_squares = x.iter().fold(F::zero(), |acc, &val| acc + val * val);
622 let norm = sum_squares.sqrt();
623 if norm > F::from(1e-12).expect("Failed to convert constant to float") {
625 norm
626 } else {
627 F::one()
628 }
629 });
630 let norm_expanded = norm.insert_axis(Axis(1));
632 let normalized = features_2d.clone() / norm_expanded;
634 Ok(normalized.into_shape_with_order(shape)?)
636}
637
638#[allow(dead_code)]
640fn compute_similarity<F: Float + Debug + ScalarOperand>(
641 features_a: &Array<F, IxDyn>,
642 features_b: &Array<F, IxDyn>,
643 temperature: F,
644) -> Result<Array<F, IxDyn>> {
645 let shape_a = features_a.shape();
647 let shape_b = features_b.shape();
648 let batch_a = shape_a[0];
649 let batch_b = shape_b[0];
650 let features_a_2d = features_a
652 .clone()
653 .into_shape_with_order((batch_a, shape_a[1]))?;
654 let features_b_2d = features_b
655 .clone()
656 .into_shape_with_order((batch_b, shape_b[1]))?;
657 let similarity = features_a_2d.dot(&features_b_2d.t());
659 let scaled_similarity = similarity * temperature;
661 Ok(scaled_similarity.into_dyn())
662}