Skip to main content

trustformers_tokenizers/
multimodal.rs

1//! Multimodal tokenization for TrustformeRS
2//!
3//! This module provides tokenization support for multimodal inputs including
4//! text + images, text + audio, and other cross-modal combinations.
5
6use crate::{TokenizedInput, Tokenizer};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use trustformers_core::errors::Result;
11
12/// Configuration for multimodal tokenizer
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct MultimodalConfig {
15    /// Maximum sequence length for text
16    pub max_text_length: Option<usize>,
17    /// Maximum number of image patches
18    pub max_image_patches: Option<usize>,
19    /// Maximum number of audio frames
20    pub max_audio_frames: Option<usize>,
21    /// Image patch size
22    pub image_patch_size: usize,
23    /// Audio frame size
24    pub audio_frame_size: usize,
25    /// Whether to include special multimodal tokens
26    pub include_special_tokens: bool,
27    /// Whether to use cross-modal attention
28    pub use_cross_modal_attention: bool,
29    /// Modality fusion strategy
30    pub fusion_strategy: FusionStrategy,
31    /// Text tokenizer configuration
32    pub text_tokenizer_config: Option<HashMap<String, String>>,
33    /// Vision tokenizer configuration
34    pub vision_tokenizer_config: Option<HashMap<String, String>>,
35    /// Audio tokenizer configuration
36    pub audio_tokenizer_config: Option<HashMap<String, String>>,
37}
38
39impl Default for MultimodalConfig {
40    fn default() -> Self {
41        Self {
42            max_text_length: Some(512),
43            max_image_patches: Some(196), // 14x14 patches for 224x224 image
44            max_audio_frames: Some(1000),
45            image_patch_size: 16,
46            audio_frame_size: 256,
47            include_special_tokens: true,
48            use_cross_modal_attention: true,
49            fusion_strategy: FusionStrategy::Concatenation,
50            text_tokenizer_config: None,
51            vision_tokenizer_config: None,
52            audio_tokenizer_config: None,
53        }
54    }
55}
56
57/// Fusion strategies for multimodal data
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum FusionStrategy {
60    /// Simple concatenation of modalities
61    Concatenation,
62    /// Interleaved modalities
63    Interleaved,
64    /// Cross-attention between modalities
65    CrossAttention,
66    /// Hierarchical fusion
67    Hierarchical,
68    /// Gate-based fusion
69    Gated,
70}
71
72/// Types of modalities
73#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub enum ModalityType {
75    Text,
76    Image,
77    Audio,
78    Video,
79    Depth,
80    PointCloud,
81    Graph,
82    Table,
83    Code,
84    Custom(String),
85}
86
87/// Multimodal token with modality information
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct MultimodalToken {
90    /// Token ID
91    pub token_id: u32,
92    /// Modality type
93    pub modality: ModalityType,
94    /// Position within modality
95    pub modality_position: usize,
96    /// Global position in sequence
97    pub global_position: usize,
98    /// Additional metadata
99    pub metadata: Option<MultimodalTokenMetadata>,
100}
101
102/// Metadata for multimodal tokens
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MultimodalTokenMetadata {
105    /// Spatial coordinates (for images/video)
106    pub spatial_coords: Option<(usize, usize)>,
107    /// Temporal coordinates (for audio/video)
108    pub temporal_coords: Option<f64>,
109    /// Channel information
110    pub channel: Option<usize>,
111    /// Confidence score
112    pub confidence: Option<f64>,
113    /// Feature vector
114    pub features: Option<Vec<f32>>,
115    /// Attention weights
116    pub attention_weights: Option<Vec<f32>>,
117}
118
119/// Image patch representation
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ImagePatch {
122    /// Patch coordinates
123    pub x: usize,
124    pub y: usize,
125    /// Patch size
126    pub width: usize,
127    pub height: usize,
128    /// Flattened pixel values
129    pub pixels: Vec<f32>,
130    /// Patch embedding
131    pub embedding: Option<Vec<f32>>,
132}
133
134/// Audio frame representation
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct AudioFrame {
137    /// Frame timestamp
138    pub timestamp: f64,
139    /// Frame duration
140    pub duration: f64,
141    /// Audio samples
142    pub samples: Vec<f32>,
143    /// Spectral features
144    pub features: Option<Vec<f32>>,
145}
146
147/// Video frame representation
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct VideoFrame {
150    /// Frame number
151    pub frame_number: usize,
152    /// Timestamp
153    pub timestamp: f64,
154    /// Image patches
155    pub patches: Vec<ImagePatch>,
156    /// Motion vectors
157    pub motion_vectors: Option<Vec<(f32, f32)>>,
158}
159
160/// Multimodal input containing different modalities
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct MultimodalInput {
163    /// Text content
164    pub text: Option<String>,
165    /// Image patches
166    pub image_patches: Option<Vec<ImagePatch>>,
167    /// Audio frames
168    pub audio_frames: Option<Vec<AudioFrame>>,
169    /// Video frames
170    pub video_frames: Option<Vec<VideoFrame>>,
171    /// Table data
172    pub table_data: Option<TableData>,
173    /// Graph structure
174    pub graph_data: Option<GraphData>,
175    /// Custom modality data
176    pub custom_data: Option<HashMap<String, Vec<u8>>>,
177}
178
179/// Table data representation
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct TableData {
182    /// Column headers
183    pub headers: Vec<String>,
184    /// Table rows
185    pub rows: Vec<Vec<String>>,
186    /// Column types
187    pub column_types: Option<Vec<String>>,
188}
189
190/// Graph data representation
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GraphData {
193    /// Node features
194    pub nodes: Vec<Vec<f32>>,
195    /// Edge list
196    pub edges: Vec<(usize, usize)>,
197    /// Edge features
198    pub edge_features: Option<Vec<Vec<f32>>>,
199    /// Node labels
200    pub node_labels: Option<Vec<String>>,
201}
202
203/// Tokenized multimodal output
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct MultimodalTokenizedInput {
206    /// All token IDs in sequence
207    pub input_ids: Vec<u32>,
208    /// Attention mask
209    pub attention_mask: Option<Vec<u32>>,
210    /// Token type IDs (modality indicators)
211    pub token_type_ids: Option<Vec<u32>>,
212    /// Modality tokens with metadata
213    pub modality_tokens: Vec<MultimodalToken>,
214    /// Modality boundaries
215    pub modality_boundaries: HashMap<ModalityType, (usize, usize)>,
216    /// Cross-modal attention matrix
217    pub cross_modal_attention: Option<Vec<Vec<f32>>>,
218}
219
220/// Multimodal tokenizer implementation
221pub struct MultimodalTokenizer<T: Tokenizer> {
222    text_tokenizer: Arc<T>,
223    config: MultimodalConfig,
224    vocab: HashMap<String, u32>,
225    id_to_token: HashMap<u32, String>,
226    next_id: u32,
227    modality_token_ids: HashMap<ModalityType, u32>,
228}
229
230impl<T: Tokenizer> MultimodalTokenizer<T> {
231    /// Create a new multimodal tokenizer
232    pub fn new(text_tokenizer: T, config: MultimodalConfig) -> Self {
233        let mut tokenizer = Self {
234            text_tokenizer: Arc::new(text_tokenizer),
235            config,
236            vocab: HashMap::new(),
237            id_to_token: HashMap::new(),
238            next_id: 0,
239            modality_token_ids: HashMap::new(),
240        };
241
242        tokenizer.initialize_vocab();
243        tokenizer
244    }
245
246    /// Create with default configuration
247    pub fn from_text_tokenizer(text_tokenizer: T) -> Self {
248        Self::new(text_tokenizer, MultimodalConfig::default())
249    }
250
251    /// Initialize vocabulary with multimodal tokens
252    fn initialize_vocab(&mut self) {
253        // Add special tokens
254        if self.config.include_special_tokens {
255            self.add_token("[CLS]");
256            self.add_token("[SEP]");
257            self.add_token("[PAD]");
258            self.add_token("[UNK]");
259            self.add_token("[MASK]");
260        }
261
262        // Add modality-specific tokens
263        let modality_tokens = vec![
264            (ModalityType::Text, "[TEXT]"),
265            (ModalityType::Image, "[IMG]"),
266            (ModalityType::Audio, "[AUD]"),
267            (ModalityType::Video, "[VID]"),
268            (ModalityType::Table, "[TAB]"),
269            (ModalityType::Graph, "[GRF]"),
270            (ModalityType::Code, "[COD]"),
271        ];
272
273        for (modality, token) in modality_tokens {
274            let token_id = self.add_token(token);
275            self.modality_token_ids.insert(modality, token_id);
276        }
277
278        // Add patch and frame tokens
279        for i in 0..self.config.max_image_patches.unwrap_or(196) {
280            self.add_token(&format!("[PATCH_{}]", i));
281        }
282
283        for i in 0..self.config.max_audio_frames.unwrap_or(1000) {
284            self.add_token(&format!("[FRAME_{}]", i));
285        }
286
287        // Add fusion tokens
288        self.add_token("[FUSE]");
289        self.add_token("[CROSS_ATTN]");
290        self.add_token("[MODAL_SEP]");
291    }
292
293    /// Add token to vocabulary
294    fn add_token(&mut self, token: &str) -> u32 {
295        if let Some(&id) = self.vocab.get(token) {
296            return id;
297        }
298
299        let id = self.next_id;
300        self.vocab.insert(token.to_string(), id);
301        self.id_to_token.insert(id, token.to_string());
302        self.next_id += 1;
303        id
304    }
305
306    /// Tokenize multimodal input
307    pub fn tokenize_multimodal(&self, input: &MultimodalInput) -> Result<MultimodalTokenizedInput> {
308        let mut all_tokens = Vec::new();
309        let mut modality_boundaries = HashMap::new();
310        let mut current_position = 0;
311
312        // Tokenize text
313        if let Some(ref text) = input.text {
314            let text_tokens = self.tokenize_text(text, current_position)?;
315            let start_pos = current_position;
316            all_tokens.extend(text_tokens);
317            let end_pos = all_tokens.len();
318            modality_boundaries.insert(ModalityType::Text, (start_pos, end_pos));
319            current_position = end_pos;
320        }
321
322        // Tokenize image patches
323        if let Some(ref patches) = input.image_patches {
324            let image_tokens = self.tokenize_image_patches(patches, current_position)?;
325            let start_pos = current_position;
326            all_tokens.extend(image_tokens);
327            let end_pos = all_tokens.len();
328            modality_boundaries.insert(ModalityType::Image, (start_pos, end_pos));
329            current_position = end_pos;
330        }
331
332        // Tokenize audio frames
333        if let Some(ref frames) = input.audio_frames {
334            let audio_tokens = self.tokenize_audio_frames(frames, current_position)?;
335            let start_pos = current_position;
336            all_tokens.extend(audio_tokens);
337            let end_pos = all_tokens.len();
338            modality_boundaries.insert(ModalityType::Audio, (start_pos, end_pos));
339            current_position = end_pos;
340        }
341
342        // Tokenize video frames
343        if let Some(ref frames) = input.video_frames {
344            let video_tokens = self.tokenize_video_frames(frames, current_position)?;
345            let start_pos = current_position;
346            all_tokens.extend(video_tokens);
347            let end_pos = all_tokens.len();
348            modality_boundaries.insert(ModalityType::Video, (start_pos, end_pos));
349            current_position = end_pos;
350        }
351
352        // Tokenize table data
353        if let Some(ref table) = input.table_data {
354            let table_tokens = self.tokenize_table(table, current_position)?;
355            let start_pos = current_position;
356            all_tokens.extend(table_tokens);
357            let end_pos = all_tokens.len();
358            modality_boundaries.insert(ModalityType::Table, (start_pos, end_pos));
359            let _ = end_pos; // Track position for potential future use
360        }
361
362        // Apply fusion strategy
363        let fused_tokens = self.apply_fusion_strategy(&all_tokens)?;
364
365        // Create input IDs and other outputs
366        let input_ids: Vec<u32> = fused_tokens.iter().map(|t| t.token_id).collect();
367        let attention_mask = Some(vec![1u32; input_ids.len()]);
368        let token_type_ids =
369            Some(fused_tokens.iter().map(|t| self.get_modality_type_id(&t.modality)).collect());
370
371        Ok(MultimodalTokenizedInput {
372            input_ids,
373            attention_mask,
374            token_type_ids,
375            modality_tokens: fused_tokens,
376            modality_boundaries,
377            cross_modal_attention: None, // Would be computed during model forward pass
378        })
379    }
380
381    /// Tokenize text using the underlying text tokenizer
382    fn tokenize_text(&self, text: &str, start_position: usize) -> Result<Vec<MultimodalToken>> {
383        let text_tokenized = self.text_tokenizer.encode(text)?;
384        let mut tokens = Vec::new();
385
386        for (i, &token_id) in text_tokenized.input_ids.iter().enumerate() {
387            tokens.push(MultimodalToken {
388                token_id,
389                modality: ModalityType::Text,
390                modality_position: i,
391                global_position: start_position + i,
392                metadata: None,
393            });
394        }
395
396        Ok(tokens)
397    }
398
399    /// Tokenize image patches
400    fn tokenize_image_patches(
401        &self,
402        patches: &[ImagePatch],
403        start_position: usize,
404    ) -> Result<Vec<MultimodalToken>> {
405        let mut tokens = Vec::new();
406
407        // Add image start token
408        if let Some(&img_token_id) = self.modality_token_ids.get(&ModalityType::Image) {
409            tokens.push(MultimodalToken {
410                token_id: img_token_id,
411                modality: ModalityType::Image,
412                modality_position: 0,
413                global_position: start_position,
414                metadata: None,
415            });
416        }
417
418        // Add patch tokens
419        for (i, patch) in patches.iter().enumerate() {
420            let patch_token = format!("[PATCH_{}]", i);
421            if let Some(&token_id) = self.vocab.get(&patch_token) {
422                let metadata = Some(MultimodalTokenMetadata {
423                    spatial_coords: Some((patch.x, patch.y)),
424                    temporal_coords: None,
425                    channel: None,
426                    confidence: None,
427                    features: patch.embedding.clone(),
428                    attention_weights: None,
429                });
430
431                tokens.push(MultimodalToken {
432                    token_id,
433                    modality: ModalityType::Image,
434                    modality_position: i + 1,
435                    global_position: start_position + tokens.len(),
436                    metadata,
437                });
438            }
439
440            // Stop if we reach max patches
441            if tokens.len() >= self.config.max_image_patches.unwrap_or(196) {
442                break;
443            }
444        }
445
446        Ok(tokens)
447    }
448
449    /// Tokenize audio frames
450    fn tokenize_audio_frames(
451        &self,
452        frames: &[AudioFrame],
453        start_position: usize,
454    ) -> Result<Vec<MultimodalToken>> {
455        let mut tokens = Vec::new();
456
457        // Add audio start token
458        if let Some(&aud_token_id) = self.modality_token_ids.get(&ModalityType::Audio) {
459            tokens.push(MultimodalToken {
460                token_id: aud_token_id,
461                modality: ModalityType::Audio,
462                modality_position: 0,
463                global_position: start_position,
464                metadata: None,
465            });
466        }
467
468        // Add frame tokens
469        for (i, frame) in frames.iter().enumerate() {
470            let frame_token = format!("[FRAME_{}]", i);
471            if let Some(&token_id) = self.vocab.get(&frame_token) {
472                let metadata = Some(MultimodalTokenMetadata {
473                    spatial_coords: None,
474                    temporal_coords: Some(frame.timestamp),
475                    channel: None,
476                    confidence: None,
477                    features: frame.features.clone(),
478                    attention_weights: None,
479                });
480
481                tokens.push(MultimodalToken {
482                    token_id,
483                    modality: ModalityType::Audio,
484                    modality_position: i + 1,
485                    global_position: start_position + tokens.len(),
486                    metadata,
487                });
488            }
489
490            // Stop if we reach max frames
491            if tokens.len() >= self.config.max_audio_frames.unwrap_or(1000) {
492                break;
493            }
494        }
495
496        Ok(tokens)
497    }
498
499    /// Tokenize video frames
500    fn tokenize_video_frames(
501        &self,
502        frames: &[VideoFrame],
503        start_position: usize,
504    ) -> Result<Vec<MultimodalToken>> {
505        let mut tokens = Vec::new();
506
507        // Add video start token
508        if let Some(&vid_token_id) = self.modality_token_ids.get(&ModalityType::Video) {
509            tokens.push(MultimodalToken {
510                token_id: vid_token_id,
511                modality: ModalityType::Video,
512                modality_position: 0,
513                global_position: start_position,
514                metadata: None,
515            });
516        }
517
518        // Tokenize each frame as image patches
519        for (frame_idx, frame) in frames.iter().enumerate() {
520            for (patch_idx, patch) in frame.patches.iter().enumerate() {
521                let patch_token = format!("[PATCH_{}]", patch_idx);
522                if let Some(&token_id) = self.vocab.get(&patch_token) {
523                    let metadata = Some(MultimodalTokenMetadata {
524                        spatial_coords: Some((patch.x, patch.y)),
525                        temporal_coords: Some(frame.timestamp),
526                        channel: Some(frame_idx),
527                        confidence: None,
528                        features: patch.embedding.clone(),
529                        attention_weights: None,
530                    });
531
532                    tokens.push(MultimodalToken {
533                        token_id,
534                        modality: ModalityType::Video,
535                        modality_position: tokens.len(),
536                        global_position: start_position + tokens.len(),
537                        metadata,
538                    });
539                }
540            }
541        }
542
543        Ok(tokens)
544    }
545
546    /// Tokenize table data
547    fn tokenize_table(
548        &self,
549        table: &TableData,
550        start_position: usize,
551    ) -> Result<Vec<MultimodalToken>> {
552        let mut tokens = Vec::new();
553
554        // Add table start token
555        if let Some(&tab_token_id) = self.modality_token_ids.get(&ModalityType::Table) {
556            tokens.push(MultimodalToken {
557                token_id: tab_token_id,
558                modality: ModalityType::Table,
559                modality_position: 0,
560                global_position: start_position,
561                metadata: None,
562            });
563        }
564
565        // Tokenize headers and rows as text
566        let mut table_text = table.headers.join(" | ");
567        for row in &table.rows {
568            table_text.push_str(" | ");
569            table_text.push_str(&row.join(" | "));
570        }
571
572        let text_tokens = self.text_tokenizer.encode(&table_text)?;
573        for (i, &token_id) in text_tokens.input_ids.iter().enumerate() {
574            tokens.push(MultimodalToken {
575                token_id,
576                modality: ModalityType::Table,
577                modality_position: i + 1,
578                global_position: start_position + tokens.len(),
579                metadata: None,
580            });
581        }
582
583        Ok(tokens)
584    }
585
586    /// Apply fusion strategy to tokens
587    fn apply_fusion_strategy(&self, tokens: &[MultimodalToken]) -> Result<Vec<MultimodalToken>> {
588        match self.config.fusion_strategy {
589            FusionStrategy::Concatenation => Ok(tokens.to_vec()),
590            FusionStrategy::Interleaved => self.apply_interleaved_fusion(tokens),
591            FusionStrategy::CrossAttention => self.apply_cross_attention_fusion(tokens),
592            FusionStrategy::Hierarchical => self.apply_hierarchical_fusion(tokens),
593            FusionStrategy::Gated => self.apply_gated_fusion(tokens),
594        }
595    }
596
597    /// Apply interleaved fusion
598    fn apply_interleaved_fusion(&self, tokens: &[MultimodalToken]) -> Result<Vec<MultimodalToken>> {
599        // Group tokens by modality
600        let mut modality_groups: HashMap<ModalityType, Vec<&MultimodalToken>> = HashMap::new();
601        for token in tokens {
602            modality_groups.entry(token.modality.clone()).or_default().push(token);
603        }
604
605        // Interleave tokens from different modalities
606        let mut result = Vec::new();
607        let max_len = modality_groups.values().map(|v| v.len()).max().unwrap_or(0);
608
609        for i in 0..max_len {
610            for group in modality_groups.values() {
611                if let Some(token) = group.get(i) {
612                    result.push((*token).clone());
613                }
614            }
615        }
616
617        Ok(result)
618    }
619
620    /// Apply cross-attention fusion
621    fn apply_cross_attention_fusion(
622        &self,
623        tokens: &[MultimodalToken],
624    ) -> Result<Vec<MultimodalToken>> {
625        // Group tokens by modality
626        let mut modality_groups: HashMap<ModalityType, Vec<&MultimodalToken>> = HashMap::new();
627        for token in tokens {
628            modality_groups.entry(token.modality.clone()).or_default().push(token);
629        }
630
631        // If we have less than 2 modalities, no cross-attention needed
632        if modality_groups.len() < 2 {
633            return Ok(tokens.to_vec());
634        }
635
636        let mut result = Vec::new();
637        let modalities: Vec<_> = modality_groups.keys().cloned().collect();
638
639        // Add cross-attention token between modality groups
640        if let Some(&cross_attn_token_id) = self.vocab.get("[CROSS_ATTN]") {
641            for (i, (modality, group)) in modality_groups.iter().enumerate() {
642                // Add original tokens from this modality
643                for token in group {
644                    let mut enhanced_token = (*token).clone();
645
646                    // Calculate attention weights with other modalities
647                    let mut attention_weights = Vec::new();
648                    for (j, other_modality) in modalities.iter().enumerate() {
649                        if i != j {
650                            // Simple attention score based on position and modality compatibility
651                            let attention_score = self.calculate_cross_modal_attention_score(
652                                modality,
653                                other_modality,
654                                token.modality_position,
655                            );
656                            attention_weights.push(attention_score);
657                        }
658                    }
659
660                    // Update token metadata with attention weights
661                    if let Some(ref mut metadata) = enhanced_token.metadata {
662                        metadata.attention_weights = Some(attention_weights);
663                    } else {
664                        enhanced_token.metadata = Some(MultimodalTokenMetadata {
665                            spatial_coords: None,
666                            temporal_coords: None,
667                            channel: None,
668                            confidence: None,
669                            features: None,
670                            attention_weights: Some(attention_weights),
671                        });
672                    }
673
674                    result.push(enhanced_token);
675                }
676
677                // Add cross-attention separator between modalities (except last)
678                if i < modality_groups.len() - 1 {
679                    result.push(MultimodalToken {
680                        token_id: cross_attn_token_id,
681                        modality: ModalityType::Custom("cross_attention".to_string()),
682                        modality_position: 0,
683                        global_position: result.len(),
684                        metadata: None,
685                    });
686                }
687            }
688        } else {
689            // If cross-attention token not available, just return original tokens
690            result = tokens.to_vec();
691        }
692
693        Ok(result)
694    }
695
696    /// Calculate cross-modal attention score between two modalities
697    fn calculate_cross_modal_attention_score(
698        &self,
699        source_modality: &ModalityType,
700        target_modality: &ModalityType,
701        position: usize,
702    ) -> f32 {
703        // Base attention score based on modality compatibility
704        let base_score = match (source_modality, target_modality) {
705            // Text-Image interactions are typically strong
706            (ModalityType::Text, ModalityType::Image)
707            | (ModalityType::Image, ModalityType::Text) => 0.8,
708            // Text-Audio interactions
709            (ModalityType::Text, ModalityType::Audio)
710            | (ModalityType::Audio, ModalityType::Text) => 0.7,
711            // Image-Video interactions are very strong
712            (ModalityType::Image, ModalityType::Video)
713            | (ModalityType::Video, ModalityType::Image) => 0.9,
714            // Audio-Video interactions for multimedia content
715            (ModalityType::Audio, ModalityType::Video)
716            | (ModalityType::Video, ModalityType::Audio) => 0.75,
717            // Table-Text interactions for structured data
718            (ModalityType::Table, ModalityType::Text)
719            | (ModalityType::Text, ModalityType::Table) => 0.6,
720            // Code-Text interactions
721            (ModalityType::Code, ModalityType::Text) | (ModalityType::Text, ModalityType::Code) => {
722                0.65
723            },
724            // Graph-Table interactions for structured data
725            (ModalityType::Graph, ModalityType::Table)
726            | (ModalityType::Table, ModalityType::Graph) => 0.7,
727            // Same modality gets moderate attention
728            (a, b) if a == b => 0.5,
729            // Default for other combinations
730            _ => 0.4,
731        };
732
733        // Position-based attention decay (closer positions get higher attention)
734        let position_factor = 1.0 / (1.0 + (position as f32 * 0.1));
735
736        // Combine base score with position factor
737        base_score * position_factor
738    }
739
740    /// Apply hierarchical fusion
741    fn apply_hierarchical_fusion(
742        &self,
743        tokens: &[MultimodalToken],
744    ) -> Result<Vec<MultimodalToken>> {
745        // Group by modality and add fusion tokens between groups
746        let mut result = Vec::new();
747        let mut current_modality = None;
748
749        if let Some(&fuse_token_id) = self.vocab.get("[FUSE]") {
750            for token in tokens {
751                if current_modality.is_some() && current_modality.as_ref() != Some(&token.modality)
752                {
753                    // Add fusion token between modalities
754                    result.push(MultimodalToken {
755                        token_id: fuse_token_id,
756                        modality: ModalityType::Custom("fusion".to_string()),
757                        modality_position: 0,
758                        global_position: result.len(),
759                        metadata: None,
760                    });
761                }
762                result.push(token.clone());
763                current_modality = Some(token.modality.clone());
764            }
765        }
766
767        Ok(result)
768    }
769
770    /// Apply gated fusion
771    fn apply_gated_fusion(&self, tokens: &[MultimodalToken]) -> Result<Vec<MultimodalToken>> {
772        // Group tokens by modality
773        let mut modality_groups: HashMap<ModalityType, Vec<&MultimodalToken>> = HashMap::new();
774        for token in tokens {
775            modality_groups.entry(token.modality.clone()).or_default().push(token);
776        }
777
778        // If we have less than 2 modalities, no gating needed
779        if modality_groups.len() < 2 {
780            return Ok(tokens.to_vec());
781        }
782
783        let mut result = Vec::new();
784
785        // Calculate gate weights for each modality based on content characteristics
786        let mut modality_gates: HashMap<ModalityType, f32> = HashMap::new();
787        for (modality, group) in &modality_groups {
788            let gate_weight = self.calculate_modality_gate_weight(modality, group);
789            modality_gates.insert(modality.clone(), gate_weight);
790        }
791
792        // Normalize gate weights so they sum to 1.0
793        let total_weight: f32 = modality_gates.values().sum();
794        if total_weight > 0.0 {
795            for weight in modality_gates.values_mut() {
796                *weight /= total_weight;
797            }
798        }
799
800        // Apply gated fusion by weighting tokens based on modality gates
801        for (modality, group) in &modality_groups {
802            let gate_weight = modality_gates.get(modality).copied().unwrap_or(0.0);
803
804            for token in group {
805                let mut gated_token = (*token).clone();
806
807                // Calculate confidence based on gate weight and token characteristics
808                let token_confidence = self.calculate_token_confidence(token, gate_weight);
809
810                // Update token metadata with gate information
811                if let Some(ref mut metadata) = gated_token.metadata {
812                    metadata.confidence = Some(token_confidence as f64);
813                } else {
814                    gated_token.metadata = Some(MultimodalTokenMetadata {
815                        spatial_coords: None,
816                        temporal_coords: None,
817                        channel: None,
818                        confidence: Some(token_confidence as f64),
819                        features: None,
820                        attention_weights: None,
821                    });
822                }
823
824                // Only include tokens that pass the gating threshold
825                if token_confidence > 0.1 {
826                    result.push(gated_token);
827                }
828            }
829        }
830
831        // Sort tokens by confidence (highest first) to prioritize important content
832        result.sort_by(|a, b| {
833            let conf_a = a.metadata.as_ref().and_then(|m| m.confidence).unwrap_or(0.0);
834            let conf_b = b.metadata.as_ref().and_then(|m| m.confidence).unwrap_or(0.0);
835            conf_b.partial_cmp(&conf_a).unwrap_or(std::cmp::Ordering::Equal)
836        });
837
838        // Update global positions after sorting
839        for (i, token) in result.iter_mut().enumerate() {
840            token.global_position = i;
841        }
842
843        Ok(result)
844    }
845
846    /// Calculate gate weight for a modality based on its tokens
847    fn calculate_modality_gate_weight(
848        &self,
849        modality: &ModalityType,
850        tokens: &[&MultimodalToken],
851    ) -> f32 {
852        if tokens.is_empty() {
853            return 0.0;
854        }
855
856        // Base weight based on modality importance
857        let base_weight = match modality {
858            ModalityType::Text => 1.0,      // Text is usually most important
859            ModalityType::Image => 0.8,     // Images are highly informative
860            ModalityType::Video => 0.9,     // Video combines visual and temporal info
861            ModalityType::Audio => 0.7,     // Audio adds complementary information
862            ModalityType::Table => 0.6,     // Structured data is valuable but less dynamic
863            ModalityType::Graph => 0.65,    // Graph data is structured and informative
864            ModalityType::Code => 0.75,     // Code has high semantic value
865            ModalityType::Custom(_) => 0.5, // Custom modalities get moderate weight
866            _ => 0.4,                       // Other modalities get lower weight
867        };
868
869        // Factor in the number of tokens (more tokens might indicate more importance)
870        let token_count_factor = (tokens.len() as f32).sqrt() / 10.0;
871
872        // Factor in feature richness (tokens with more metadata are more informative)
873        let feature_richness = tokens
874            .iter()
875            .map(|token| {
876                if let Some(metadata) = &token.metadata {
877                    let mut richness = 0.0;
878                    if metadata.spatial_coords.is_some() {
879                        richness += 0.2;
880                    }
881                    if metadata.temporal_coords.is_some() {
882                        richness += 0.2;
883                    }
884                    if metadata.features.is_some() {
885                        richness += 0.4;
886                    }
887                    if metadata.confidence.is_some() {
888                        richness += 0.2;
889                    }
890                    richness
891                } else {
892                    0.1 // Base richness for tokens without metadata
893                }
894            })
895            .sum::<f32>()
896            / tokens.len() as f32;
897
898        // Combine factors
899        base_weight * (1.0 + token_count_factor) * (1.0 + feature_richness)
900    }
901
902    /// Calculate confidence for a token based on gate weight and token characteristics
903    fn calculate_token_confidence(&self, token: &MultimodalToken, gate_weight: f32) -> f32 {
904        // Start with the gate weight as base confidence
905        let mut confidence = gate_weight;
906
907        // Factor in token-specific characteristics
908        if let Some(metadata) = &token.metadata {
909            // Spatial information increases confidence for visual modalities
910            if metadata.spatial_coords.is_some()
911                && matches!(token.modality, ModalityType::Image | ModalityType::Video)
912            {
913                confidence *= 1.2;
914            }
915
916            // Temporal information increases confidence for temporal modalities
917            if metadata.temporal_coords.is_some()
918                && matches!(token.modality, ModalityType::Audio | ModalityType::Video)
919            {
920                confidence *= 1.15;
921            }
922
923            // Feature vectors indicate processed/meaningful content
924            if let Some(features) = &metadata.features {
925                let feature_magnitude =
926                    features.iter().map(|f| f.abs()).sum::<f32>() / features.len() as f32;
927                confidence *= 1.0 + (feature_magnitude * 0.1);
928            }
929
930            // Existing confidence scores should be respected
931            if let Some(existing_confidence) = metadata.confidence {
932                confidence = (confidence + existing_confidence as f32) / 2.0;
933            }
934        }
935
936        // Position-based confidence (earlier tokens might be more important)
937        let position_factor = 1.0 / (1.0 + (token.modality_position as f32 * 0.05));
938        confidence *= position_factor;
939
940        // Ensure confidence is in valid range [0, 1]
941        confidence.clamp(0.0, 1.0)
942    }
943
944    /// Get modality type ID for token type IDs
945    fn get_modality_type_id(&self, modality: &ModalityType) -> u32 {
946        match modality {
947            ModalityType::Text => 0,
948            ModalityType::Image => 1,
949            ModalityType::Audio => 2,
950            ModalityType::Video => 3,
951            ModalityType::Table => 4,
952            ModalityType::Graph => 5,
953            ModalityType::Code => 6,
954            ModalityType::Custom(_) => 7,
955            _ => 0,
956        }
957    }
958
959    /// Get configuration
960    pub fn config(&self) -> &MultimodalConfig {
961        &self.config
962    }
963
964    /// Get vocabulary
965    pub fn get_vocab(&self) -> &HashMap<String, u32> {
966        &self.vocab
967    }
968
969    /// Get underlying text tokenizer
970    pub fn text_tokenizer(&self) -> &T {
971        &self.text_tokenizer
972    }
973}
974
975impl<T: Tokenizer> Tokenizer for MultimodalTokenizer<T> {
976    fn encode(&self, text: &str) -> Result<TokenizedInput> {
977        // For plain text, just use the underlying text tokenizer
978        self.text_tokenizer.encode(text)
979    }
980
981    fn decode(&self, token_ids: &[u32]) -> Result<String> {
982        // Filter out special multimodal tokens and decode text tokens
983        let text_tokens: Vec<u32> = token_ids
984            .iter()
985            .copied()
986            .filter(|&id| {
987                if let Some(token) = self.id_to_token.get(&id) {
988                    !token.starts_with('[') || !token.ends_with(']')
989                } else {
990                    true
991                }
992            })
993            .collect();
994
995        self.text_tokenizer.decode(&text_tokens)
996    }
997
998    fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
999        // For text pairs, use the underlying text tokenizer
1000        self.text_tokenizer.encode_pair(text_a, text_b)
1001    }
1002
1003    fn vocab_size(&self) -> usize {
1004        self.text_tokenizer.vocab_size() + self.vocab.len()
1005    }
1006
1007    fn get_vocab(&self) -> HashMap<String, u32> {
1008        let mut vocab = self.text_tokenizer.get_vocab();
1009        for (token, &id) in &self.vocab {
1010            vocab.insert(token.clone(), id);
1011        }
1012        vocab
1013    }
1014
1015    fn token_to_id(&self, token: &str) -> Option<u32> {
1016        self.vocab
1017            .get(token)
1018            .copied()
1019            .or_else(|| self.text_tokenizer.token_to_id(token))
1020    }
1021
1022    fn id_to_token(&self, id: u32) -> Option<String> {
1023        self.id_to_token
1024            .get(&id)
1025            .cloned()
1026            .or_else(|| self.text_tokenizer.id_to_token(id))
1027    }
1028}
1029
1030/// Utilities for multimodal tokenization
1031pub struct MultimodalUtils;
1032
1033impl MultimodalUtils {
1034    /// Create image patches from image dimensions
1035    pub fn create_image_patches(
1036        image_width: usize,
1037        image_height: usize,
1038        patch_size: usize,
1039    ) -> Vec<ImagePatch> {
1040        let mut patches = Vec::new();
1041
1042        for y in (0..image_height).step_by(patch_size) {
1043            for x in (0..image_width).step_by(patch_size) {
1044                let width = (patch_size).min(image_width - x);
1045                let height = (patch_size).min(image_height - y);
1046
1047                patches.push(ImagePatch {
1048                    x,
1049                    y,
1050                    width,
1051                    height,
1052                    pixels: vec![0.0; width * height * 3], // RGB
1053                    embedding: None,
1054                });
1055            }
1056        }
1057
1058        patches
1059    }
1060
1061    /// Create audio frames from audio parameters
1062    pub fn create_audio_frames(
1063        sample_rate: f64,
1064        duration: f64,
1065        frame_size: usize,
1066        hop_size: usize,
1067    ) -> Vec<AudioFrame> {
1068        let mut frames = Vec::new();
1069        let total_samples = (sample_rate * duration) as usize;
1070
1071        for start in (0..total_samples).step_by(hop_size) {
1072            let end = (start + frame_size).min(total_samples);
1073            let timestamp = start as f64 / sample_rate;
1074            let frame_duration = (end - start) as f64 / sample_rate;
1075
1076            frames.push(AudioFrame {
1077                timestamp,
1078                duration: frame_duration,
1079                samples: vec![0.0; end - start],
1080                features: None,
1081            });
1082        }
1083
1084        frames
1085    }
1086
1087    /// Convert tokenized input to multimodal format
1088    pub fn convert_to_multimodal(
1089        tokenized: TokenizedInput,
1090        modality: ModalityType,
1091    ) -> MultimodalTokenizedInput {
1092        let modality_tokens: Vec<MultimodalToken> = tokenized
1093            .input_ids
1094            .into_iter()
1095            .enumerate()
1096            .map(|(i, token_id)| MultimodalToken {
1097                token_id,
1098                modality: modality.clone(),
1099                modality_position: i,
1100                global_position: i,
1101                metadata: None,
1102            })
1103            .collect();
1104
1105        let mut boundaries = HashMap::new();
1106        boundaries.insert(modality, (0, modality_tokens.len()));
1107
1108        MultimodalTokenizedInput {
1109            input_ids: modality_tokens.iter().map(|t| t.token_id).collect(),
1110            attention_mask: Some(tokenized.attention_mask.into_iter().map(|x| x as u32).collect()),
1111            token_type_ids: tokenized.token_type_ids,
1112            modality_tokens,
1113            modality_boundaries: boundaries,
1114            cross_modal_attention: None,
1115        }
1116    }
1117
1118    /// Calculate cross-modal attention matrix
1119    pub fn calculate_cross_modal_attention(
1120        tokens: &[MultimodalToken],
1121        query_modality: &ModalityType,
1122        key_modality: &ModalityType,
1123    ) -> Vec<Vec<f32>> {
1124        let query_tokens: Vec<_> =
1125            tokens.iter().filter(|t| &t.modality == query_modality).collect();
1126        let key_tokens: Vec<_> = tokens.iter().filter(|t| &t.modality == key_modality).collect();
1127
1128        // Placeholder attention calculation
1129        let mut attention = vec![vec![0.0; key_tokens.len()]; query_tokens.len()];
1130
1131        for (i, _) in query_tokens.iter().enumerate() {
1132            for (j, _) in key_tokens.iter().enumerate() {
1133                // Simplified attention score
1134                attention[i][j] = 1.0 / (key_tokens.len() as f32);
1135            }
1136        }
1137
1138        attention
1139    }
1140}
1141
1142#[cfg(test)]
1143mod tests {
1144    use super::*;
1145    use crate::char::CharTokenizer;
1146    use std::collections::HashMap;
1147
1148    fn create_test_char_tokenizer() -> CharTokenizer {
1149        let mut vocab = HashMap::new();
1150        vocab.insert("[PAD]".to_string(), 0);
1151        vocab.insert("[UNK]".to_string(), 1);
1152        vocab.insert("[CLS]".to_string(), 2);
1153        vocab.insert("[SEP]".to_string(), 3);
1154        vocab.insert("h".to_string(), 4);
1155        vocab.insert("e".to_string(), 5);
1156        vocab.insert("l".to_string(), 6);
1157        vocab.insert("o".to_string(), 7);
1158        vocab.insert("w".to_string(), 8);
1159        vocab.insert("r".to_string(), 9);
1160        vocab.insert("d".to_string(), 10);
1161        vocab.insert(" ".to_string(), 11);
1162        vocab.insert("t".to_string(), 12);
1163        vocab.insert("s".to_string(), 13);
1164        CharTokenizer::new(vocab)
1165    }
1166
1167    #[test]
1168    fn test_multimodal_config() {
1169        let config = MultimodalConfig::default();
1170        assert_eq!(config.max_text_length, Some(512));
1171        assert_eq!(config.max_image_patches, Some(196));
1172        assert!(config.include_special_tokens);
1173    }
1174
1175    #[test]
1176    fn test_multimodal_tokenizer_creation() {
1177        let text_tokenizer = create_test_char_tokenizer();
1178        let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1179
1180        assert!(multimodal_tokenizer.get_vocab().contains_key("[IMG]"));
1181        assert!(multimodal_tokenizer.get_vocab().contains_key("[AUD]"));
1182    }
1183
1184    #[test]
1185    fn test_text_only_tokenization() {
1186        let text_tokenizer = create_test_char_tokenizer();
1187        let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1188
1189        let input = MultimodalInput {
1190            text: Some("hello world".to_string()),
1191            image_patches: None,
1192            audio_frames: None,
1193            video_frames: None,
1194            table_data: None,
1195            graph_data: None,
1196            custom_data: None,
1197        };
1198
1199        let result = multimodal_tokenizer.tokenize_multimodal(&input);
1200        assert!(result.is_ok());
1201        let tokenized = result.expect("Operation failed in test");
1202        assert!(!tokenized.input_ids.is_empty());
1203        assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Text));
1204    }
1205
1206    #[test]
1207    fn test_image_patch_creation() {
1208        let patches = MultimodalUtils::create_image_patches(224, 224, 16);
1209        assert_eq!(patches.len(), 14 * 14); // 196 patches
1210
1211        let first_patch = &patches[0];
1212        assert_eq!(first_patch.x, 0);
1213        assert_eq!(first_patch.y, 0);
1214        assert_eq!(first_patch.width, 16);
1215        assert_eq!(first_patch.height, 16);
1216    }
1217
1218    #[test]
1219    fn test_audio_frame_creation() {
1220        let frames = MultimodalUtils::create_audio_frames(44100.0, 1.0, 1024, 512);
1221        assert!(!frames.is_empty());
1222
1223        let first_frame = &frames[0];
1224        assert_eq!(first_frame.timestamp, 0.0);
1225        assert_eq!(first_frame.samples.len(), 1024);
1226    }
1227
1228    #[test]
1229    fn test_multimodal_input_with_images() {
1230        let text_tokenizer = create_test_char_tokenizer();
1231        let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1232
1233        let patches = vec![ImagePatch {
1234            x: 0,
1235            y: 0,
1236            width: 16,
1237            height: 16,
1238            pixels: vec![0.0; 16 * 16 * 3],
1239            embedding: Some(vec![1.0, 2.0, 3.0]),
1240        }];
1241
1242        let input = MultimodalInput {
1243            text: Some("An image".to_string()),
1244            image_patches: Some(patches),
1245            audio_frames: None,
1246            video_frames: None,
1247            table_data: None,
1248            graph_data: None,
1249            custom_data: None,
1250        };
1251
1252        let result = multimodal_tokenizer.tokenize_multimodal(&input);
1253        assert!(result.is_ok());
1254        let tokenized = result.expect("Operation failed in test");
1255        assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Text));
1256        assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Image));
1257    }
1258
1259    #[test]
1260    fn test_table_tokenization() {
1261        let text_tokenizer = create_test_char_tokenizer();
1262        let multimodal_tokenizer = MultimodalTokenizer::from_text_tokenizer(text_tokenizer);
1263
1264        let table = TableData {
1265            headers: vec!["Name".to_string(), "Age".to_string()],
1266            rows: vec![
1267                vec!["Alice".to_string(), "25".to_string()],
1268                vec!["Bob".to_string(), "30".to_string()],
1269            ],
1270            column_types: Some(vec!["string".to_string(), "int".to_string()]),
1271        };
1272
1273        let input = MultimodalInput {
1274            text: None,
1275            image_patches: None,
1276            audio_frames: None,
1277            video_frames: None,
1278            table_data: Some(table),
1279            graph_data: None,
1280            custom_data: None,
1281        };
1282
1283        let result = multimodal_tokenizer.tokenize_multimodal(&input);
1284        assert!(result.is_ok());
1285        let tokenized = result.expect("Operation failed in test");
1286        assert!(tokenized.modality_boundaries.contains_key(&ModalityType::Table));
1287    }
1288
1289    #[test]
1290    fn test_fusion_strategies() {
1291        let text_tokenizer = create_test_char_tokenizer();
1292        let mut config = MultimodalConfig::default();
1293        config.fusion_strategy = FusionStrategy::Interleaved;
1294        let multimodal_tokenizer = MultimodalTokenizer::new(text_tokenizer, config);
1295
1296        let tokens = vec![
1297            MultimodalToken {
1298                token_id: 1,
1299                modality: ModalityType::Text,
1300                modality_position: 0,
1301                global_position: 0,
1302                metadata: None,
1303            },
1304            MultimodalToken {
1305                token_id: 2,
1306                modality: ModalityType::Image,
1307                modality_position: 0,
1308                global_position: 1,
1309                metadata: None,
1310            },
1311        ];
1312
1313        let result = multimodal_tokenizer.apply_fusion_strategy(&tokens);
1314        assert!(result.is_ok());
1315    }
1316
1317    #[test]
1318    fn test_convert_to_multimodal() {
1319        let tokenized = TokenizedInput {
1320            input_ids: vec![1, 2, 3],
1321            attention_mask: vec![1, 1, 1],
1322            token_type_ids: None,
1323            special_tokens_mask: None,
1324            offset_mapping: None,
1325            overflowing_tokens: None,
1326        };
1327
1328        let multimodal = MultimodalUtils::convert_to_multimodal(tokenized, ModalityType::Text);
1329        assert_eq!(multimodal.input_ids, vec![1, 2, 3]);
1330        assert_eq!(multimodal.modality_tokens.len(), 3);
1331        assert!(multimodal.modality_boundaries.contains_key(&ModalityType::Text));
1332    }
1333
1334    #[test]
1335    fn test_cross_modal_attention() {
1336        let tokens = vec![
1337            MultimodalToken {
1338                token_id: 1,
1339                modality: ModalityType::Text,
1340                modality_position: 0,
1341                global_position: 0,
1342                metadata: None,
1343            },
1344            MultimodalToken {
1345                token_id: 2,
1346                modality: ModalityType::Image,
1347                modality_position: 0,
1348                global_position: 1,
1349                metadata: None,
1350            },
1351        ];
1352
1353        let attention = MultimodalUtils::calculate_cross_modal_attention(
1354            &tokens,
1355            &ModalityType::Text,
1356            &ModalityType::Image,
1357        );
1358
1359        assert_eq!(attention.len(), 1); // 1 text token
1360        assert_eq!(attention[0].len(), 1); // 1 image token
1361        assert_eq!(attention[0][0], 1.0);
1362    }
1363}