1use crate::{dataset::Dataset, transforms::Transform};
4use torsh_core::error::{Result, TorshError};
5use torsh_tensor::Tensor;
6
7#[cfg(not(feature = "std"))]
8use alloc::{boxed::Box, collections::BTreeMap as HashMap, string::String, vec::Vec};
9#[cfg(feature = "std")]
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone)]
15pub struct TextSequence {
16 pub text: String,
17 pub tokens: Option<Vec<String>>,
18 pub token_ids: Option<Vec<usize>>,
19}
20
21impl TextSequence {
22 pub fn new(text: String) -> Self {
23 Self {
24 text,
25 tokens: None,
26 token_ids: None,
27 }
28 }
29
30 pub fn with_tokens(mut self, tokens: Vec<String>) -> Self {
31 self.tokens = Some(tokens);
32 self
33 }
34
35 pub fn with_token_ids(mut self, token_ids: Vec<usize>) -> Self {
36 self.token_ids = Some(token_ids);
37 self
38 }
39
40 pub fn len(&self) -> usize {
41 if let Some(ref tokens) = self.tokens {
42 tokens.len()
43 } else if let Some(ref token_ids) = self.token_ids {
44 token_ids.len()
45 } else {
46 self.text.split_whitespace().count()
47 }
48 }
49
50 pub fn is_empty(&self) -> bool {
51 self.text.is_empty()
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct Vocabulary {
58 token_to_id: HashMap<String, usize>,
59 id_to_token: Vec<String>,
60 special_tokens: HashMap<String, usize>,
61}
62
63impl Vocabulary {
64 pub fn new() -> Self {
65 Self {
66 token_to_id: HashMap::new(),
67 id_to_token: Vec::new(),
68 special_tokens: HashMap::new(),
69 }
70 }
71
72 pub fn build_from_texts(&mut self, texts: &[String], min_freq: usize) -> Result<()> {
74 let mut token_counts = HashMap::new();
76
77 for text in texts {
78 for token in Self::simple_tokenize(text) {
79 *token_counts.entry(token).or_insert(0) += 1;
80 }
81 }
82
83 self.add_special_token("<UNK>".to_string());
85 self.add_special_token("<PAD>".to_string());
86 self.add_special_token("<SOS>".to_string());
87 self.add_special_token("<EOS>".to_string());
88
89 let mut sorted_tokens: Vec<(String, usize)> = token_counts.into_iter().collect();
91 sorted_tokens.sort_by(|a, b| b.1.cmp(&a.1)); for (token, count) in sorted_tokens {
94 if count >= min_freq && !self.token_to_id.contains_key(&token) {
95 self.add_token(token);
96 }
97 }
98
99 Ok(())
100 }
101
102 pub fn add_token(&mut self, token: String) {
104 if !self.token_to_id.contains_key(&token) {
105 let id = self.id_to_token.len();
106 self.token_to_id.insert(token.clone(), id);
107 self.id_to_token.push(token);
108 }
109 }
110
111 pub fn add_special_token(&mut self, token: String) {
113 if !self.token_to_id.contains_key(&token) {
114 let id = self.id_to_token.len();
115 self.token_to_id.insert(token.clone(), id);
116 self.special_tokens.insert(token.clone(), id);
117 self.id_to_token.push(token);
118 }
119 }
120
121 pub fn token_to_id(&self, token: &str) -> usize {
123 self.token_to_id
124 .get(token)
125 .copied()
126 .unwrap_or_else(|| self.unk_id())
127 }
128
129 pub fn id_to_token(&self, id: usize) -> Option<&str> {
131 self.id_to_token.get(id).map(|s| s.as_str())
132 }
133
134 pub fn unk_id(&self) -> usize {
136 self.special_tokens.get("<UNK>").copied().unwrap_or(0)
137 }
138
139 pub fn pad_id(&self) -> usize {
141 self.special_tokens.get("<PAD>").copied().unwrap_or(1)
142 }
143
144 pub fn sos_id(&self) -> usize {
146 self.special_tokens.get("<SOS>").copied().unwrap_or(2)
147 }
148
149 pub fn eos_id(&self) -> usize {
151 self.special_tokens.get("<EOS>").copied().unwrap_or(3)
152 }
153
154 pub fn len(&self) -> usize {
156 self.id_to_token.len()
157 }
158
159 pub fn is_empty(&self) -> bool {
160 self.id_to_token.is_empty()
161 }
162
163 fn simple_tokenize(text: &str) -> Vec<String> {
165 text.split_whitespace().map(|s| s.to_lowercase()).collect()
166 }
167
168 pub fn encode(&self, text: &str) -> Vec<usize> {
170 Self::simple_tokenize(text)
171 .into_iter()
172 .map(|token| self.token_to_id(&token))
173 .collect()
174 }
175
176 pub fn decode(&self, ids: &[usize]) -> String {
178 ids.iter()
179 .filter_map(|&id| self.id_to_token(id))
180 .filter(|&token| !self.special_tokens.contains_key(token) || token == "<UNK>")
181 .collect::<Vec<_>>()
182 .join(" ")
183 }
184}
185
186impl Default for Vocabulary {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192pub struct TextClassificationDataset {
194 texts: Vec<String>,
195 labels: Vec<usize>,
196 vocabulary: Vocabulary,
197 max_length: Option<usize>,
198 transform: Option<Box<dyn Transform<TextSequence, Output = Tensor<f32>>>>,
199}
200
201impl TextClassificationDataset {
202 pub fn new(texts: Vec<String>, labels: Vec<usize>) -> Result<Self> {
204 if texts.len() != labels.len() {
205 return Err(TorshError::InvalidArgument(
206 "Number of texts must match number of labels".to_string(),
207 ));
208 }
209
210 let mut vocabulary = Vocabulary::new();
211 vocabulary.build_from_texts(&texts, 1)?;
212
213 Ok(Self {
214 texts,
215 labels,
216 vocabulary,
217 max_length: None,
218 transform: None,
219 })
220 }
221
222 pub fn with_max_length(mut self, max_length: usize) -> Self {
224 self.max_length = Some(max_length);
225 self
226 }
227
228 pub fn with_transform<T>(mut self, transform: T) -> Self
230 where
231 T: Transform<TextSequence, Output = Tensor<f32>> + 'static,
232 {
233 self.transform = Some(Box::new(transform));
234 self
235 }
236
237 pub fn vocabulary(&self) -> &Vocabulary {
239 &self.vocabulary
240 }
241
242 pub fn num_classes(&self) -> usize {
244 self.labels.iter().max().map(|&x| x + 1).unwrap_or(0)
245 }
246}
247
248impl Dataset for TextClassificationDataset {
249 type Item = (Tensor<f32>, usize);
250
251 fn len(&self) -> usize {
252 self.texts.len()
253 }
254
255 fn get(&self, index: usize) -> Result<Self::Item> {
256 if index >= self.texts.len() {
257 return Err(TorshError::IndexError {
258 index,
259 size: self.texts.len(),
260 });
261 }
262
263 let text = &self.texts[index];
264 let label = self.labels[index];
265
266 let token_ids = self.vocabulary.encode(text);
268 let tokens = Vocabulary::simple_tokenize(text);
269
270 let mut sequence = TextSequence::new(text.clone())
271 .with_tokens(tokens)
272 .with_token_ids(token_ids);
273
274 if let Some(max_len) = self.max_length {
276 if let Some(ref mut token_ids) = sequence.token_ids {
277 if token_ids.len() > max_len {
278 token_ids.truncate(max_len);
279 } else {
280 let pad_id = self.vocabulary.pad_id();
282 token_ids.resize(max_len, pad_id);
283 }
284 }
285 }
286
287 let tensor = if let Some(ref transform) = self.transform {
288 transform.transform(sequence)?
289 } else {
290 TokenIdsToTensor.transform(sequence)?
292 };
293
294 Ok((tensor, label))
295 }
296}
297
298pub struct TextFileDataset {
300 files: Vec<(PathBuf, usize)>,
301 classes: Vec<String>,
302 vocabulary: Vocabulary,
303 max_length: Option<usize>,
304 transform: Option<Box<dyn Transform<TextSequence, Output = Tensor<f32>>>>,
305}
306
307impl TextFileDataset {
308 pub fn new<P: AsRef<Path>>(root: P) -> Result<Self> {
310 let root = root.as_ref().to_path_buf();
311
312 if !root.exists() {
313 return Err(TorshError::IoError(format!(
314 "Directory does not exist: {root:?}"
315 )));
316 }
317
318 let mut classes = Vec::new();
319 let mut files = Vec::new();
320 let mut all_texts = Vec::new();
321
322 for entry in std::fs::read_dir(&root).map_err(|e| TorshError::IoError(e.to_string()))? {
324 let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
325 let path = entry.path();
326
327 if path.is_dir() {
328 let class_name = path
329 .file_name()
330 .and_then(|n| n.to_str())
331 .ok_or_else(|| TorshError::IoError("Invalid class directory name".to_string()))?
332 .to_string();
333
334 let class_idx = classes.len();
335 classes.push(class_name);
336
337 for file_entry in
339 std::fs::read_dir(&path).map_err(|e| TorshError::IoError(e.to_string()))?
340 {
341 let file_entry = file_entry.map_err(|e| TorshError::IoError(e.to_string()))?;
342 let file_path = file_entry.path();
343
344 if Self::is_text_file(&file_path) {
345 files.push((file_path.clone(), class_idx));
346
347 if let Ok(content) = std::fs::read_to_string(&file_path) {
349 all_texts.push(content);
350 }
351 }
352 }
353 }
354 }
355
356 let mut vocabulary = Vocabulary::new();
358 vocabulary.build_from_texts(&all_texts, 2)?;
359
360 Ok(Self {
361 files,
362 classes,
363 vocabulary,
364 max_length: None,
365 transform: None,
366 })
367 }
368
369 fn is_text_file(path: &Path) -> bool {
371 if let Some(extension) = path.extension().and_then(|ext| ext.to_str()) {
372 matches!(
373 extension.to_lowercase().as_str(),
374 "txt" | "text" | "md" | "rst" | "csv" | "json"
375 )
376 } else {
377 false
378 }
379 }
380
381 pub fn with_max_length(mut self, max_length: usize) -> Self {
383 self.max_length = Some(max_length);
384 self
385 }
386
387 pub fn with_transform<T>(mut self, transform: T) -> Self
389 where
390 T: Transform<TextSequence, Output = Tensor<f32>> + 'static,
391 {
392 self.transform = Some(Box::new(transform));
393 self
394 }
395
396 pub fn classes(&self) -> &[String] {
398 &self.classes
399 }
400
401 pub fn vocabulary(&self) -> &Vocabulary {
403 &self.vocabulary
404 }
405}
406
407impl Dataset for TextFileDataset {
408 type Item = (Tensor<f32>, usize);
409
410 fn len(&self) -> usize {
411 self.files.len()
412 }
413
414 fn get(&self, index: usize) -> Result<Self::Item> {
415 if index >= self.files.len() {
416 return Err(TorshError::IndexError {
417 index,
418 size: self.files.len(),
419 });
420 }
421
422 let (ref path, class_idx) = self.files[index];
423
424 let text = std::fs::read_to_string(path)
426 .map_err(|e| TorshError::IoError(format!("Failed to read file {path:?}: {e}")))?;
427
428 let token_ids = self.vocabulary.encode(&text);
430 let tokens = Vocabulary::simple_tokenize(&text);
431
432 let mut sequence = TextSequence::new(text)
433 .with_tokens(tokens)
434 .with_token_ids(token_ids);
435
436 if let Some(max_len) = self.max_length {
438 if let Some(ref mut token_ids) = sequence.token_ids {
439 if token_ids.len() > max_len {
440 token_ids.truncate(max_len);
441 } else {
442 let pad_id = self.vocabulary.pad_id();
444 token_ids.resize(max_len, pad_id);
445 }
446 }
447 }
448
449 let tensor = if let Some(ref transform) = self.transform {
450 transform.transform(sequence)?
451 } else {
452 TokenIdsToTensor.transform(sequence)?
454 };
455
456 Ok((tensor, class_idx))
457 }
458}
459
460pub struct TokenIdsToTensor;
462
463impl Transform<TextSequence> for TokenIdsToTensor {
464 type Output = Tensor<f32>;
465
466 fn transform(&self, input: TextSequence) -> Result<Self::Output> {
467 if let Some(token_ids) = input.token_ids {
468 let len = token_ids.len();
470 let data: Vec<f32> = token_ids.into_iter().map(|id| id as f32).collect();
471 Tensor::from_data(data, vec![len], torsh_core::device::DeviceType::Cpu)
472 } else {
473 Err(TorshError::InvalidArgument(
474 "TextSequence must have token_ids for tensor conversion".to_string(),
475 ))
476 }
477 }
478}
479
480pub mod transforms {
482 use super::*;
483 use crate::transforms::Transform;
484
485 pub struct ToLowercase;
487
488 impl Transform<TextSequence> for ToLowercase {
489 type Output = TextSequence;
490
491 fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
492 input.text = input.text.to_lowercase();
493 if let Some(ref mut tokens) = input.tokens {
494 for token in tokens.iter_mut() {
495 *token = token.to_lowercase();
496 }
497 }
498 Ok(input)
499 }
500 }
501
502 pub struct RemovePunctuation;
504
505 impl Transform<TextSequence> for RemovePunctuation {
506 type Output = TextSequence;
507
508 fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
509 input.text = input
510 .text
511 .chars()
512 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
513 .collect();
514
515 if let Some(ref mut tokens) = input.tokens {
516 for token in tokens.iter_mut() {
517 *token = token.chars().filter(|c| c.is_alphanumeric()).collect();
518 }
519 tokens.retain(|token| !token.is_empty());
521 }
522 Ok(input)
523 }
524 }
525
526 pub struct FixedLength {
528 length: usize,
529 pad_token_id: usize,
530 }
531
532 impl FixedLength {
533 pub fn new(length: usize, pad_token_id: usize) -> Self {
534 Self {
535 length,
536 pad_token_id,
537 }
538 }
539 }
540
541 impl Transform<TextSequence> for FixedLength {
542 type Output = TextSequence;
543
544 fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
545 if let Some(ref mut token_ids) = input.token_ids {
546 if token_ids.len() > self.length {
547 token_ids.truncate(self.length);
548 } else {
549 token_ids.resize(self.length, self.pad_token_id);
550 }
551 }
552
553 if let Some(ref mut tokens) = input.tokens {
554 if tokens.len() > self.length {
555 tokens.truncate(self.length);
556 } else {
557 tokens.resize(self.length, "<PAD>".to_string());
558 }
559 }
560
561 Ok(input)
562 }
563 }
564
565 pub struct AddSpecialTokens {
567 sos_token_id: usize,
568 eos_token_id: usize,
569 }
570
571 impl AddSpecialTokens {
572 pub fn new(sos_token_id: usize, eos_token_id: usize) -> Self {
573 Self {
574 sos_token_id,
575 eos_token_id,
576 }
577 }
578 }
579
580 impl Transform<TextSequence> for AddSpecialTokens {
581 type Output = TextSequence;
582
583 fn transform(&self, mut input: TextSequence) -> Result<Self::Output> {
584 if let Some(ref mut token_ids) = input.token_ids {
585 token_ids.insert(0, self.sos_token_id);
586 token_ids.push(self.eos_token_id);
587 }
588
589 if let Some(ref mut tokens) = input.tokens {
590 tokens.insert(0, "<SOS>".to_string());
591 tokens.push("<EOS>".to_string());
592 }
593
594 Ok(input)
595 }
596 }
597
598 pub struct NGrams {
600 n: usize,
601 }
602
603 impl NGrams {
604 pub fn new(n: usize) -> Self {
605 assert!(n > 0, "n must be positive");
606 Self { n }
607 }
608 }
609
610 impl Transform<TextSequence> for NGrams {
611 type Output = TextSequence;
612
613 fn transform(&self, input: TextSequence) -> Result<Self::Output> {
614 let tokens = if let Some(tokens) = input.tokens {
615 tokens
616 } else {
617 Vocabulary::simple_tokenize(&input.text)
618 };
619
620 let mut ngrams = Vec::new();
621 for window in tokens.windows(self.n) {
622 let ngram = window.join("_");
623 ngrams.push(ngram);
624 }
625
626 let ngram_text = ngrams.join(" ");
627 Ok(TextSequence::new(ngram_text).with_tokens(ngrams))
628 }
629 }
630
631 pub struct CharTokenizer;
633
634 impl Transform<TextSequence> for CharTokenizer {
635 type Output = TextSequence;
636
637 fn transform(&self, input: TextSequence) -> Result<Self::Output> {
638 let chars: Vec<String> = input.text.chars().map(|c| c.to_string()).collect();
639 Ok(input.with_tokens(chars))
640 }
641 }
642
643 pub struct SimpleBPE {
645 #[allow(dead_code)]
646 vocab_size: usize,
647 }
648
649 impl SimpleBPE {
650 pub fn new(vocab_size: usize) -> Self {
651 Self { vocab_size }
652 }
653 }
654
655 impl Transform<TextSequence> for SimpleBPE {
656 type Output = TextSequence;
657
658 fn transform(&self, input: TextSequence) -> Result<Self::Output> {
659 let mut tokens = Vec::new();
662
663 for word in input.text.split_whitespace() {
664 if word.len() <= 3 {
666 tokens.push(word.to_string());
667 } else {
668 let chars: Vec<char> = word.chars().collect();
670 for chunk in chars.chunks(2) {
671 let subword: String = chunk.iter().collect();
672 tokens.push(subword);
673 }
674 }
675 }
676
677 Ok(input.with_tokens(tokens))
678 }
679 }
680}
681
682pub mod datasets {
684 use super::*;
685
686 pub struct IMDB {
688 #[allow(dead_code)]
689 root: PathBuf,
690 #[allow(dead_code)]
691 split: String,
692 vocabulary: Vocabulary,
693 samples: Vec<(String, usize)>, transform: Option<Box<dyn Transform<TextSequence, Output = Tensor<f32>>>>,
695 }
696
697 impl IMDB {
698 pub fn new<P: AsRef<Path>>(root: P, split: &str) -> Result<Self> {
700 let root = root.as_ref().to_path_buf();
701
702 let samples = vec![
709 ("This movie is great!".to_string(), 1), ("Terrible film, waste of time.".to_string(), 0), ("Amazing cinematography and acting.".to_string(), 1),
712 ("Boring and predictable plot.".to_string(), 0),
713 ];
714
715 let texts: Vec<String> = samples.iter().map(|(text, _)| text.clone()).collect();
716 let mut vocabulary = Vocabulary::new();
717 vocabulary.build_from_texts(&texts, 1)?;
718
719 Ok(Self {
720 root,
721 split: split.to_string(),
722 vocabulary,
723 samples,
724 transform: None,
725 })
726 }
727
728 pub fn with_transform<T>(mut self, transform: T) -> Self
730 where
731 T: Transform<TextSequence, Output = Tensor<f32>> + 'static,
732 {
733 self.transform = Some(Box::new(transform));
734 self
735 }
736
737 pub fn vocabulary(&self) -> &Vocabulary {
739 &self.vocabulary
740 }
741 }
742
743 impl Dataset for IMDB {
744 type Item = (Tensor<f32>, usize);
745
746 fn len(&self) -> usize {
747 self.samples.len()
748 }
749
750 fn get(&self, index: usize) -> Result<Self::Item> {
751 if index >= self.samples.len() {
752 return Err(TorshError::IndexError {
753 index,
754 size: self.samples.len(),
755 });
756 }
757
758 let (ref text, label) = self.samples[index];
759
760 let token_ids = self.vocabulary.encode(text);
762 let tokens = Vocabulary::simple_tokenize(text);
763
764 let sequence = TextSequence::new(text.clone())
765 .with_tokens(tokens)
766 .with_token_ids(token_ids);
767
768 let tensor = if let Some(ref transform) = self.transform {
769 transform.transform(sequence)?
770 } else {
771 TokenIdsToTensor.transform(sequence)?
772 };
773
774 Ok((tensor, label))
775 }
776 }
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782
783 #[test]
784 fn test_vocabulary() {
785 let texts = vec![
786 "hello world".to_string(),
787 "world hello".to_string(),
788 "foo bar".to_string(),
789 ];
790
791 let mut vocab = Vocabulary::new();
792 vocab.build_from_texts(&texts, 1).unwrap();
793
794 assert!(vocab.len() >= 6); let ids = vocab.encode("hello world");
799 let decoded = vocab.decode(&ids);
800 assert_eq!(decoded, "hello world");
801 }
802
803 #[test]
804 fn test_text_sequence() {
805 let seq = TextSequence::new("hello world".to_string())
806 .with_tokens(vec!["hello".to_string(), "world".to_string()])
807 .with_token_ids(vec![1, 2]);
808
809 assert_eq!(seq.len(), 2);
810 assert!(!seq.is_empty());
811 }
812
813 #[test]
814 fn test_text_classification_dataset() {
815 let texts = vec![
816 "positive example".to_string(),
817 "negative example".to_string(),
818 ];
819 let labels = vec![1, 0];
820
821 let dataset = TextClassificationDataset::new(texts, labels).unwrap();
822 assert_eq!(dataset.len(), 2);
823 assert_eq!(dataset.num_classes(), 2);
824
825 let (tensor, label) = dataset.get(0).unwrap();
826 assert_eq!(label, 1);
827 assert!(tensor.ndim() > 0);
828 }
829
830 #[test]
831 fn test_token_ids_to_tensor() {
832 let seq = TextSequence::new("test".to_string()).with_token_ids(vec![1, 2, 3]);
833
834 let transform = TokenIdsToTensor;
835 let result = transform.transform(seq).unwrap();
836
837 assert_eq!(result.shape().dims(), &[3]);
838 let data = result.to_vec().unwrap();
839 assert_eq!(data, vec![1.0, 2.0, 3.0]);
840 }
841
842 #[test]
843 fn test_text_transforms() {
844 use transforms::*;
845
846 let seq = TextSequence::new("Hello, World!".to_string())
847 .with_tokens(vec!["Hello,".to_string(), "World!".to_string()]);
848
849 let lowercase = ToLowercase;
851 let result = lowercase.transform(seq.clone()).unwrap();
852 assert_eq!(result.text, "hello, world!");
853
854 let remove_punct = RemovePunctuation;
856 let result = remove_punct.transform(seq.clone()).unwrap();
857 assert_eq!(result.text, "Hello World");
858
859 let seq_with_ids = seq.with_token_ids(vec![1, 2]);
861 let fixed_len = FixedLength::new(4, 0);
862 let result = fixed_len.transform(seq_with_ids).unwrap();
863 assert_eq!(result.token_ids.unwrap(), vec![1, 2, 0, 0]);
864
865 let add_special = AddSpecialTokens::new(100, 101);
867 let seq_with_ids = TextSequence::new("test".to_string()).with_token_ids(vec![1, 2]);
868 let result = add_special.transform(seq_with_ids).unwrap();
869 assert_eq!(result.token_ids.unwrap(), vec![100, 1, 2, 101]);
870 }
871
872 #[test]
873 fn test_ngrams() {
874 use transforms::*;
875
876 let seq = TextSequence::new("the quick brown fox".to_string()).with_tokens(vec![
877 "the".to_string(),
878 "quick".to_string(),
879 "brown".to_string(),
880 "fox".to_string(),
881 ]);
882
883 let bigrams = NGrams::new(2);
884 let result = bigrams.transform(seq).unwrap();
885
886 let expected_tokens = vec![
887 "the_quick".to_string(),
888 "quick_brown".to_string(),
889 "brown_fox".to_string(),
890 ];
891 assert_eq!(result.tokens.unwrap(), expected_tokens);
892 }
893
894 #[test]
895 fn test_imdb_dataset() {
896 use datasets::*;
897
898 let dataset = IMDB::new("/tmp", "train").unwrap();
899 assert_eq!(dataset.len(), 4);
900
901 let (tensor, label) = dataset.get(0).unwrap();
902 assert_eq!(label, 1); assert!(tensor.ndim() > 0);
904 }
905
906 #[test]
907 fn test_char_tokenizer() {
908 use transforms::*;
909
910 let seq = TextSequence::new("abc".to_string());
911 let char_tokenizer = CharTokenizer;
912 let result = char_tokenizer.transform(seq).unwrap();
913
914 assert_eq!(
915 result.tokens.unwrap(),
916 vec!["a".to_string(), "b".to_string(), "c".to_string()]
917 );
918 }
919
920 #[test]
921 fn test_simple_bpe() {
922 use transforms::*;
923
924 let seq = TextSequence::new("hello world".to_string());
925 let bpe = SimpleBPE::new(1000);
926 let result = bpe.transform(seq).unwrap();
927
928 assert!(result.tokens.is_some());
930 assert!(!result.tokens.unwrap().is_empty());
931 }
932}