Skip to main content

torsh_data/
text.rs

1//! Text preprocessing and NLP dataset utilities
2
3use crate::{dataset::Dataset, transforms::Transform};
4use torsh_core::error::{Result, TorshError};
5use torsh_tensor::Tensor;
6
7#[cfg(not(feature = "std"))]
8use alloc::{boxed::Box, collections::BTreeMap as HashMap, string::String, vec::Vec};
9#[cfg(feature = "std")]
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13/// Text sequence data container
14#[derive(Debug, Clone)]
15pub struct TextSequence {
16    pub text: String,
17    pub tokens: Option<Vec<String>>,
18    pub token_ids: Option<Vec<usize>>,
19}
20
21impl TextSequence {
22    pub fn new(text: String) -> Self {
23        Self {
24            text,
25            tokens: None,
26            token_ids: None,
27        }
28    }
29
30    pub fn with_tokens(mut self, tokens: Vec<String>) -> Self {
31        self.tokens = Some(tokens);
32        self
33    }
34
35    pub fn with_token_ids(mut self, token_ids: Vec<usize>) -> Self {
36        self.token_ids = Some(token_ids);
37        self
38    }
39
40    pub fn len(&self) -> usize {
41        if let Some(ref tokens) = self.tokens {
42            tokens.len()
43        } else if let Some(ref token_ids) = self.token_ids {
44            token_ids.len()
45        } else {
46            self.text.split_whitespace().count()
47        }
48    }
49
50    pub fn is_empty(&self) -> bool {
51        self.text.is_empty()
52    }
53}
54
55/// Simple vocabulary for text tokenization
56#[derive(Debug, Clone)]
57pub struct Vocabulary {
58    token_to_id: HashMap<String, usize>,
59    id_to_token: Vec<String>,
60    special_tokens: HashMap<String, usize>,
61}
62
63impl Vocabulary {
64    pub fn new() -> Self {
65        Self {
66            token_to_id: HashMap::new(),
67            id_to_token: Vec::new(),
68            special_tokens: HashMap::new(),
69        }
70    }
71
72    /// Build vocabulary from text corpus
73    pub fn build_from_texts(&mut self, texts: &[String], min_freq: usize) -> Result<()> {
74        // Count token frequencies
75        let mut token_counts = HashMap::new();
76
77        for text in texts {
78            for token in Self::simple_tokenize(text) {
79                *token_counts.entry(token).or_insert(0) += 1;
80            }
81        }
82
83        // Add special tokens first
84        self.add_special_token("<UNK>".to_string());
85        self.add_special_token("<PAD>".to_string());
86        self.add_special_token("<SOS>".to_string());
87        self.add_special_token("<EOS>".to_string());
88
89        // Add tokens that meet frequency threshold
90        let mut sorted_tokens: Vec<(String, usize)> = token_counts.into_iter().collect();
91        sorted_tokens.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by frequency (descending)
92
93        for (token, count) in sorted_tokens {
94            if count >= min_freq && !self.token_to_id.contains_key(&token) {
95                self.add_token(token);
96            }
97        }
98
99        Ok(())
100    }
101
102    /// Add a regular token
103    pub fn add_token(&mut self, token: String) {
104        if !self.token_to_id.contains_key(&token) {
105            let id = self.id_to_token.len();
106            self.token_to_id.insert(token.clone(), id);
107            self.id_to_token.push(token);
108        }
109    }
110
111    /// Add a special token
112    pub fn add_special_token(&mut self, token: String) {
113        if !self.token_to_id.contains_key(&token) {
114            let id = self.id_to_token.len();
115            self.token_to_id.insert(token.clone(), id);
116            self.special_tokens.insert(token.clone(), id);
117            self.id_to_token.push(token);
118        }
119    }
120
121    /// Convert token to ID
122    pub fn token_to_id(&self, token: &str) -> usize {
123        self.token_to_id
124            .get(token)
125            .copied()
126            .unwrap_or_else(|| self.unk_id())
127    }
128
129    /// Convert ID to token
130    pub fn id_to_token(&self, id: usize) -> Option<&str> {
131        self.id_to_token.get(id).map(|s| s.as_str())
132    }
133
134    /// Get unknown token ID
135    pub fn unk_id(&self) -> usize {
136        self.special_tokens.get("<UNK>").copied().unwrap_or(0)
137    }
138
139    /// Get padding token ID
140    pub fn pad_id(&self) -> usize {
141        self.special_tokens.get("<PAD>").copied().unwrap_or(1)
142    }
143
144    /// Get start of sequence token ID
145    pub fn sos_id(&self) -> usize {
146        self.special_tokens.get("<SOS>").copied().unwrap_or(2)
147    }
148
149    /// Get end of sequence token ID
150    pub fn eos_id(&self) -> usize {
151        self.special_tokens.get("<EOS>").copied().unwrap_or(3)
152    }
153
154    /// Get vocabulary size
155    pub fn len(&self) -> usize {
156        self.id_to_token.len()
157    }
158
159    pub fn is_empty(&self) -> bool {
160        self.id_to_token.is_empty()
161    }
162
163    /// Simple whitespace tokenization
164    fn simple_tokenize(text: &str) -> Vec<String> {
165        text.split_whitespace().map(|s| s.to_lowercase()).collect()
166    }
167
168    /// Tokenize text and convert to IDs
169    pub fn encode(&self, text: &str) -> Vec<usize> {
170        Self::simple_tokenize(text)
171            .into_iter()
172            .map(|token| self.token_to_id(&token))
173            .collect()
174    }
175
176    /// Convert IDs back to text
177    pub fn decode(&self, ids: &[usize]) -> String {
178        ids.iter()
179            .filter_map(|&id| self.id_to_token(id))
180            .filter(|&token| !self.special_tokens.contains_key(token) || token == "<UNK>")
181            .collect::<Vec<_>>()
182            .join(" ")
183    }
184}
185
186impl Default for Vocabulary {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192/// Text dataset for classification tasks
193pub struct TextClassificationDataset {
194    texts: Vec<String>,
195    labels: Vec<usize>,
196    vocabulary: Vocabulary,
197    max_length: Option<usize>,
198    transform: Option<Box<dyn Transform<TextSequence, Output = Tensor<f32>>>>,
199}
200
201impl TextClassificationDataset {
202    /// Create a new text classification dataset
203    pub fn new(texts: Vec<String>, labels: Vec<usize>) -> Result<Self> {
204        if texts.len() != labels.len() {
205            return Err(TorshError::InvalidArgument(
206                "Number of texts must match number of labels".to_string(),
207            ));
208        }
209
210        let mut vocabulary = Vocabulary::new();
211        vocabulary.build_from_texts(&texts, 1)?;
212
213        Ok(Self {
214            texts,
215            labels,
216            vocabulary,
217            max_length: None,
218            transform: None,
219        })
220    }
221
222    /// Set maximum sequence length
223    pub fn with_max_length(mut self, max_length: usize) -> Self {
224        self.max_length = Some(max_length);
225        self
226    }
227
228    /// Set transform
229    pub fn with_transform<T>(mut self, transform: T) -> Self
230    where
231        T: Transform<TextSequence, Output = Tensor<f32>> + 'static,
232    {
233        self.transform = Some(Box::new(transform));
234        self
235    }
236
237    /// Get vocabulary reference
238    pub fn vocabulary(&self) -> &Vocabulary {
239        &self.vocabulary
240    }
241
242    /// Get number of classes
243    pub fn num_classes(&self) -> usize {
244        self.labels.iter().max().map(|&x| x + 1).unwrap_or(0)
245    }
246}
247
248impl Dataset for TextClassificationDataset {
249    type Item = (Tensor<f32>, usize);
250
251    fn len(&self) -> usize {
252        self.texts.len()
253    }
254
255    fn get(&self, index: usize) -> Result<Self::Item> {
256        if index >= self.texts.len() {
257            return Err(TorshError::IndexError {
258                index,
259                size: self.texts.len(),
260            });
261        }
262
263        let text = &self.texts[index];
264        let label = self.labels[index];
265
266        // Tokenize and encode
267        let token_ids = self.vocabulary.encode(text);
268        let tokens = Vocabulary::simple_tokenize(text);
269
270        let mut sequence = TextSequence::new(text.clone())
271            .with_tokens(tokens)
272            .with_token_ids(token_ids);
273
274        // Apply max length if specified
275        if let Some(max_len) = self.max_length {
276            if let Some(ref mut token_ids) = sequence.token_ids {
277                if token_ids.len() > max_len {
278                    token_ids.truncate(max_len);
279                } else {
280                    // Pad with padding token
281                    let pad_id = self.vocabulary.pad_id();
282                    token_ids.resize(max_len, pad_id);
283                }
284            }
285        }
286
287        let tensor = if let Some(ref transform) = self.transform {
288            transform.transform(sequence)?
289        } else {
290            // Default: convert token IDs to tensor
291            TokenIdsToTensor.transform(sequence)?
292        };
293
294        Ok((tensor, label))
295    }
296}
297
298/// Text file dataset for reading text files from directories
299pub struct TextFileDataset {
300    files: Vec<(PathBuf, usize)>,
301    classes: Vec<String>,
302    vocabulary: Vocabulary,
303    max_length: Option<usize>,
304    transform: Option<Box<dyn Transform<TextSequence, Output = Tensor<f32>>>>,
305}
306
307impl TextFileDataset {
308    /// Create a new text file dataset
309    pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
310        let root = root.as_ref().to_path_buf();
311
312        if !root.exists() {
313            return Err(TorshError::IoError(format!(
314                "Directory does not exist: {root:?}"
315            )));
316        }
317
318        let mut classes = Vec::new();
319        let mut files = Vec::new();
320        let mut all_texts = Vec::new();
321
322        // Scan subdirectories for classes
323        for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
324            let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
325            let path = entry.path();
326
327            if path.is_dir() {
328                let class_name = path
329                    .file_name()
330                    .and_then(|n| n.to_str())
331                    .ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
332                    .to_string();
333
334                let class_idx = classes.len();
335                classes.push(class_name);
336
337                // Scan text files in class directory
338                for file_entry in
339                    std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
340                {
341                    let file_entry = file_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
342                    let file_path = file_entry.path();
343
344                    if Self::is_text_file(&file_path) {
345                        files.push((file_path.clone(), class_idx));
346
347                        // Read file content for vocabulary building
348                        if let Ok(content) = std::fs::read_to_string(&file_path) {
349                            all_texts.push(content);
350                        }
351                    }
352                }
353            }
354        }
355
356        // Build vocabulary
357        let mut vocabulary = Vocabulary::new();
358        vocabulary.build_from_texts(&all_texts, 2)?;
359
360        Ok(Self {
361            files,
362            classes,
363            vocabulary,
364            max_length: None,
365            transform: None,
366        })
367    }
368
369    /// Check if file is a text file
370    fn is_text_file(path: &Path) -> bool {
371        if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
372            matches!(
373                extension.to_lowercase().as_str(),
374                "txt" | "text" | "md" | "rst" | "csv" | "json"
375            )
376        } else {
377            false
378        }
379    }
380
381    /// Set maximum sequence length
382    pub fn with_max_length(mut self, max_length: usize) -> Self {
383        self.max_length = Some(max_length);
384        self
385    }
386
387    /// Set transform
388    pub fn with_transform<T>(mut self, transform: T) -> Self
389    where
390        T: Transform<TextSequence, Output = Tensor<f32>> + 'static,
391    {
392        self.transform = Some(Box::new(transform));
393        self
394    }
395
396    /// Get class names
397    pub fn classes(&self) -> &[String] {
398        &self.classes
399    }
400
401    /// Get vocabulary reference
402    pub fn vocabulary(&self) -> &Vocabulary {
403        &self.vocabulary
404    }
405}
406
407impl Dataset for TextFileDataset {
408    type Item = (Tensor<f32>, usize);
409
410    fn len(&self) -> usize {
411        self.files.len()
412    }
413
414    fn get(&self, index: usize) -> Result<Self::Item> {
415        if index >= self.files.len() {
416            return Err(TorshError::IndexError {
417                index,
418                size: self.files.len(),
419            });
420        }
421
422        let (ref path, class_idx) = self.files[index];
423
424        // Read file content
425        let text = std::fs::read_to_string(path)
426            .map_err(|e| TorshError::IoError(format!("Failed to read file {path:?}: {e}")))?;
427
428        // Tokenize and encode
429        let token_ids = self.vocabulary.encode(&text);
430        let tokens = Vocabulary::simple_tokenize(&text);
431
432        let mut sequence = TextSequence::new(text)
433            .with_tokens(tokens)
434            .with_token_ids(token_ids);
435
436        // Apply max length if specified
437        if let Some(max_len) = self.max_length {
438            if let Some(ref mut token_ids) = sequence.token_ids {
439                if token_ids.len() > max_len {
440                    token_ids.truncate(max_len);
441                } else {
442                    // Pad with padding token
443                    let pad_id = self.vocabulary.pad_id();
444                    token_ids.resize(max_len, pad_id);
445                }
446            }
447        }
448
449        let tensor = if let Some(ref transform) = self.transform {
450            transform.transform(sequence)?
451        } else {
452            // Default: convert token IDs to tensor
453            TokenIdsToTensor.transform(sequence)?
454        };
455
456        Ok((tensor, class_idx))
457    }
458}
459
460/// Transform to convert token IDs to tensor
461pub struct TokenIdsToTensor;
462
463impl Transform<TextSequence> for TokenIdsToTensor {
464    type Output = Tensor<f32>;
465
466    fn transform(&self, input: TextSequence) -> Result<Self::Output> {
467        if let Some(token_ids) = input.token_ids {
468            // Convert token IDs to f32 tensor
469            let len = token_ids.len();
470            let data: Vec<f32> = token_ids.into_iter().map(|id| id as f32).collect();
471            Tensor::from_data(data, vec![len], torsh_core::device::DeviceType::Cpu)
472        } else {
473            Err(TorshError::InvalidArgument(
474                "TextSequence must have token_ids for tensor conversion".to_string(),
475            ))
476        }
477    }
478}
479
480/// Text preprocessing transforms
481pub mod transforms {
482    use super::*;
483    use crate::transforms::Transform;
484
485    /// Convert text to lowercase
486    pub struct ToLowercase;
487
488    impl Transform<TextSequence> for ToLowercase {
489        type Output = TextSequence;
490
491        fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
492            input.text = input.text.to_lowercase();
493            if let Some(ref mut tokens) = input.tokens {
494                for token in tokens.iter_mut() {
495                    *token = token.to_lowercase();
496                }
497            }
498            Ok(input)
499        }
500    }
501
502    /// Remove punctuation from text
503    pub struct RemovePunctuation;
504
505    impl Transform<TextSequence> for RemovePunctuation {
506        type Output = TextSequence;
507
508        fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
509            input.text = input
510                .text
511                .chars()
512                .filter(|c| c.is_alphanumeric() || c.is_whitespace())
513                .collect();
514
515            if let Some(ref mut tokens) = input.tokens {
516                for token in tokens.iter_mut() {
517                    *token = token.chars().filter(|c| c.is_alphanumeric()).collect();
518                }
519                // Remove empty tokens
520                tokens.retain(|token| !token.is_empty());
521            }
522            Ok(input)
523        }
524    }
525
526    /// Truncate or pad sequence to fixed length
527    pub struct FixedLength {
528        length: usize,
529        pad_token_id: usize,
530    }
531
532    impl FixedLength {
533        pub fn new(length: usize, pad_token_id: usize) -> Self {
534            Self {
535                length,
536                pad_token_id,
537            }
538        }
539    }
540
541    impl Transform<TextSequence> for FixedLength {
542        type Output = TextSequence;
543
544        fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
545            if let Some(ref mut token_ids) = input.token_ids {
546                if token_ids.len() > self.length {
547                    token_ids.truncate(self.length);
548                } else {
549                    token_ids.resize(self.length, self.pad_token_id);
550                }
551            }
552
553            if let Some(ref mut tokens) = input.tokens {
554                if tokens.len() > self.length {
555                    tokens.truncate(self.length);
556                } else {
557                    tokens.resize(self.length, "<PAD>".to_string());
558                }
559            }
560
561            Ok(input)
562        }
563    }
564
565    /// Add start and end tokens
566    pub struct AddSpecialTokens {
567        sos_token_id: usize,
568        eos_token_id: usize,
569    }
570
571    impl AddSpecialTokens {
572        pub fn new(sos_token_id: usize, eos_token_id: usize) -> Self {
573            Self {
574                sos_token_id,
575                eos_token_id,
576            }
577        }
578    }
579
580    impl Transform<TextSequence> for AddSpecialTokens {
581        type Output = TextSequence;
582
583        fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
584            if let Some(ref mut token_ids) = input.token_ids {
585                token_ids.insert(0, self.sos_token_id);
586                token_ids.push(self.eos_token_id);
587            }
588
589            if let Some(ref mut tokens) = input.tokens {
590                tokens.insert(0, "<SOS>".to_string());
591                tokens.push("<EOS>".to_string());
592            }
593
594            Ok(input)
595        }
596    }
597
598    /// Simple n-gram extraction
599    pub struct NGrams {
600        n: usize,
601    }
602
603    impl NGrams {
604        pub fn new(n: usize) -> Self {
605            assert!(n > 0, "n must be positive");
606            Self { n }
607        }
608    }
609
610    impl Transform<TextSequence> for NGrams {
611        type Output = TextSequence;
612
613        fn transform(&self, input: TextSequence) -> Result<Self::Output> {
614            let tokens = if let Some(tokens) = input.tokens {
615                tokens
616            } else {
617                Vocabulary::simple_tokenize(&input.text)
618            };
619
620            let mut ngrams = Vec::new();
621            for window in tokens.windows(self.n) {
622                let ngram = window.join("_");
623                ngrams.push(ngram);
624            }
625
626            let ngram_text = ngrams.join(" ");
627            Ok(TextSequence::new(ngram_text).with_tokens(ngrams))
628        }
629    }
630
631    /// Character-level tokenization
632    pub struct CharTokenizer;
633
634    impl Transform<TextSequence> for CharTokenizer {
635        type Output = TextSequence;
636
637        fn transform(&self, input: TextSequence) -> Result<Self::Output> {
638            let chars: Vec<String> = input.text.chars().map(|c| c.to_string()).collect();
639            Ok(input.with_tokens(chars))
640        }
641    }
642
643    /// Byte Pair Encoding (simplified version)
644    pub struct SimpleBPE {
645        #[allow(dead_code)]
646        vocab_size: usize,
647    }
648
649    impl SimpleBPE {
650        pub fn new(vocab_size: usize) -> Self {
651            Self { vocab_size }
652        }
653    }
654
655    impl Transform<TextSequence> for SimpleBPE {
656        type Output = TextSequence;
657
658        fn transform(&self, input: TextSequence) -> Result<Self::Output> {
659            // Simplified BPE - in practice you'd need a trained BPE model
660            // For now, just do character-level tokenization with word boundaries
661            let mut tokens = Vec::new();
662
663            for word in input.text.split_whitespace() {
664                // Add word-level tokens for short words
665                if word.len() <= 3 {
666                    tokens.push(word.to_string());
667                } else {
668                    // Split longer words into subwords (simplified)
669                    let chars: Vec<char> = word.chars().collect();
670                    for chunk in chars.chunks(2) {
671                        let subword: String = chunk.iter().collect();
672                        tokens.push(subword);
673                    }
674                }
675            }
676
677            Ok(input.with_tokens(tokens))
678        }
679    }
680}
681
682/// Common NLP datasets
683pub mod datasets {
684    use super::*;
685
686    /// IMDB movie reviews dataset (simplified)
687    pub struct IMDB {
688        #[allow(dead_code)]
689        root: PathBuf,
690        #[allow(dead_code)]
691        split: String,
692        vocabulary: Vocabulary,
693        samples: Vec<(String, usize)>, // (text, label)
694        transform: Option<Box<dyn Transform<TextSequence, Output = Tensor<f32>>>>,
695    }
696
697    impl IMDB {
698        /// Create IMDB dataset
699        pub fn new<P: AsRef<Path>>(root: P, split: &str) -> Result<Self> {
700            let root = root.as_ref().to_path_buf();
701
702            // In a real implementation, you would:
703            // 1. Download IMDB dataset from official source
704            // 2. Parse the data files
705            // 3. Load reviews and sentiment labels
706
707            // For now, create dummy data
708            let samples = vec![
709                ("This movie is great!".to_string(), 1),          // positive
710                ("Terrible film, waste of time.".to_string(), 0), // negative
711                ("Amazing cinematography and acting.".to_string(), 1),
712                ("Boring and predictable plot.".to_string(), 0),
713            ];
714
715            let texts: Vec<String> = samples.iter().map(|(text, _)| text.clone()).collect();
716            let mut vocabulary = Vocabulary::new();
717            vocabulary.build_from_texts(&texts, 1)?;
718
719            Ok(Self {
720                root,
721                split: split.to_string(),
722                vocabulary,
723                samples,
724                transform: None,
725            })
726        }
727
728        /// Set transform
729        pub fn with_transform<T>(mut self, transform: T) -> Self
730        where
731            T: Transform<TextSequence, Output = Tensor<f32>> + 'static,
732        {
733            self.transform = Some(Box::new(transform));
734            self
735        }
736
737        /// Get vocabulary reference
738        pub fn vocabulary(&self) -> &Vocabulary {
739            &self.vocabulary
740        }
741    }
742
743    impl Dataset for IMDB {
744        type Item = (Tensor<f32>, usize);
745
746        fn len(&self) -> usize {
747            self.samples.len()
748        }
749
750        fn get(&self, index: usize) -> Result<Self::Item> {
751            if index >= self.samples.len() {
752                return Err(TorshError::IndexError {
753                    index,
754                    size: self.samples.len(),
755                });
756            }
757
758            let (ref text, label) = self.samples[index];
759
760            // Tokenize and encode
761            let token_ids = self.vocabulary.encode(text);
762            let tokens = Vocabulary::simple_tokenize(text);
763
764            let sequence = TextSequence::new(text.clone())
765                .with_tokens(tokens)
766                .with_token_ids(token_ids);
767
768            let tensor = if let Some(ref transform) = self.transform {
769                transform.transform(sequence)?
770            } else {
771                TokenIdsToTensor.transform(sequence)?
772            };
773
774            Ok((tensor, label))
775        }
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782
783    #[test]
784    fn test_vocabulary() {
785        let texts = vec![
786            "hello world".to_string(),
787            "world hello".to_string(),
788            "foo bar".to_string(),
789        ];
790
791        let mut vocab = Vocabulary::new();
792        vocab.build_from_texts(&texts, 1).unwrap();
793
794        // Should have special tokens + unique words
795        assert!(vocab.len() >= 6); // 4 special + hello, world, foo, bar
796
797        // Test encoding/decoding
798        let ids = vocab.encode("hello world");
799        let decoded = vocab.decode(&ids);
800        assert_eq!(decoded, "hello world");
801    }
802
803    #[test]
804    fn test_text_sequence() {
805        let seq = TextSequence::new("hello world".to_string())
806            .with_tokens(vec!["hello".to_string(), "world".to_string()])
807            .with_token_ids(vec![1, 2]);
808
809        assert_eq!(seq.len(), 2);
810        assert!(!seq.is_empty());
811    }
812
813    #[test]
814    fn test_text_classification_dataset() {
815        let texts = vec![
816            "positive example".to_string(),
817            "negative example".to_string(),
818        ];
819        let labels = vec![1, 0];
820
821        let dataset = TextClassificationDataset::new(texts, labels).unwrap();
822        assert_eq!(dataset.len(), 2);
823        assert_eq!(dataset.num_classes(), 2);
824
825        let (tensor, label) = dataset.get(0).unwrap();
826        assert_eq!(label, 1);
827        assert!(tensor.ndim() > 0);
828    }
829
830    #[test]
831    fn test_token_ids_to_tensor() {
832        let seq = TextSequence::new("test".to_string()).with_token_ids(vec![1, 2, 3]);
833
834        let transform = TokenIdsToTensor;
835        let result = transform.transform(seq).unwrap();
836
837        assert_eq!(result.shape().dims(), &[3]);
838        let data = result.to_vec().unwrap();
839        assert_eq!(data, vec![1.0, 2.0, 3.0]);
840    }
841
842    #[test]
843    fn test_text_transforms() {
844        use transforms::*;
845
846        let seq = TextSequence::new("Hello, World!".to_string())
847            .with_tokens(vec!["Hello,".to_string(), "World!".to_string()]);
848
849        // Test lowercase transform
850        let lowercase = ToLowercase;
851        let result = lowercase.transform(seq.clone()).unwrap();
852        assert_eq!(result.text, "hello, world!");
853
854        // Test punctuation removal
855        let remove_punct = RemovePunctuation;
856        let result = remove_punct.transform(seq.clone()).unwrap();
857        assert_eq!(result.text, "Hello World");
858
859        // Test fixed length
860        let seq_with_ids = seq.with_token_ids(vec![1, 2]);
861        let fixed_len = FixedLength::new(4, 0);
862        let result = fixed_len.transform(seq_with_ids).unwrap();
863        assert_eq!(result.token_ids.unwrap(), vec![1, 2, 0, 0]);
864
865        // Test special tokens
866        let add_special = AddSpecialTokens::new(100, 101);
867        let seq_with_ids = TextSequence::new("test".to_string()).with_token_ids(vec![1, 2]);
868        let result = add_special.transform(seq_with_ids).unwrap();
869        assert_eq!(result.token_ids.unwrap(), vec![100, 1, 2, 101]);
870    }
871
872    #[test]
873    fn test_ngrams() {
874        use transforms::*;
875
876        let seq = TextSequence::new("the quick brown fox".to_string()).with_tokens(vec![
877            "the".to_string(),
878            "quick".to_string(),
879            "brown".to_string(),
880            "fox".to_string(),
881        ]);
882
883        let bigrams = NGrams::new(2);
884        let result = bigrams.transform(seq).unwrap();
885
886        let expected_tokens = vec![
887            "the_quick".to_string(),
888            "quick_brown".to_string(),
889            "brown_fox".to_string(),
890        ];
891        assert_eq!(result.tokens.unwrap(), expected_tokens);
892    }
893
894    #[test]
895    fn test_imdb_dataset() {
896        use datasets::*;
897
898        let dataset = IMDB::new("/tmp", "train").unwrap();
899        assert_eq!(dataset.len(), 4);
900
901        let (tensor, label) = dataset.get(0).unwrap();
902        assert_eq!(label, 1); // positive
903        assert!(tensor.ndim() > 0);
904    }
905
906    #[test]
907    fn test_char_tokenizer() {
908        use transforms::*;
909
910        let seq = TextSequence::new("abc".to_string());
911        let char_tokenizer = CharTokenizer;
912        let result = char_tokenizer.transform(seq).unwrap();
913
914        assert_eq!(
915            result.tokens.unwrap(),
916            vec!["a".to_string(), "b".to_string(), "c".to_string()]
917        );
918    }
919
920    #[test]
921    fn test_simple_bpe() {
922        use transforms::*;
923
924        let seq = TextSequence::new("hello world".to_string());
925        let bpe = SimpleBPE::new(1000);
926        let result = bpe.transform(seq).unwrap();
927
928        // Should have some form of subword tokenization
929        assert!(result.tokens.is_some());
930        assert!(!result.tokens.unwrap().is_empty());
931    }
932}