Skip to main content

wordchipper_cli_util/
model_selector.rs

1use std::sync::Arc;
2
3use wordchipper::{
4    Tokenizer,
5    UnifiedTokenVocab,
6    disk_cache::WordchipperDiskCache,
7};
8
9/// Model selector arg group.
10#[derive(clap::Args, Debug)]
11#[group(required = true, multiple = false)]
12pub struct ModelSelectorArgs {
13    /// Model to use for encoding.
14    #[arg(long, default_value = "openai:r50k_base")]
15    model: String,
16}
17
18impl ModelSelectorArgs {
19    /// Get the model name.
20    pub fn model(&self) -> &str {
21        &self.model
22    }
23
24    /// Load the vocabulary.
25    pub fn load_vocab(
26        &self,
27        disk_cache: &mut WordchipperDiskCache,
28    ) -> Result<Arc<UnifiedTokenVocab<u32>>, Box<dyn std::error::Error>> {
29        let vocab = wordchipper::load_vocab(self.model(), disk_cache)?
30            .vocab()
31            .clone();
32
33        Ok(vocab)
34    }
35
36    /// Load the tokenizer.
37    pub fn load_tokenizer(
38        &self,
39        disk_cache: &mut WordchipperDiskCache,
40    ) -> Result<Arc<Tokenizer<u32>>, Box<dyn std::error::Error>> {
41        let vocab = self.load_vocab(disk_cache)?;
42        let tokenizer = wordchipper::TokenizerOptions::default().build(vocab);
43        Ok(tokenizer)
44    }
45}