tokenizers/decoders/
sequence.rs

1use crate::decoders::DecoderWrapper;
2use crate::tokenizer::{Decoder, Result};
3use crate::utils::macro_rules_attribute;
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct Sequence {
9    decoders: Vec<DecoderWrapper>,
10}
11
12impl Sequence {
13    pub fn new(decoders: Vec<DecoderWrapper>) -> Self {
14        Self { decoders }
15    }
16
17    pub fn get_decoders(&self) -> &[DecoderWrapper] {
18        &self.decoders
19    }
20
21    pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] {
22        &mut self.decoders
23    }
24}
25
26impl Decoder for Sequence {
27    fn decode_chain(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
28        for decoder in &self.decoders {
29            tokens = decoder.decode_chain(tokens)?;
30        }
31        Ok(tokens)
32    }
33}
34
35#[cfg(test)]
36mod tests {
37    use super::*;
38    use crate::decoders::ctc::CTC;
39    use crate::pre_tokenizers::metaspace::Metaspace;
40
41    #[test]
42    fn sequence_basic() {
43        let decoders = vec![
44            DecoderWrapper::CTC(CTC::default()),
45            DecoderWrapper::Metaspace(Metaspace::default()),
46        ];
47        let decoder = Sequence::new(decoders);
48        let tokens: Vec<String> = vec!["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"]
49            .into_iter()
50            .map(|s| s.to_string())
51            .collect();
52        let out_tokens = decoder.decode(tokens).unwrap();
53        assert_eq!(out_tokens, "Hi you");
54    }
55}