1use std::{
2 fs,
3 path::{Path, PathBuf},
4 time::Instant,
5};
6
7use candle_core::{
8 quantized::{gguf_file, tokenizer::TokenizerFromGguf},
9 DType, Device, Tensor,
10};
11use candle_nn::VarBuilder;
12use candle_transformers::models::{
13 llama::{self, Llama},
14 quantized_llama,
15};
16use tokenizers::Tokenizer;
17
18use crate::{
19 device::{device_label, dtype_label, select_device, select_dtype},
20 loader::{resolve_model_source, ModelConfig, ModelSource},
21 sampler::{Sampler, SamplingConfig},
22 stats::{GenerateStats, StopReason},
23 token_stream::TokenOutputStream,
24 DTypeChoice, DeviceChoice, Result, WaxError,
25};
26
27pub trait StreamSink {
28 fn token(&mut self, text: &str) -> Result<()>;
29}
30
31impl<F> StreamSink for F
32where
33 F: FnMut(&str) -> Result<()>,
34{
35 fn token(&mut self, text: &str) -> Result<()> {
36 self(text)
37 }
38}
39
40#[derive(Debug, Clone)]
41pub struct EngineConfig {
42 pub model_dir: PathBuf,
43 pub device: DeviceChoice,
44 pub dtype: DTypeChoice,
45}
46
47impl EngineConfig {
48 pub fn new(model_dir: impl Into<PathBuf>) -> Self {
49 Self {
50 model_dir: model_dir.into(),
51 device: DeviceChoice::Auto,
52 dtype: DTypeChoice::Auto,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
58pub struct GenerateRequest {
59 pub prompt: String,
60 pub max_new_tokens: usize,
61 pub sampling: SamplingConfig,
62 pub stream: bool,
63}
64
65impl Default for GenerateRequest {
66 fn default() -> Self {
67 Self {
68 prompt: String::new(),
69 max_new_tokens: 64,
70 sampling: SamplingConfig::default(),
71 stream: true,
72 }
73 }
74}
75
76pub struct Engine {
77 model_dir: PathBuf,
78 model_name: String,
79 backend: ModelBackend,
80 tokenizer: Tokenizer,
81 eos_token_ids: Vec<u32>,
82 device: Device,
83 dtype: DType,
84 dtype_label: String,
85}
86
87enum ModelBackend {
88 Safetensors {
89 model: Llama,
90 llama_config: llama::Config,
91 },
92 Gguf {
93 model: quantized_llama::ModelWeights,
94 },
95}
96
97impl Engine {
98 pub fn load(config: EngineConfig) -> Result<Self> {
99 let model_dir = config.model_dir;
100 validate_model_path(&model_dir)?;
101
102 let source = resolve_model_source(&model_dir)?;
103 let device = select_device(config.device)?;
104 let dtype = select_dtype(config.dtype, &device);
105 let model_name = model_display_name(&model_dir);
106 let (backend, tokenizer, eos_token_ids, dtype_label) =
107 load_backend(&model_dir, source, &device, dtype)?;
108
109 Ok(Self {
110 model_dir,
111 model_name,
112 backend,
113 tokenizer,
114 eos_token_ids,
115 device,
116 dtype,
117 dtype_label,
118 })
119 }
120
121 pub fn model_dir(&self) -> &Path {
122 &self.model_dir
123 }
124
125 pub fn device_label(&self) -> String {
126 device_label(&self.device)
127 }
128
129 pub fn dtype_label(&self) -> String {
130 self.dtype_label.clone()
131 }
132
133 pub fn generate<S: StreamSink>(
134 &mut self,
135 request: GenerateRequest,
136 mut stream: S,
137 ) -> Result<GenerateStats> {
138 validate_generate_request(&request)?;
139
140 let mut all_tokens = self
141 .tokenizer
142 .encode(request.prompt.as_str(), true)
143 .map_err(WaxError::tokenizer)?
144 .get_ids()
145 .to_vec();
146 if all_tokens.is_empty() {
147 return Err(WaxError::InvalidRequest(
148 "prompt produced no tokens".to_string(),
149 ));
150 }
151
152 let prompt_tokens = all_tokens.len();
153 let mut cache = self.backend.new_cache(self.dtype, &self.device)?;
154 let mut sampler = Sampler::new(request.sampling)?;
155 let mut output = TokenOutputStream::new(self.tokenizer.clone());
156
157 let total_start = Instant::now();
158 let prefill_start = Instant::now();
159 let input = Tensor::new(all_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
160 let mut logits = self
161 .backend
162 .forward(&input, 0, cache.as_mut())?
163 .squeeze(0)?;
164 let prefill_ms = prefill_start.elapsed().as_secs_f64() * 1000.0;
165
166 let mut generated_tokens = 0usize;
167 let mut ttft_ms = None;
168 let mut decode_forward_secs = 0.0f64;
169 let mut stop_reason = StopReason::MaxTokens;
170
171 for (step, index_pos) in (0..request.max_new_tokens).zip(prompt_tokens..) {
172 let next_token = sampler.sample(&logits, &all_tokens)?;
173 generated_tokens += 1;
174
175 if ttft_ms.is_none() {
176 ttft_ms = Some(total_start.elapsed().as_secs_f64() * 1000.0);
177 }
178
179 all_tokens.push(next_token);
180 if self.is_eos(next_token) {
181 stop_reason = StopReason::Eos;
182 break;
183 }
184
185 if request.stream {
186 if let Some(delta) = output.next_token(next_token)? {
187 stream.token(&delta)?;
188 }
189 }
190
191 if step + 1 == request.max_new_tokens {
192 break;
193 }
194
195 let decode_start = Instant::now();
196 let input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
197 logits = self
198 .backend
199 .forward(&input, index_pos, cache.as_mut())?
200 .squeeze(0)?;
201 decode_forward_secs += decode_start.elapsed().as_secs_f64();
202 }
203
204 if request.stream {
205 if let Some(rest) = output.decode_rest()? {
206 stream.token(&rest)?;
207 }
208 }
209
210 let decode_tok_s = if generated_tokens > 1 && decode_forward_secs > 0.0 {
211 Some((generated_tokens - 1) as f64 / decode_forward_secs)
212 } else {
213 None
214 };
215
216 Ok(GenerateStats {
217 model: self.model_name.clone(),
218 device: self.device_label(),
219 dtype: self.dtype_label(),
220 prompt_tokens,
221 generated_tokens,
222 prefill_ms,
223 ttft_ms,
224 decode_tok_s,
225 total_ms: total_start.elapsed().as_secs_f64() * 1000.0,
226 stop_reason,
227 })
228 }
229
230 fn is_eos(&self, token: u32) -> bool {
231 self.eos_token_ids.contains(&token)
232 }
233}
234
235impl ModelBackend {
236 fn new_cache(&self, dtype: DType, device: &Device) -> Result<Option<llama::Cache>> {
237 match self {
238 Self::Safetensors { llama_config, .. } => {
239 Ok(Some(llama::Cache::new(true, dtype, llama_config, device)?))
240 }
241 Self::Gguf { .. } => Ok(None),
242 }
243 }
244
245 fn forward(
246 &mut self,
247 input: &Tensor,
248 index_pos: usize,
249 cache: Option<&mut llama::Cache>,
250 ) -> Result<Tensor> {
251 match self {
252 Self::Safetensors { model, .. } => {
253 let cache = cache.ok_or_else(|| {
254 WaxError::InvalidRequest("missing safetensors KV cache".to_string())
255 })?;
256 Ok(model.forward(input, index_pos, cache)?)
257 }
258 Self::Gguf { model } => Ok(model.forward(input, index_pos)?),
259 }
260 }
261}
262
263fn load_backend(
264 model_dir: &Path,
265 source: ModelSource,
266 device: &Device,
267 dtype: DType,
268) -> Result<(ModelBackend, Tokenizer, Vec<u32>, String)> {
269 match source {
270 ModelSource::Safetensors { files } => {
271 let tokenizer = load_tokenizer_json(model_dir)?;
272 let model_config = ModelConfig::load(model_dir)?;
273 let eos_token_ids = eos_token_ids(&tokenizer, model_config.llama.eos_token_id.as_ref());
274 let vb = unsafe { VarBuilder::from_mmaped_safetensors(&files, dtype, device)? };
275 let model = Llama::load(vb, &model_config.llama)?;
276 Ok((
277 ModelBackend::Safetensors {
278 model,
279 llama_config: model_config.llama,
280 },
281 tokenizer,
282 eos_token_ids,
283 dtype_label(dtype),
284 ))
285 }
286 ModelSource::Gguf { file } => {
287 let mut reader = fs::File::open(&file)?;
288 let content = gguf_file::Content::read(&mut reader)
289 .map_err(|err| err.with_path(file.clone()))?;
290 let tokenizer_base = if model_dir.is_file() {
291 file.parent().unwrap_or_else(|| Path::new("."))
292 } else {
293 model_dir
294 };
295 let tokenizer = match load_tokenizer_json(tokenizer_base) {
296 Ok(tokenizer) => tokenizer,
297 Err(WaxError::MissingModelFile(_)) => {
298 Tokenizer::from_gguf(&content).map_err(WaxError::tokenizer)?
299 }
300 Err(err) => return Err(err),
301 };
302 let eos_token_ids = eos_token_ids(&tokenizer, None);
303 let model = quantized_llama::ModelWeights::from_gguf(content, &mut reader, device)?;
304 Ok((
305 ModelBackend::Gguf { model },
306 tokenizer,
307 eos_token_ids,
308 "gguf".to_string(),
309 ))
310 }
311 ModelSource::Mlx { .. } => Err(WaxError::UnsupportedModelFormat {
312 format: "mlx",
313 message: "MLX model folders are not directly executable by Candle. Convert the model to Hugging Face safetensors or GGUF, then load that converted folder/file with wax.".to_string(),
314 }),
315 }
316}
317
318fn load_tokenizer_json(model_dir: &Path) -> Result<Tokenizer> {
319 let tokenizer_path = model_dir.join("tokenizer.json");
320 if !tokenizer_path.is_file() {
321 return Err(WaxError::MissingModelFile(tokenizer_path));
322 }
323 Tokenizer::from_file(&tokenizer_path).map_err(WaxError::tokenizer)
324}
325
326fn eos_token_ids(tokenizer: &Tokenizer, config_eos: Option<&llama::LlamaEosToks>) -> Vec<u32> {
327 let mut ids = match config_eos {
328 Some(llama::LlamaEosToks::Single(id)) => vec![*id],
329 Some(llama::LlamaEosToks::Multiple(ids)) => ids.clone(),
330 None => Vec::new(),
331 };
332
333 for token in ["</s>", "<|end_of_text|>", "<|endoftext|>"] {
334 if let Some(id) = tokenizer.token_to_id(token) {
335 if !ids.contains(&id) {
336 ids.push(id);
337 }
338 }
339 }
340 ids
341}
342
343fn validate_model_path(model_dir: &Path) -> Result<()> {
344 if !model_dir.is_dir() && !model_dir.is_file() {
345 return Err(WaxError::InvalidModelFolder {
346 path: model_dir.to_path_buf(),
347 reason: "path is not a directory or .gguf file".to_string(),
348 });
349 }
350 if model_dir.is_file() && model_dir.extension().is_none_or(|ext| ext != "gguf") {
351 return Err(WaxError::InvalidModelFolder {
352 path: model_dir.to_path_buf(),
353 reason: "file model paths must have a .gguf extension".to_string(),
354 });
355 }
356 Ok(())
357}
358
359fn model_display_name(model_path: &Path) -> String {
360 let name = if model_path.is_file() {
361 model_path.file_stem()
362 } else {
363 model_path.file_name()
364 };
365 name.and_then(|name| name.to_str())
366 .unwrap_or("local")
367 .to_string()
368}
369
370fn validate_generate_request(request: &GenerateRequest) -> Result<()> {
371 if request.prompt.is_empty() {
372 return Err(WaxError::InvalidRequest(
373 "prompt must not be empty".to_string(),
374 ));
375 }
376 if request.max_new_tokens == 0 {
377 return Err(WaxError::InvalidRequest(
378 "max-new-tokens must be > 0".to_string(),
379 ));
380 }
381 request.sampling.validate()
382}
383
384#[cfg(test)]
385mod tests {
386 use std::path::Path;
387
388 use super::{GenerateRequest, SamplingConfig};
389
390 #[test]
391 fn default_request_streams_sixty_four_tokens_max() {
392 let request = GenerateRequest {
393 prompt: "hello".to_string(),
394 ..GenerateRequest::default()
395 };
396
397 assert!(request.stream);
398 assert_eq!(request.max_new_tokens, 64);
399 }
400
401 #[test]
402 fn request_validation_rejects_empty_prompt() {
403 let err = super::validate_generate_request(&GenerateRequest {
404 prompt: String::new(),
405 max_new_tokens: 1,
406 sampling: SamplingConfig::default(),
407 stream: true,
408 })
409 .unwrap_err();
410
411 assert!(err.to_string().contains("prompt"));
412 }
413
414 #[test]
415 fn directory_model_name_preserves_version_suffix() {
416 let path = Path::new("/tmp/TinyLlama-1.1B-Chat-v1.0");
417
418 assert_eq!(super::model_display_name(path), "TinyLlama-1.1B-Chat-v1.0");
419 }
420
421 #[test]
422 fn gguf_file_model_name_removes_extension() {
423 let dir = tempfile::tempdir().unwrap();
424 let file = dir.path().join("model-q8_0.gguf");
425 std::fs::write(&file, b"").unwrap();
426
427 assert_eq!(super::model_display_name(&file), "model-q8_0");
428 }
429}