Skip to main content

wax_core/
engine.rs

1use std::{
2    fs,
3    path::{Path, PathBuf},
4    time::Instant,
5};
6
7use candle_core::{
8    quantized::{gguf_file, tokenizer::TokenizerFromGguf},
9    DType, Device, Tensor,
10};
11use candle_nn::VarBuilder;
12use candle_transformers::models::{
13    llama::{self, Llama},
14    quantized_llama,
15};
16use tokenizers::Tokenizer;
17
18use crate::{
19    device::{device_label, dtype_label, select_device, select_dtype},
20    loader::{resolve_model_source, ModelConfig, ModelSource},
21    sampler::{Sampler, SamplingConfig},
22    stats::{GenerateStats, StopReason},
23    token_stream::TokenOutputStream,
24    DTypeChoice, DeviceChoice, Result, WaxError,
25};
26
27pub trait StreamSink {
28    fn token(&mut self, text: &str) -> Result<()>;
29}
30
31impl<F> StreamSink for F
32where
33    F: FnMut(&str) -> Result<()>,
34{
35    fn token(&mut self, text: &str) -> Result<()> {
36        self(text)
37    }
38}
39
40#[derive(Debug, Clone)]
41pub struct EngineConfig {
42    pub model_dir: PathBuf,
43    pub device: DeviceChoice,
44    pub dtype: DTypeChoice,
45}
46
47impl EngineConfig {
48    pub fn new(model_dir: impl Into<PathBuf>) -> Self {
49        Self {
50            model_dir: model_dir.into(),
51            device: DeviceChoice::Auto,
52            dtype: DTypeChoice::Auto,
53        }
54    }
55}
56
57#[derive(Debug, Clone)]
58pub struct GenerateRequest {
59    pub prompt: String,
60    pub max_new_tokens: usize,
61    pub sampling: SamplingConfig,
62    pub stream: bool,
63}
64
65impl Default for GenerateRequest {
66    fn default() -> Self {
67        Self {
68            prompt: String::new(),
69            max_new_tokens: 64,
70            sampling: SamplingConfig::default(),
71            stream: true,
72        }
73    }
74}
75
76pub struct Engine {
77    model_dir: PathBuf,
78    model_name: String,
79    backend: ModelBackend,
80    tokenizer: Tokenizer,
81    eos_token_ids: Vec<u32>,
82    device: Device,
83    dtype: DType,
84    dtype_label: String,
85}
86
87enum ModelBackend {
88    Safetensors {
89        model: Llama,
90        llama_config: llama::Config,
91    },
92    Gguf {
93        model: quantized_llama::ModelWeights,
94    },
95}
96
97impl Engine {
98    pub fn load(config: EngineConfig) -> Result<Self> {
99        let model_dir = config.model_dir;
100        validate_model_path(&model_dir)?;
101
102        let source = resolve_model_source(&model_dir)?;
103        let device = select_device(config.device)?;
104        let dtype = select_dtype(config.dtype, &device);
105        let model_name = model_display_name(&model_dir);
106        let (backend, tokenizer, eos_token_ids, dtype_label) =
107            load_backend(&model_dir, source, &device, dtype)?;
108
109        Ok(Self {
110            model_dir,
111            model_name,
112            backend,
113            tokenizer,
114            eos_token_ids,
115            device,
116            dtype,
117            dtype_label,
118        })
119    }
120
121    pub fn model_dir(&self) -> &Path {
122        &self.model_dir
123    }
124
125    pub fn device_label(&self) -> String {
126        device_label(&self.device)
127    }
128
129    pub fn dtype_label(&self) -> String {
130        self.dtype_label.clone()
131    }
132
133    pub fn generate<S: StreamSink>(
134        &mut self,
135        request: GenerateRequest,
136        mut stream: S,
137    ) -> Result<GenerateStats> {
138        validate_generate_request(&request)?;
139
140        let mut all_tokens = self
141            .tokenizer
142            .encode(request.prompt.as_str(), true)
143            .map_err(WaxError::tokenizer)?
144            .get_ids()
145            .to_vec();
146        if all_tokens.is_empty() {
147            return Err(WaxError::InvalidRequest(
148                "prompt produced no tokens".to_string(),
149            ));
150        }
151
152        let prompt_tokens = all_tokens.len();
153        let mut cache = self.backend.new_cache(self.dtype, &self.device)?;
154        let mut sampler = Sampler::new(request.sampling)?;
155        let mut output = TokenOutputStream::new(self.tokenizer.clone());
156
157        let total_start = Instant::now();
158        let prefill_start = Instant::now();
159        let input = Tensor::new(all_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
160        let mut logits = self
161            .backend
162            .forward(&input, 0, cache.as_mut())?
163            .squeeze(0)?;
164        let prefill_ms = prefill_start.elapsed().as_secs_f64() * 1000.0;
165
166        let mut generated_tokens = 0usize;
167        let mut ttft_ms = None;
168        let mut decode_forward_secs = 0.0f64;
169        let mut stop_reason = StopReason::MaxTokens;
170
171        for (step, index_pos) in (0..request.max_new_tokens).zip(prompt_tokens..) {
172            let next_token = sampler.sample(&logits, &all_tokens)?;
173            generated_tokens += 1;
174
175            if ttft_ms.is_none() {
176                ttft_ms = Some(total_start.elapsed().as_secs_f64() * 1000.0);
177            }
178
179            all_tokens.push(next_token);
180            if self.is_eos(next_token) {
181                stop_reason = StopReason::Eos;
182                break;
183            }
184
185            if request.stream {
186                if let Some(delta) = output.next_token(next_token)? {
187                    stream.token(&delta)?;
188                }
189            }
190
191            if step + 1 == request.max_new_tokens {
192                break;
193            }
194
195            let decode_start = Instant::now();
196            let input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
197            logits = self
198                .backend
199                .forward(&input, index_pos, cache.as_mut())?
200                .squeeze(0)?;
201            decode_forward_secs += decode_start.elapsed().as_secs_f64();
202        }
203
204        if request.stream {
205            if let Some(rest) = output.decode_rest()? {
206                stream.token(&rest)?;
207            }
208        }
209
210        let decode_tok_s = if generated_tokens > 1 && decode_forward_secs > 0.0 {
211            Some((generated_tokens - 1) as f64 / decode_forward_secs)
212        } else {
213            None
214        };
215
216        Ok(GenerateStats {
217            model: self.model_name.clone(),
218            device: self.device_label(),
219            dtype: self.dtype_label(),
220            prompt_tokens,
221            generated_tokens,
222            prefill_ms,
223            ttft_ms,
224            decode_tok_s,
225            total_ms: total_start.elapsed().as_secs_f64() * 1000.0,
226            stop_reason,
227        })
228    }
229
230    fn is_eos(&self, token: u32) -> bool {
231        self.eos_token_ids.contains(&token)
232    }
233}
234
235impl ModelBackend {
236    fn new_cache(&self, dtype: DType, device: &Device) -> Result<Option<llama::Cache>> {
237        match self {
238            Self::Safetensors { llama_config, .. } => {
239                Ok(Some(llama::Cache::new(true, dtype, llama_config, device)?))
240            }
241            Self::Gguf { .. } => Ok(None),
242        }
243    }
244
245    fn forward(
246        &mut self,
247        input: &Tensor,
248        index_pos: usize,
249        cache: Option<&mut llama::Cache>,
250    ) -> Result<Tensor> {
251        match self {
252            Self::Safetensors { model, .. } => {
253                let cache = cache.ok_or_else(|| {
254                    WaxError::InvalidRequest("missing safetensors KV cache".to_string())
255                })?;
256                Ok(model.forward(input, index_pos, cache)?)
257            }
258            Self::Gguf { model } => Ok(model.forward(input, index_pos)?),
259        }
260    }
261}
262
263fn load_backend(
264    model_dir: &Path,
265    source: ModelSource,
266    device: &Device,
267    dtype: DType,
268) -> Result<(ModelBackend, Tokenizer, Vec<u32>, String)> {
269    match source {
270        ModelSource::Safetensors { files } => {
271            let tokenizer = load_tokenizer_json(model_dir)?;
272            let model_config = ModelConfig::load(model_dir)?;
273            let eos_token_ids = eos_token_ids(&tokenizer, model_config.llama.eos_token_id.as_ref());
274            let vb = unsafe { VarBuilder::from_mmaped_safetensors(&files, dtype, device)? };
275            let model = Llama::load(vb, &model_config.llama)?;
276            Ok((
277                ModelBackend::Safetensors {
278                    model,
279                    llama_config: model_config.llama,
280                },
281                tokenizer,
282                eos_token_ids,
283                dtype_label(dtype),
284            ))
285        }
286        ModelSource::Gguf { file } => {
287            let mut reader = fs::File::open(&file)?;
288            let content = gguf_file::Content::read(&mut reader)
289                .map_err(|err| err.with_path(file.clone()))?;
290            let tokenizer_base = if model_dir.is_file() {
291                file.parent().unwrap_or_else(|| Path::new("."))
292            } else {
293                model_dir
294            };
295            let tokenizer = match load_tokenizer_json(tokenizer_base) {
296                Ok(tokenizer) => tokenizer,
297                Err(WaxError::MissingModelFile(_)) => {
298                    Tokenizer::from_gguf(&content).map_err(WaxError::tokenizer)?
299                }
300                Err(err) => return Err(err),
301            };
302            let eos_token_ids = eos_token_ids(&tokenizer, None);
303            let model = quantized_llama::ModelWeights::from_gguf(content, &mut reader, device)?;
304            Ok((
305                ModelBackend::Gguf { model },
306                tokenizer,
307                eos_token_ids,
308                "gguf".to_string(),
309            ))
310        }
311        ModelSource::Mlx { .. } => Err(WaxError::UnsupportedModelFormat {
312            format: "mlx",
313            message: "MLX model folders are not directly executable by Candle. Convert the model to Hugging Face safetensors or GGUF, then load that converted folder/file with wax.".to_string(),
314        }),
315    }
316}
317
318fn load_tokenizer_json(model_dir: &Path) -> Result<Tokenizer> {
319    let tokenizer_path = model_dir.join("tokenizer.json");
320    if !tokenizer_path.is_file() {
321        return Err(WaxError::MissingModelFile(tokenizer_path));
322    }
323    Tokenizer::from_file(&tokenizer_path).map_err(WaxError::tokenizer)
324}
325
326fn eos_token_ids(tokenizer: &Tokenizer, config_eos: Option<&llama::LlamaEosToks>) -> Vec<u32> {
327    let mut ids = match config_eos {
328        Some(llama::LlamaEosToks::Single(id)) => vec![*id],
329        Some(llama::LlamaEosToks::Multiple(ids)) => ids.clone(),
330        None => Vec::new(),
331    };
332
333    for token in ["</s>", "<|end_of_text|>", "<|endoftext|>"] {
334        if let Some(id) = tokenizer.token_to_id(token) {
335            if !ids.contains(&id) {
336                ids.push(id);
337            }
338        }
339    }
340    ids
341}
342
343fn validate_model_path(model_dir: &Path) -> Result<()> {
344    if !model_dir.is_dir() && !model_dir.is_file() {
345        return Err(WaxError::InvalidModelFolder {
346            path: model_dir.to_path_buf(),
347            reason: "path is not a directory or .gguf file".to_string(),
348        });
349    }
350    if model_dir.is_file() && model_dir.extension().is_none_or(|ext| ext != "gguf") {
351        return Err(WaxError::InvalidModelFolder {
352            path: model_dir.to_path_buf(),
353            reason: "file model paths must have a .gguf extension".to_string(),
354        });
355    }
356    Ok(())
357}
358
359fn model_display_name(model_path: &Path) -> String {
360    let name = if model_path.is_file() {
361        model_path.file_stem()
362    } else {
363        model_path.file_name()
364    };
365    name.and_then(|name| name.to_str())
366        .unwrap_or("local")
367        .to_string()
368}
369
370fn validate_generate_request(request: &GenerateRequest) -> Result<()> {
371    if request.prompt.is_empty() {
372        return Err(WaxError::InvalidRequest(
373            "prompt must not be empty".to_string(),
374        ));
375    }
376    if request.max_new_tokens == 0 {
377        return Err(WaxError::InvalidRequest(
378            "max-new-tokens must be > 0".to_string(),
379        ));
380    }
381    request.sampling.validate()
382}
383
384#[cfg(test)]
385mod tests {
386    use std::path::Path;
387
388    use super::{GenerateRequest, SamplingConfig};
389
390    #[test]
391    fn default_request_streams_sixty_four_tokens_max() {
392        let request = GenerateRequest {
393            prompt: "hello".to_string(),
394            ..GenerateRequest::default()
395        };
396
397        assert!(request.stream);
398        assert_eq!(request.max_new_tokens, 64);
399    }
400
401    #[test]
402    fn request_validation_rejects_empty_prompt() {
403        let err = super::validate_generate_request(&GenerateRequest {
404            prompt: String::new(),
405            max_new_tokens: 1,
406            sampling: SamplingConfig::default(),
407            stream: true,
408        })
409        .unwrap_err();
410
411        assert!(err.to_string().contains("prompt"));
412    }
413
414    #[test]
415    fn directory_model_name_preserves_version_suffix() {
416        let path = Path::new("/tmp/TinyLlama-1.1B-Chat-v1.0");
417
418        assert_eq!(super::model_display_name(path), "TinyLlama-1.1B-Chat-v1.0");
419    }
420
421    #[test]
422    fn gguf_file_model_name_removes_extension() {
423        let dir = tempfile::tempdir().unwrap();
424        let file = dir.path().join("model-q8_0.gguf");
425        std::fs::write(&file, b"").unwrap();
426
427        assert_eq!(super::model_display_name(&file), "model-q8_0");
428    }
429}