1use crate::{TokenizedInput, Tokenizer};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use trustformers_core::errors::Result;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct MultimodalConfig {
15 pub max_text_length: Option<usize>,
17 pub max_image_patches: Option<usize>,
19 pub max_audio_frames: Option<usize>,
21 pub image_patch_size: usize,
23 pub audio_frame_size: usize,
25 pub include_special_tokens: bool,
27 pub use_cross_modal_attention: bool,
29 pub fusion_strategy: FusionStrategy,
31 pub text_tokenizer_config: Option<HashMap<String, String>>,
33 pub vision_tokenizer_config: Option<HashMap<String, String>>,
35 pub audio_tokenizer_config: Option<HashMap<String, String>>,
37}
38
39impl Default for MultimodalConfig {
40 fn default() -> Self {
41 Self {
42 max_text_length: Some(512),
43 max_image_patches: Some(196), max_audio_frames: Some(1000),
45 image_patch_size: 16,
46 audio_frame_size: 256,
47 include_special_tokens: true,
48 use_cross_modal_attention: true,
49 fusion_strategy: FusionStrategy::Concatenation,
50 text_tokenizer_config: None,
51 vision_tokenizer_config: None,
52 audio_tokenizer_config: None,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum FusionStrategy {
60 Concatenation,
62 Interleaved,
64 CrossAttention,
66 Hierarchical,
68 Gated,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub enum ModalityType {
75 Text,
76 Image,
77 Audio,
78 Video,
79 Depth,
80 PointCloud,
81 Graph,
82 Table,
83 Code,
84 Custom(String),
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct MultimodalToken {
90 pub token_id: u32,
92 pub modality: ModalityType,
94 pub modality_position: usize,
96 pub global_position: usize,
98 pub metadata: Option<MultimodalTokenMetadata>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MultimodalTokenMetadata {
105 pub spatial_coords: Option<(usize, usize)>,
107 pub temporal_coords: Option<f64>,
109 pub channel: Option<usize>,
111 pub confidence: Option<f64>,
113 pub features: Option<Vec<f32>>,
115 pub attention_weights: Option<Vec<f32>>,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ImagePatch {
122 pub x: usize,
124 pub y: usize,
125 pub width: usize,
127 pub height: usize,
128 pub pixels: Vec<f32>,
130 pub embedding: Option<Vec<f32>>,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct AudioFrame {
137 pub timestamp: f64,
139 pub duration: f64,
141 pub samples: Vec<f32>,
143 pub features: Option<Vec<f32>>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct VideoFrame {
150 pub frame_number: usize,
152 pub timestamp: f64,
154 pub patches: Vec<ImagePatch>,
156 pub motion_vectors: Option<Vec<(f32, f32)>>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct MultimodalInput {
163 pub text: Option<String>,
165 pub image_patches: Option<Vec<ImagePatch>>,
167 pub audio_frames: Option<Vec<AudioFrame>>,
169 pub video_frames: Option<Vec<VideoFrame>>,
171 pub table_data: Option<TableData>,
173 pub graph_data: Option<GraphData>,
175 pub custom_data: Option<HashMap<String, Vec<u8>>>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct TableData {
182 pub headers: Vec<String>,
184 pub rows: Vec<Vec<String>>,
186 pub column_types: Option<Vec<String>>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GraphData {
193 pub nodes: Vec<Vec<f32>>,
195 pub edges: Vec<(usize, usize)>,
197 pub edge_features: Option<Vec<Vec<f32>>>,
199 pub node_labels: Option<Vec<String>>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct MultimodalTokenizedInput {
206 pub input_ids: Vec<u32>,
208 pub attention_mask: Option<Vec<u32>>,
210 pub token_type_ids: Option<Vec<u32>>,
212 pub modality_tokens: Vec<MultimodalToken>,
214 pub modality_boundaries: HashMap<ModalityType, (usize, usize)>,
216 pub cross_modal_attention: Option<Vec<Vec<f32>>>,
218}
219
220pub struct MultimodalTokenizer<T: Tokenizer> {
222 text_tokenizer: Arc<T>,
223 config: MultimodalConfig,
224 vocab: HashMap<String, u32>,
225 id_to_token: HashMap<u32, String>,
226 next_id: u32,
227 modality_token_ids: HashMap<ModalityType, u32>,
228}
229
230impl<T: Tokenizer> MultimodalTokenizer<T> {
231 pub fn new(text_tokenizer: T, config: MultimodalConfig) -> Self {
233 let mut tokenizer = Self {
234 text_tokenizer: Arc::new(text_tokenizer),
235 config,
236 vocab: HashMap::new(),
237 id_to_token: HashMap::new(),
238 next_id: 0,
239 modality_token_ids: HashMap::new(),
240 };
241
242 tokenizer.initialize_vocab();
243 tokenizer
244 }
245
246 pub fn from_text_tokenizer(text_tokenizer: T) -> Self {
248 Self::new(text_tokenizer, MultimodalConfig::default())
249 }
250
251 fn initialize_vocab(&mut self) {
253 if self.config.include_special_tokens {
255 self.add_token("[CLS]");
256 self.add_token("[SEP]");
257 self.add_token("[PAD]");
258 self.add_token("[UNK]");
259 self.add_token("[MASK]");
260 }
261
262 let modality_tokens = vec![
264 (ModalityType::Text, "[TEXT]"),
265 (ModalityType::Image, "[IMG]"),
266 (ModalityType::Audio, "[AUD]"),
267 (ModalityType::Video, "[VID]"),
268 (ModalityType::Table, "[TAB]"),
269 (ModalityType::Graph, "[GRF]"),
270 (ModalityType::Code, "[COD]"),
271 ];
272
273 for (modality, token) in modality_tokens {
274 let token_id = self.add_token(token);
275 self.modality_token_ids.insert(modality, token_id);
276 }
277
278 for i in 0..self.config.max_image_patches.unwrap_or(196) {
280 self.add_token(&format!("[PATCH_{}]", i));
281 }
282
283 for i in 0..self.config.max_audio_frames.unwrap_or(1000) {
284 self.add_token(&format!("[FRAME_{}]", i));
285 }
286
287 self.add_token("[FUSE]");
289 self.add_token("[CROSS_ATTN]");
290 self.add_token("[MODAL_SEP]");
291 }
292
293 fn add_token(&mut self, token: &str) -> u32 {
295 if let Some(&id) = self.vocab.get(token) {
296 return id;
297 }
298
299 let id = self.next_id;
300 self.vocab.insert(token.to_string(), id);
301 self.id_to_token.insert(id, token.to_string());
302 self.next_id += 1;
303 id
304 }
305
306 pub fn tokenize_multimodal(&self, input: &MultimodalInput) -> Result<MultimodalTokenizedInput> {
308 let mut all_tokens = Vec::new();
309 let mut modality_boundaries = HashMap::new();
310 let mut current_position = 0;
311
312 if let Some(ref text) = input.text {
314 let text_tokens = self.tokenize_text(text, current_position)?;
315 let start_pos = current_position;
316 all_tokens.extend(text_tokens);
317 let end_pos = all_tokens.len();
318 modality_boundaries.insert(ModalityType::Text, (start_pos, end_pos));
319 current_position = end_pos;
320 }
321
322 if let Some(ref patches) = input.image_patches {
324 let image_tokens = self.tokenize_image_patches(patches, current_position)?;
325 let start_pos = current_position;
326 all_tokens.extend(image_tokens);
327 let end_pos = all_tokens.len();
328 modality_boundaries.insert(ModalityType::Image, (start_pos, end_pos));
329 current_position = end_pos;
330 }
331
332 if let Some(ref frames) = input.audio_frames {
334 let audio_tokens = self.tokenize_audio_frames(frames, current_position)?;
335 let start_pos = current_position;
336 all_tokens.extend(audio_tokens);
337 let end_pos = all_tokens.len();
338 modality_boundaries.insert(ModalityType::Audio, (start_pos, end_pos));
339 current_position = end_pos;
340 }
341
342 if let Some(ref frames) = input.video_frames {
344 let video_tokens = self.tokenize_video_frames(frames, current_position)?;
345 let start_pos = current_position;
346 all_tokens.extend(video_tokens);
347 let end_pos = all_tokens.len();
348 modality_boundaries.insert(ModalityType::Video, (start_pos, end_pos));
349 current_position = end_pos;
350 }
351
352 if let Some(ref table) = input.table_data {
354 let table_tokens = self.tokenize_table(table, current_position)?;
355 let start_pos = current_position;
356 all_tokens.extend(table_tokens);
357 let end_pos = all_tokens.len();
358 modality_boundaries.insert(ModalityType::Table, (start_pos, end_pos));
359 let _ = end_pos; }
361
362 let fused_tokens = self.apply_fusion_strategy(&all_tokens)?;
364
365 let input_ids: Vec<u32> = fused_tokens.iter().map(|t| t.token_id).collect();
367 let attention_mask = Some(vec![1u32; input_ids.len()]);
368 let token_type_ids =
369 Some(fused_tokens.iter().map(|t| self.get_modality_type_id(&t.modality)).collect());
370
371 Ok(MultimodalTokenizedInput {
372 input_ids,
373 attention_mask,
374 token_type_ids,
375 modality_tokens: fused_tokens,
376 modality_boundaries,
377 cross_modal_attention: None, })
379 }
380
381 fn tokenize_text(&self, text: &str, start_position: usize) -> Result<Vec<MultimodalToken>> {
383 let text_tokenized = self.text_tokenizer.encode(text)?;
384 let mut tokens = Vec::new();
385
386 for (i, &token_id) in text_tokenized.input_ids.iter().enumerate() {
387 tokens.push(MultimodalToken {
388 token_id,
389 modality: ModalityType::Text,
390 modality_position: i,
391 global_position: start_position + i,
392 metadata: None,
393 });
394 }
395
396 Ok(tokens)
397 }
398
399 fn tokenize_image_patches(
401 &self,
402 patches: &[ImagePatch],
403 start_position: usize,
404 ) -> Result<Vec<MultimodalToken>> {
405 let mut tokens = Vec::new();
406
407 if let Some(&img_token_id) = self.modality_token_ids.get(&ModalityType::Image) {
409 tokens.push(MultimodalToken {
410 token_id: img_token_id,
411 modality: ModalityType::Image,
412 modality_position: 0,
413 global_position: start_position,
414 metadata: None,
415 });
416 }
417
418 for (i, patch) in patches.iter().enumerate() {
420 let patch_token = format!("[PATCH_{}]", i);
421 if let Some(&token_id) = self.vocab.get(&patch_token) {
422 let metadata = Some(MultimodalTokenMetadata {
423 spatial_coords: Some((patch.x, patch.y)),
424 temporal_coords: None,
425 channel: None,
426 confidence: None,
427 features: patch.embedding.clone(),
428 attention_weights: None,
429 });
430
431 tokens.push(MultimodalToken {
432 token_id,
433 modality: ModalityType::Image,
434 modality_position: i + 1,
435 global_position: start_position + tokens.len(),
436 metadata,
437 });
438 }
439
440 if tokens.len() >= self.config.max_image_patches.unwrap_or(196) {
442 break;
443 }
444 }
445
446 Ok(tokens)
447 }
448
449 fn tokenize_audio_frames(
451 &self,
452 frames: &[AudioFrame],
453 start_position: usize,
454 ) -> Result<Vec<MultimodalToken>> {
455 let mut tokens = Vec::new();
456
457 if let Some(&aud_token_id) = self.modality_token_ids.get(&ModalityType::Audio) {
459 tokens.push(MultimodalToken {
460 token_id: aud_token_id,
461 modality: ModalityType::Audio,
462 modality_position: 0,
463 global_position: start_position,
464 metadata: None,
465 });
466 }
467
468 for (i, frame) in frames.iter().enumerate() {
470 let frame_token = format!("[FRAME_{}]", i);
471 if let Some(&token_id) = self.vocab.get(&frame_token) {
472 let metadata = Some(MultimodalTokenMetadata {
473 spatial_coords: None,
474 temporal_coords: Some(frame.timestamp),
475 channel: None,
476 confidence: None,
477 features: frame.features.clone(),
478 attention_weights: None,
479 });
480
481 tokens.push(MultimodalToken {
482 token_id,
483 modality: ModalityType::Audio,
484 modality_position: i + 1,
485 global_position: start_position + tokens.len(),
486 metadata,
487 });
488 }
489
490 if tokens.len() >= self.config.max_audio_frames.unwrap_or(1000) {
492 break;
493 }
494 }
495
496 Ok(tokens)
497 }
498
499 fn tokenize_video_frames(
501 &self,
502 frames: &[VideoFrame],
503 start_position: usize,
504 ) -> Result<Vec<MultimodalToken>> {
505 let mut tokens = Vec::new();
506
507 if let Some(&vid_token_id) = self.modality_token_ids.get(&ModalityType::Video) {
509 tokens.push(MultimodalToken {
510 token_id: vid_token_id,
511 modality: ModalityType::Video,
512 modality_position: 0,
513 global_position: start_position,
514 metadata: None,
515 });
516 }
517
518 for (frame_idx, frame) in frames.iter().enumerate() {
520 for (patch_idx, patch) in frame.patches.iter().enumerate() {
521 let patch_token = format!("[PATCH_{}]", patch_idx);
522 if let Some(&token_id) = self.vocab.get(&patch_token) {
523 let metadata = Some(MultimodalTokenMetadata {
524 spatial_coords: Some((patch.x, patch.y)),
525 temporal_coords: Some(frame.timestamp),
526 channel: Some(frame_idx),
527 confidence: None,
528 features: patch.embedding.clone(),
529 attention_weights: None,
530 });
531
532 tokens.push(MultimodalToken {
533 token_id,
534 modality: ModalityType::Video,
535 modality_position: tokens.len(),
536 global_position: start_position + tokens.len(),
537 metadata,
538 });
539 }
540 }
541 }
542
543 Ok(tokens)
544 }
545
546 fn tokenize_table(
548 &self,
549 table: &TableData,
550 start_position: usize,
551 ) -> Result<Vec<MultimodalToken>> {
552 let mut tokens = Vec::new();
553
554 if let Some(&tab_token_id) = self.modality_token_ids.get(&ModalityType::Table) {
556 tokens.push(MultimodalToken {
557 token_id: tab_token_id,
558 modality: ModalityType::Table,
559 modality_position: 0,
560 global_position: start_position,
561 metadata: None,
562 });
563 }
564
565 let mut table_text = table.headers.join(" | ");
567 for row in &table.rows {
568 table_text.push_str(" | ");
569 table_text.push_str(&row.join(" | "));
570 }
571
572 let text_tokens = self.text_tokenizer.encode(&table_text)?;
573 for (i, &token_id) in text_tokens.input_ids.iter().enumerate() {
574 tokens.push(MultimodalToken {
575 token_id,
576 modality: ModalityType::Table,
577 modality_position: i + 1,
578 global_position: start_position + tokens.len(),
579 metadata: None,
580 });
581 }
582
583 Ok(tokens)
584 }
585
586 fn apply_fusion_strategy(&self, tokens: &[MultimodalToken]) -> Result<Vec<MultimodalToken>> {
588 match self.config.fusion_strategy {
589 FusionStrategy::Concatenation => Ok(tokens.to_vec()),
590 FusionStrategy::Interleaved => self.apply_interleaved_fusion(tokens),
591 FusionStrategy::CrossAttention => self.apply_cross_attention_fusion(tokens),
592 FusionStrategy::Hierarchical => self.apply_hierarchical_fusion(tokens),
593 FusionStrategy::Gated => self.apply_gated_fusion(tokens),
594 }
595 }
596
597 fn apply_interleaved_fusion(&self, tokens: &[MultimodalToken]) -> Result<Vec<MultimodalToken>> {
599 let mut modality_groups: HashMap<ModalityType, Vec<&MultimodalToken>> = HashMap::new();
601 for token in tokens {
602 modality_groups.entry(token.modality.clone()).or_default().push(token);
603 }
604
605 let mut result = Vec::new();
607 let max_len = modality_groups.values().map(|v| v.len()).max().unwrap_or(0);
608
609 for i in 0..max_len {
610 for group in modality_groups.values() {
611 if let Some(token) = group.get(i) {
612 result.push((*token).clone());
613 }
614 }
615 }
616
617 Ok(result)
618 }
619
620 fn apply_cross_attention_fusion(
622 &self,
623 tokens: &[MultimodalToken],
624 ) -> Result<Vec<MultimodalToken>> {
625 let mut modality_groups: HashMap<ModalityType, Vec<&MultimodalToken>> = HashMap::new();
627 for token in tokens {
628 modality_groups.entry(token.modality.clone()).or_default().push(token);
629 }
630
631 if modality_groups.len() < 2 {
633 return Ok(tokens.to_vec());
634 }
635
636 let mut result = Vec::new();
637 let modalities: Vec<_> = modality_groups.keys().cloned().collect();
638
639 if let Some(&cross_attn_token_id) = self.vocab.get("[CROSS_ATTN]") {
641 for (i, (modality, group)) in modality_groups.iter().enumerate() {
642 for token in group {
644 let mut enhanced_token = (*token).clone();
645
646 let mut attention_weights = Vec::new();
648 for (j, other_modality) in modalities.iter().enumerate() {
649 if i != j {
650 let attention_score = self.calculate_cross_modal_attention_score(
652 modality,
653 other_modality,
654 token.modality_position,
655 );
656 attention_weights.push(attention_score);
657 }
658 }
659
660 if let Some(ref mut metadata) = enhanced_token.metadata {
662 metadata.attention_weights = Some(attention_weights);
663 } else {
664 enhanced_token.metadata = Some(MultimodalTokenMetadata {
665 spatial_coords: None,
666 temporal_coords: None,
667 channel: None,
668 confidence: None,
669 features: None,
670 attention_weights: Some(attention_weights),
671 });
672 }
673
674 result.push(enhanced_token);
675 }
676
677 if i < modality_groups.len() - 1 {
679 result.push(MultimodalToken {
680 token_id: cross_attn_token_id,
681 modality: ModalityType::Custom("cross_attention".to_string()),
682 modality_position: 0,
683 global_position: result.len(),
684 metadata: None,
685 });
686 }
687 }
688 } else {
689 result = tokens.to_vec();
691 }
692
693 Ok(result)
694 }
695
696 fn calculate_cross_modal_attention_score(
698 &self,
699 source_modality: &ModalityType,
700 target_modality: &ModalityType,
701 position: usize,
702 ) -> f32 {
703 let base_score = match (source_modality, target_modality) {
705 (ModalityType::Text, ModalityType::Image)
707 | (ModalityType::Image, ModalityType::Text) => 0.8,
708 (ModalityType::Text, ModalityType::Audio)
710 | (ModalityType::Audio, ModalityType::Text) => 0.7,
711 (ModalityType::Image, ModalityType::Video)
713 | (ModalityType::Video, ModalityType::Image) => 0.9,
714 (ModalityType::Audio, ModalityType::Video)
716 | (ModalityType::Video, ModalityType::Audio) => 0.75,
717 (ModalityType::Table, ModalityType::Text)
719 | (ModalityType::Text, ModalityType::Table) => 0.6,
720 (ModalityType::Code, ModalityType::Text) | (ModalityType::Text, ModalityType::Code) => {
722 0.65
723 },
724 (ModalityType::Graph, ModalityType::Table)
726 | (ModalityType::Table, ModalityType::Graph) => 0.7,
727 (a, b) if a == b => 0.5,
729 _ => 0.4,
731 };
732
733 let position_factor = 1.0 / (1.0 + (position as f32 * 0.1));
735
736 base_score * position_factor
738 }
739
740 fn apply_hierarchical_fusion(
742 &self,
743 tokens: &[MultimodalToken],
744 ) -> Result<Vec<MultimodalToken>> {
745 let mut result = Vec::new();
747 let mut current_modality = None;
748
749 if let Some(&fuse_token_id) = self.vocab.get("[FUSE]") {
750 for token in tokens {
751 if current_modality.is_some() && current_modality.as_ref() != Some(&token.modality)
752 {
753 result.push(MultimodalToken {
755 token_id: fuse_token_id,
756 modality: ModalityType::Custom("fusion".to_string()),
757 modality_position: 0,
758 global_position: result.len(),
759 metadata: None,
760 });
761 }
762 result.push(token.clone());
763 current_modality = Some(token.modality.clone());
764 }
765 }
766
767 Ok(result)
768 }
769
770 fn apply_gated_fusion(&self, tokens: &[MultimodalToken]) -> Result<Vec<MultimodalToken>> {
772 let mut modality_groups: HashMap<ModalityType, Vec<&MultimodalToken>> = HashMap::new();
774 for token in tokens {
775 modality_groups.entry(token.modality.clone()).or_default().push(token);
776 }
777
778 if modality_groups.len() < 2 {
780 return Ok(tokens.to_vec());
781 }
782
783 let mut result = Vec::new();
784
785 let mut modality_gates: HashMap<ModalityType, f32> = HashMap::new();
787 for (modality, group) in &modality_groups {
788 let gate_weight = self.calculate_modality_gate_weight(modality, group);
789 modality_gates.insert(modality.clone(), gate_weight);
790 }
791
792 let total_weight: f32 = modality_gates.values().sum();
794 if total_weight > 0.0 {
795 for weight in modality_gates.values_mut() {
796 *weight /= total_weight;
797 }
798 }
799
800 for (modality, group) in &modality_groups {
802 let gate_weight = modality_gates.get(modality).copied().unwrap_or(0.0);
803
804 for token in group {
805 let mut gated_token = (*token).clone();
806
807 let token_confidence = self.calculate_token_confidence(token, gate_weight);
809
810 if let Some(ref mut metadata) = gated_token.metadata {
812 metadata.confidence = Some(token_confidence as f64);
813 } else {
814 gated_token.metadata = Some(MultimodalTokenMetadata {
815 spatial_coords: None,
816 temporal_coords: None,
817 channel: None,
818 confidence: Some(token_confidence as f64),
819 features: None,
820 attention_weights: None,
821 });
822 }
823
824 if token_confidence > 0.1 {
826 result.push(gated_token);
827 }
828 }
829 }
830
831 result.sort_by(|a, b| {
833 let conf_a = a.metadata.as_ref().and_then(|m| m.confidence).unwrap_or(0.0);
834 let conf_b = b.metadata.as_ref().and_then(|m| m.confidence).unwrap_or(0.0);
835 conf_b.partial_cmp(&conf_a).unwrap_or(std::cmp::Ordering::Equal)
836 });
837
838 for (i, token) in result.iter_mut().enumerate() {
840 token.global_position = i;
841 }
842
843 Ok(result)
844 }
845
846 fn calculate_modality_gate_weight(
848 &self,
849 modality: &ModalityType,
850 tokens: &[&MultimodalToken],
851 ) -> f32 {
852 if tokens.is_empty() {
853 return 0.0;
854 }
855
856 let base_weight = match modality {
858 ModalityType::Text => 1.0, ModalityType::Image => 0.8, ModalityType::Video => 0.9, ModalityType::Audio => 0.7, ModalityType::Table => 0.6, ModalityType::Graph => 0.65, ModalityType::Code => 0.75, ModalityType::Custom(_) => 0.5, _ => 0.4, };
868
869 let token_count_factor = (tokens.len() as f32).sqrt() / 10.0;
871
872 let feature_richness = tokens
874 .iter()
875 .map(|token| {
876 if let Some(metadata) = &token.metadata {
877 let mut richness = 0.0;
878 if metadata.spatial_coords.is_some() {
879 richness += 0.2;
880 }
881 if metadata.temporal_coords.is_some() {
882 richness += 0.2;
883 }
884 if metadata.features.is_some() {
885 richness += 0.4;
886 }
887 if metadata.confidence.is_some() {
888 richness += 0.2;
889 }
890 richness
891 } else {
892 0.1 }
894 })
895 .sum::<f32>()
896 / tokens.len() as f32;
897
898 base_weight * (1.0 + token_count_factor) * (1.0 + feature_richness)
900 }
901
902 fn calculate_token_confidence(&self, token: &MultimodalToken, gate_weight: f32) -> f32 {
904 let mut confidence = gate_weight;
906
907 if let Some(metadata) = &token.metadata {
909 if metadata.spatial_coords.is_some()
911 && matches!(token.modality, ModalityType::Image | ModalityType::Video)
912 {
913 confidence *= 1.2;
914 }
915
916 if metadata.temporal_coords.is_some()
918 && matches!(token.modality, ModalityType::Audio | ModalityType::Video)
919 {
920 confidence *= 1.15;
921 }
922
923 if let Some(features) = &metadata.features {
925 let feature_magnitude =
926 features.iter().map(|f| f.abs()).sum::<f32>() / features.len() as f32;
927 confidence *= 1.0 + (feature_magnitude * 0.1);
928 }
929
930 if let Some(existing_confidence) = metadata.confidence {
932 confidence = (confidence + existing_confidence as f32) / 2.0;
933 }
934 }
935
936 let position_factor = 1.0 / (1.0 + (token.modality_position as f32 * 0.05));
938 confidence *= position_factor;
939
940 confidence.clamp(0.0, 1.0)
942 }
943
944 fn get_modality_type_id(&self, modality: &ModalityType) -> u32 {
946 match modality {
947 ModalityType::Text => 0,
948 ModalityType::Image => 1,
949 ModalityType::Audio => 2,
950 ModalityType::Video => 3,
951 ModalityType::Table => 4,
952 ModalityType::Graph => 5,
953 ModalityType::Code => 6,
954 ModalityType::Custom(_) => 7,
955 _ => 0,
956 }
957 }
958
959 pub fn config(&self) -> &MultimodalConfig {
961 &self.config
962 }
963
964 pub fn get_vocab(&self) -> &HashMap<String, u32> {
966 &self.vocab
967 }
968
969 pub fn text_tokenizer(&self) -> &T {
971 &self.text_tokenizer
972 }
973}
974
975impl<T: Tokenizer> Tokenizer for MultimodalTokenizer<T> {
976 fn encode(&self, text: &str) -> Result<TokenizedInput> {
977 self.text_tokenizer.encode(text)
979 }
980
981 fn decode(&self, token_ids: &[u32]) -> Result<String> {
982 let text_tokens: Vec<u32> = token_ids
984 .iter()
985 .copied()
986 .filter(|&id| {
987 if let Some(token) = self.id_to_token.get(&id) {
988 !token.starts_with('[') || !token.ends_with(']')
989 } else {
990 true
991 }
992 })
993 .collect();
994
995 self.text_tokenizer.decode(&text_tokens)
996 }
997
998 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
999 self.text_tokenizer.encode_pair(text_a, text_b)
1001 }
1002
1003 fn vocab_size(&self) -> usize {
1004 self.text_tokenizer.vocab_size() + self.vocab.len()
1005 }
1006
1007 fn get_vocab(&self) -> HashMap<String, u32> {
1008 let mut vocab = self.text_tokenizer.get_vocab();
1009 for (token, &id) in &self.vocab {
1010 vocab.insert(token.clone(), id);
1011 }
1012 vocab
1013 }
1014
1015 fn token_to_id(&self, token: &str) -> Option<u32> {
1016 self.vocab
1017 .get(token)
1018 .copied()
1019 .or_else(|| self.text_tokenizer.token_to_id(token))
1020 }
1021
1022 fn id_to_token(&self, id: u32) -> Option<String> {
1023 self.id_to_token
1024 .get(&id)
1025 .cloned()
1026 .or_else(|| self.text_tokenizer.id_to_token(id))
1027 }
1028}
1029
1030pub struct MultimodalUtils;
1032
1033impl MultimodalUtils {
1034 pub fn create_image_patches(
1036 image_width: usize,
1037 image_height: usize,
1038 patch_size: usize,
1039 ) -> Vec<ImagePatch> {
1040 let mut patches = Vec::new();
1041
1042 for y in (0..image_height).step_by(patch_size) {
1043 for x in (0..image_width).step_by(patch_size) {
1044 let width = (patch_size).min(image_width - x);
1045 let height = (patch_size).min(image_height - y);
1046
1047 patches.push(ImagePatch {
1048 x,
1049 y,
1050 width,
1051 height,
1052 pixels: vec![0.0; width * height * 3], embedding: None,
1054 });
1055 }
1056 }
1057
1058 patches
1059 }
1060
1061 pub fn create_audio_frames(
1063 sample_rate: f64,
1064 duration: f64,
1065 frame_size: usize,
1066 hop_size: usize,
1067 ) -> Vec<AudioFrame> {
1068 let mut frames = Vec::new();
1069 let total_samples = (sample_rate * duration) as usize;
1070
1071 for start in (0..total_samples).step_by(hop_size) {
1072 let end = (start + frame_size).min(total_samples);
1073 let timestamp = start as f64 / sample_rate;
1074 let frame_duration = (end - start) as f64 / sample_rate;
1075
1076 frames.push(AudioFrame {
1077 timestamp,
1078 duration: frame_duration,
1079 samples: vec![0.0; end - start],
1080 features: None,
1081 });
1082 }
1083
1084 frames
1085 }
1086
1087 pub fn convert_to_multimodal(
1089 tokenized: TokenizedInput,
1090 modality: ModalityType,
1091 ) -> MultimodalTokenizedInput {
1092 let modality_tokens: Vec<MultimodalToken> = tokenized
1093 .input_ids
1094 .into_iter()
1095 .enumerate()
1096 .map(|(i, token_id)| MultimodalToken {
1097 token_id,
1098 modality: modality.clone(),
1099 modality_position: i,
1100 global_position: i,
1101 metadata: None,
1102 })
1103 .collect();
1104
1105 let mut boundaries = HashMap::new();
1106 boundaries.insert(modality, (0, modality_tokens.len()));
1107
1108 MultimodalTokenizedInput {
1109 input_ids: modality_tokens.iter().map(|t| t.token_id).collect(),
1110 attention_mask: Some(tokenized.attention_mask.into_iter().map(|x| x as u32).collect()),
1111 token_type_ids: tokenized.token_type_ids,
1112 modality_tokens,
1113 modality_boundaries: boundaries,
1114 cross_modal_attention: None,
1115 }
1116 }
1117
1118 pub fn calculate_cross_modal_attention(
1120 tokens: &[MultimodalToken],
1121 query_modality: &ModalityType,
1122 key_modality: &ModalityType,
1123 ) -> Vec<Vec<f32>> {
1124 let query_tokens: Vec<_> =
1125 tokens.iter().filter(|t| &t.modality == query_modality).collect();
1126 let key_tokens: Vec<_> = tokens.iter().filter(|t| &t.modality == key_modality).collect();
1127
1128 let mut attention = vec![vec![0.0; key_tokens.len()]; query_tokens.len()];
1130
1131 for (i, _) in query_tokens.iter().enumerate() {
1132 for (j, _) in key_tokens.iter().enumerate() {
1133 attention[i][j] = 1.0 / (key_tokens.len() as f32);
1135 }
1136 }
1137
1138 attention
1139 }
1140}
1141
1142#[cfg(test)]
1143mod tests {
1144 use super::*;
1145 use crate::char::CharTokenizer;
1146 use std::collections::HashMap;
1147
1148 fn create_test_char_tokenizer() -> CharTokenizer {
1149 let mut vocab = HashMap::new();
1150 vocab.insert("[PAD]".to_string(), 0);
1151 vocab.insert("[UNK]".to_string(), 1);
1152 vocab.insert("[CLS]".to_string(), 2);
1153 vocab.insert("[SEP]".to_string(), 3);
1154 vocab.insert("h".to_string(), 4);
1155 vocab.insert("e".to_string(), 5);
1156 vocab.insert("l".to_string(), 6);
1157 vocab.insert("o".to_string(), 7);
1158 vocab.insert("w".to_string(), 8);
1159 vocab.insert("r".to_string(), 9);
1160 vocab.insert("d".to_string(), 10);
1161 vocab.insert(" ".to_string(), 11);
1162 vocab.insert("t".to_string(), 12);
1163 vocab.insert("s".to_string(), 13);
1164 CharTokenizer::new(vocab)
1165 }
1166
1167 #[test]
1168 fn test_multimodal_config() {
1169 let config = MultimodalConfig::default();
1170 assert_eq!(config.max_text_length, Some(512));
1171 assert_eq!(config.max_image_patches, Some(196));
1172 assert!(config.include_special_tokens);
1173 }
1174
1175 #[test]
1176 fn test_multimodal_tokenizer_creation() {
1177 let text_tokenizer = create_test_char_tokenizer();
1178 let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1179
1180 assert!(multimodal_tokenizer.get_vocab().contains_key("[IMG]"));
1181 assert!(multimodal_tokenizer.get_vocab().contains_key("[AUD]"));
1182 }
1183
1184 #[test]
1185 fn test_text_only_tokenization() {
1186 let text_tokenizer = create_test_char_tokenizer();
1187 let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1188
1189 let input = MultimodalInput {
1190 text: Some("hello world".to_string()),
1191 image_patches: None,
1192 audio_frames: None,
1193 video_frames: None,
1194 table_data: None,
1195 graph_data: None,
1196 custom_data: None,
1197 };
1198
1199 let result = multimodal_tokenizer.tokenize_multimodal(&input);
1200 assert!(result.is_ok());
1201 let tokenized = result.expect("Operation failed in test");
1202 assert!(!tokenized.input_ids.is_empty());
1203 assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Text));
1204 }
1205
1206 #[test]
1207 fn test_image_patch_creation() {
1208 let patches = MultimodalUtils::create_image_patches(224, 224, 16);
1209 assert_eq!(patches.len(), 14 * 14); let first_patch = &patches[0];
1212 assert_eq!(first_patch.x, 0);
1213 assert_eq!(first_patch.y, 0);
1214 assert_eq!(first_patch.width, 16);
1215 assert_eq!(first_patch.height, 16);
1216 }
1217
1218 #[test]
1219 fn test_audio_frame_creation() {
1220 let frames = MultimodalUtils::create_audio_frames(44100.0, 1.0, 1024, 512);
1221 assert!(!frames.is_empty());
1222
1223 let first_frame = &frames[0];
1224 assert_eq!(first_frame.timestamp, 0.0);
1225 assert_eq!(first_frame.samples.len(), 1024);
1226 }
1227
1228 #[test]
1229 fn test_multimodal_input_with_images() {
1230 let text_tokenizer = create_test_char_tokenizer();
1231 let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1232
1233 let patches = vec![ImagePatch {
1234 x: 0,
1235 y: 0,
1236 width: 16,
1237 height: 16,
1238 pixels: vec![0.0; 16 * 16 * 3],
1239 embedding: Some(vec![1.0, 2.0, 3.0]),
1240 }];
1241
1242 let input = MultimodalInput {
1243 text: Some("An image".to_string()),
1244 image_patches: Some(patches),
1245 audio_frames: None,
1246 video_frames: None,
1247 table_data: None,
1248 graph_data: None,
1249 custom_data: None,
1250 };
1251
1252 let result = multimodal_tokenizer.tokenize_multimodal(&input);
1253 assert!(result.is_ok());
1254 let tokenized = result.expect("Operation failed in test");
1255 assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Text));
1256 assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Image));
1257 }
1258
1259 #[test]
1260 fn test_table_tokenization() {
1261 let text_tokenizer = create_test_char_tokenizer();
1262 let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1263
1264 let table = TableData {
1265 headers: vec!["Name".to_string(), "Age".to_string()],
1266 rows: vec![
1267 vec!["Alice".to_string(), "25".to_string()],
1268 vec!["Bob".to_string(), "30".to_string()],
1269 ],
1270 column_types: Some(vec!["string".to_string(), "int".to_string()]),
1271 };
1272
1273 let input = MultimodalInput {
1274 text: None,
1275 image_patches: None,
1276 audio_frames: None,
1277 video_frames: None,
1278 table_data: Some(table),
1279 graph_data: None,
1280 custom_data: None,
1281 };
1282
1283 let result = multimodal_tokenizer.tokenize_multimodal(&input);
1284 assert!(result.is_ok());
1285 let tokenized = result.expect("Operation failed in test");
1286 assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Table));
1287 }
1288
1289 #[test]
1290 fn test_fusion_strategies() {
1291 let text_tokenizer = create_test_char_tokenizer();
1292 let mut config = MultimodalConfig::default();
1293 config.fusion_strategy = FusionStrategy::Interleaved;
1294 let multimodal_tokenizer = MultimodalTokenizer::new(text_tokenizer, config);
1295
1296 let tokens = vec![
1297 MultimodalToken {
1298 token_id: 1,
1299 modality: ModalityType::Text,
1300 modality_position: 0,
1301 global_position: 0,
1302 metadata: None,
1303 },
1304 MultimodalToken {
1305 token_id: 2,
1306 modality: ModalityType::Image,
1307 modality_position: 0,
1308 global_position: 1,
1309 metadata: None,
1310 },
1311 ];
1312
1313 let result = multimodal_tokenizer.apply_fusion_strategy(&tokens);
1314 assert!(result.is_ok());
1315 }
1316
1317 #[test]
1318 fn test_convert_to_multimodal() {
1319 let tokenized = TokenizedInput {
1320 input_ids: vec![1, 2, 3],
1321 attention_mask: vec![1, 1, 1],
1322 token_type_ids: None,
1323 special_tokens_mask: None,
1324 offset_mapping: None,
1325 overflowing_tokens: None,
1326 };
1327
1328 let multimodal = MultimodalUtils::convert_to_multimodal(tokenized, ModalityType::Text);
1329 assert_eq!(multimodal.input_ids, vec![1, 2, 3]);
1330 assert_eq!(multimodal.modality_tokens.len(), 3);
1331 assert!(multimodal.modality_boundaries.contains_key(&ModalityType::Text));
1332 }
1333
1334 #[test]
1335 fn test_cross_modal_attention() {
1336 let tokens = vec![
1337 MultimodalToken {
1338 token_id: 1,
1339 modality: ModalityType::Text,
1340 modality_position: 0,
1341 global_position: 0,
1342 metadata: None,
1343 },
1344 MultimodalToken {
1345 token_id: 2,
1346 modality: ModalityType::Image,
1347 modality_position: 0,
1348 global_position: 1,
1349 metadata: None,
1350 },
1351 ];
1352
1353 let attention = MultimodalUtils::calculate_cross_modal_attention(
1354 &tokens,
1355 &ModalityType::Text,
1356 &ModalityType::Image,
1357 );
1358
1359 assert_eq!(attention.len(), 1); assert_eq!(attention[0].len(), 1); assert_eq!(attention[0][0], 1.0);
1362 }
1363}