Skip to main content

yule_api/
inference.rs

1use std::path::PathBuf;
2use std::sync::mpsc;
3use std::thread;
4use std::time::Instant;
5
6use yule_core::chat_template::{ChatTemplate, Role};
7use yule_core::error::{Result, YuleError};
8use yule_core::gguf::GgufParser;
9use yule_core::model::ModelMetadata;
10use yule_core::tokenizer::{BpeTokenizer, Tokenizer};
11use yule_infer::model_runner::{ModelRunner, TransformerRunner};
12use yule_infer::sampler::Sampler;
13use yule_infer::weight_loader::{TransformerWeights, WeightStore};
14use yule_infer::SamplingParams;
15use yule_verify::merkle::MerkleTree;
16
17pub struct InferenceHandle {
18    pub tx: mpsc::Sender<InferenceRequest>,
19    pub model_info: ModelInfo,
20    join: thread::JoinHandle<()>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelInfo {
25    pub metadata: ModelMetadata,
26    pub tensor_count: usize,
27    pub file_size: u64,
28    pub merkle_root: String,
29}
30
31pub enum InferenceRequest {
32    Generate {
33        messages: Vec<(Role, String)>,
34        max_tokens: u32,
35        temperature: f32,
36        top_p: f32,
37        token_tx: tokio::sync::mpsc::UnboundedSender<TokenEvent>,
38    },
39    Tokenize {
40        text: String,
41        reply: tokio::sync::oneshot::Sender<TokenizeResult>,
42    },
43    Shutdown,
44}
45
46pub enum TokenEvent {
47    Token(String),
48    Done {
49        prompt_tokens: u32,
50        completion_tokens: u32,
51        finish_reason: String,
52        prefill_ms: f64,
53        decode_ms: f64,
54    },
55    Error(String),
56}
57
58pub struct TokenizeResult {
59    pub tokens: Vec<u32>,
60}
61
62impl InferenceHandle {
63    pub fn spawn(model_path: PathBuf) -> Result<Self> {
64        let (tx, rx) = mpsc::channel::<InferenceRequest>();
65        let (init_tx, init_rx) = mpsc::channel::<Result<ModelInfo>>();
66
67        let join = thread::Builder::new()
68            .name("inference".into())
69            .spawn(move || {
70                match inference_thread_init(&model_path) {
71                    Ok(state) => {
72                        init_tx.send(Ok(state.info.clone())).ok();
73                        inference_loop(state, rx);
74                    }
75                    Err(e) => {
76                        init_tx.send(Err(e)).ok();
77                    }
78                }
79            })
80            .map_err(|e| YuleError::Api(format!("failed to spawn inference thread: {e}")))?;
81
82        let model_info = init_rx.recv()
83            .map_err(|_| YuleError::Api("inference thread died during init".into()))??;
84
85        Ok(Self { tx, model_info, join })
86    }
87
88    pub fn shutdown(self) {
89        let _ = self.tx.send(InferenceRequest::Shutdown);
90        let _ = self.join.join();
91    }
92}
93
94struct InferenceState {
95    info: ModelInfo,
96    runner: TransformerRunner<'static>,
97    tokenizer: BpeTokenizer,
98    chat_template: Option<ChatTemplate>,
99}
100
101fn inference_thread_init(model_path: &PathBuf) -> Result<InferenceState> {
102    let parser = GgufParser::new();
103    let gguf = parser.parse_file(model_path)?;
104    let loaded = gguf.to_loaded_model()?;
105
106    let mmap = yule_core::mmap::mmap_model(model_path)?;
107
108    let merkle_root = if (gguf.data_offset as usize) <= mmap.len() {
109        let tree = MerkleTree::new();
110        let root = tree.build(&mmap[gguf.data_offset as usize..]);
111        root.hash.iter().map(|b| format!("{b:02x}")).collect()
112    } else {
113        "none".into()
114    };
115
116    let tokenizer = BpeTokenizer::from_gguf(&gguf)?;
117    let chat_template = ChatTemplate::for_architecture(&loaded.metadata.architecture);
118
119    let info = ModelInfo {
120        metadata: loaded.metadata.clone(),
121        tensor_count: loaded.tensors.len(),
122        file_size: loaded.file_size,
123        merkle_root,
124    };
125
126    // intentional leak — mmap lives for entire server lifetime on this thread
127    let mmap_ref: &'static memmap2::Mmap = Box::leak(Box::new(mmap));
128    let mmap_static: &'static [u8] = mmap_ref.as_ref();
129
130    let store = WeightStore::from_gguf(&gguf, mmap_static)?;
131    let weights = TransformerWeights::new(store);
132    let runner = TransformerRunner::new(weights)?;
133
134    Ok(InferenceState { info, runner, tokenizer, chat_template })
135}
136
137fn inference_loop(mut state: InferenceState, rx: mpsc::Receiver<InferenceRequest>) {
138    while let Ok(req) = rx.recv() {
139        match req {
140            InferenceRequest::Generate { messages, max_tokens, temperature, top_p, token_tx } => {
141                handle_generate(&mut state, messages, max_tokens, temperature, top_p, &token_tx);
142            }
143            InferenceRequest::Tokenize { text, reply } => {
144                match state.tokenizer.encode(&text) {
145                    Ok(tokens) => { reply.send(TokenizeResult { tokens }).ok(); }
146                    Err(_) => { reply.send(TokenizeResult { tokens: vec![] }).ok(); }
147                }
148            }
149            InferenceRequest::Shutdown => break,
150        }
151    }
152}
153
154fn handle_generate(
155    state: &mut InferenceState,
156    messages: Vec<(Role, String)>,
157    max_tokens: u32,
158    temperature: f32,
159    top_p: f32,
160    token_tx: &tokio::sync::mpsc::UnboundedSender<TokenEvent>,
161) {
162    state.runner.reset();
163
164    // build prompt from messages
165    let prompt = if let Some(ref tmpl) = state.chat_template {
166        let msg_refs: Vec<(Role, &str)> = messages.iter()
167            .map(|(r, s)| (*r, s.as_str()))
168            .collect();
169        tmpl.apply(&msg_refs)
170    } else {
171        // no template, just concatenate
172        messages.iter().map(|(_, s)| s.as_str()).collect::<Vec<_>>().join("\n")
173    };
174
175    // tokenize
176    let mut tokens = Vec::new();
177    if let Some(bos) = state.tokenizer.bos_token() {
178        tokens.push(bos);
179    }
180    match state.tokenizer.encode(&prompt) {
181        Ok(encoded) => tokens.extend(encoded),
182        Err(e) => {
183            token_tx.send(TokenEvent::Error(format!("tokenize failed: {e}"))).ok();
184            return;
185        }
186    }
187
188    let prompt_tokens = tokens.len() as u32;
189
190    // prefill
191    let prefill_start = Instant::now();
192    let mut logits = match state.runner.prefill(&tokens) {
193        Ok(l) => l,
194        Err(e) => {
195            token_tx.send(TokenEvent::Error(format!("prefill failed: {e}"))).ok();
196            return;
197        }
198    };
199    let prefill_ms = prefill_start.elapsed().as_secs_f64() * 1000.0;
200
201    // decode
202    let sampler = Sampler::new(SamplingParams {
203        temperature,
204        top_p,
205        ..Default::default()
206    });
207
208    let eos = state.tokenizer.eos_token();
209    let decode_start = Instant::now();
210    let mut generated = 0u32;
211    let mut finish_reason = "length".to_string();
212
213    for _ in 0..max_tokens {
214        let token = match sampler.sample(&logits) {
215            Ok(t) => t,
216            Err(e) => {
217                token_tx.send(TokenEvent::Error(format!("sample failed: {e}"))).ok();
218                return;
219            }
220        };
221
222        if Some(token) == eos {
223            finish_reason = "stop".to_string();
224            break;
225        }
226
227        match state.tokenizer.decode(&[token]) {
228            Ok(text) => {
229                if token_tx.send(TokenEvent::Token(text)).is_err() {
230                    return; // client disconnected
231                }
232            }
233            Err(e) => {
234                token_tx.send(TokenEvent::Error(format!("decode failed: {e}"))).ok();
235                return;
236            }
237        }
238
239        generated += 1;
240        logits = match state.runner.decode_step(token) {
241            Ok(l) => l,
242            Err(e) => {
243                token_tx.send(TokenEvent::Error(format!("decode_step failed: {e}"))).ok();
244                return;
245            }
246        };
247    }
248
249    let decode_ms = decode_start.elapsed().as_secs_f64() * 1000.0;
250
251    token_tx.send(TokenEvent::Done {
252        prompt_tokens,
253        completion_tokens: generated,
254        finish_reason,
255        prefill_ms,
256        decode_ms,
257    }).ok();
258}