1use std::path::PathBuf;
2use std::sync::mpsc;
3use std::thread;
4use std::time::Instant;
5
6use yule_core::chat_template::{ChatTemplate, Role};
7use yule_core::error::{Result, YuleError};
8use yule_core::gguf::GgufParser;
9use yule_core::model::ModelMetadata;
10use yule_core::tokenizer::{BpeTokenizer, Tokenizer};
11use yule_infer::model_runner::{ModelRunner, TransformerRunner};
12use yule_infer::sampler::Sampler;
13use yule_infer::weight_loader::{TransformerWeights, WeightStore};
14use yule_infer::SamplingParams;
15use yule_verify::merkle::MerkleTree;
16
17pub struct InferenceHandle {
18 pub tx: mpsc::Sender<InferenceRequest>,
19 pub model_info: ModelInfo,
20 join: thread::JoinHandle<()>,
21}
22
23#[derive(Debug, Clone)]
24pub struct ModelInfo {
25 pub metadata: ModelMetadata,
26 pub tensor_count: usize,
27 pub file_size: u64,
28 pub merkle_root: String,
29}
30
31pub enum InferenceRequest {
32 Generate {
33 messages: Vec<(Role, String)>,
34 max_tokens: u32,
35 temperature: f32,
36 top_p: f32,
37 token_tx: tokio::sync::mpsc::UnboundedSender<TokenEvent>,
38 },
39 Tokenize {
40 text: String,
41 reply: tokio::sync::oneshot::Sender<TokenizeResult>,
42 },
43 Shutdown,
44}
45
46pub enum TokenEvent {
47 Token(String),
48 Done {
49 prompt_tokens: u32,
50 completion_tokens: u32,
51 finish_reason: String,
52 prefill_ms: f64,
53 decode_ms: f64,
54 },
55 Error(String),
56}
57
58pub struct TokenizeResult {
59 pub tokens: Vec<u32>,
60}
61
62impl InferenceHandle {
63 pub fn spawn(model_path: PathBuf) -> Result<Self> {
64 let (tx, rx) = mpsc::channel::<InferenceRequest>();
65 let (init_tx, init_rx) = mpsc::channel::<Result<ModelInfo>>();
66
67 let join = thread::Builder::new()
68 .name("inference".into())
69 .spawn(move || {
70 match inference_thread_init(&model_path) {
71 Ok(state) => {
72 init_tx.send(Ok(state.info.clone())).ok();
73 inference_loop(state, rx);
74 }
75 Err(e) => {
76 init_tx.send(Err(e)).ok();
77 }
78 }
79 })
80 .map_err(|e| YuleError::Api(format!("failed to spawn inference thread: {e}")))?;
81
82 let model_info = init_rx.recv()
83 .map_err(|_| YuleError::Api("inference thread died during init".into()))??;
84
85 Ok(Self { tx, model_info, join })
86 }
87
88 pub fn shutdown(self) {
89 let _ = self.tx.send(InferenceRequest::Shutdown);
90 let _ = self.join.join();
91 }
92}
93
94struct InferenceState {
95 info: ModelInfo,
96 runner: TransformerRunner<'static>,
97 tokenizer: BpeTokenizer,
98 chat_template: Option<ChatTemplate>,
99}
100
101fn inference_thread_init(model_path: &PathBuf) -> Result<InferenceState> {
102 let parser = GgufParser::new();
103 let gguf = parser.parse_file(model_path)?;
104 let loaded = gguf.to_loaded_model()?;
105
106 let mmap = yule_core::mmap::mmap_model(model_path)?;
107
108 let merkle_root = if (gguf.data_offset as usize) <= mmap.len() {
109 let tree = MerkleTree::new();
110 let root = tree.build(&mmap[gguf.data_offset as usize..]);
111 root.hash.iter().map(|b| format!("{b:02x}")).collect()
112 } else {
113 "none".into()
114 };
115
116 let tokenizer = BpeTokenizer::from_gguf(&gguf)?;
117 let chat_template = ChatTemplate::for_architecture(&loaded.metadata.architecture);
118
119 let info = ModelInfo {
120 metadata: loaded.metadata.clone(),
121 tensor_count: loaded.tensors.len(),
122 file_size: loaded.file_size,
123 merkle_root,
124 };
125
126 let mmap_ref: &'static memmap2::Mmap = Box::leak(Box::new(mmap));
128 let mmap_static: &'static [u8] = mmap_ref.as_ref();
129
130 let store = WeightStore::from_gguf(&gguf, mmap_static)?;
131 let weights = TransformerWeights::new(store);
132 let runner = TransformerRunner::new(weights)?;
133
134 Ok(InferenceState { info, runner, tokenizer, chat_template })
135}
136
137fn inference_loop(mut state: InferenceState, rx: mpsc::Receiver<InferenceRequest>) {
138 while let Ok(req) = rx.recv() {
139 match req {
140 InferenceRequest::Generate { messages, max_tokens, temperature, top_p, token_tx } => {
141 handle_generate(&mut state, messages, max_tokens, temperature, top_p, &token_tx);
142 }
143 InferenceRequest::Tokenize { text, reply } => {
144 match state.tokenizer.encode(&text) {
145 Ok(tokens) => { reply.send(TokenizeResult { tokens }).ok(); }
146 Err(_) => { reply.send(TokenizeResult { tokens: vec![] }).ok(); }
147 }
148 }
149 InferenceRequest::Shutdown => break,
150 }
151 }
152}
153
154fn handle_generate(
155 state: &mut InferenceState,
156 messages: Vec<(Role, String)>,
157 max_tokens: u32,
158 temperature: f32,
159 top_p: f32,
160 token_tx: &tokio::sync::mpsc::UnboundedSender<TokenEvent>,
161) {
162 state.runner.reset();
163
164 let prompt = if let Some(ref tmpl) = state.chat_template {
166 let msg_refs: Vec<(Role, &str)> = messages.iter()
167 .map(|(r, s)| (*r, s.as_str()))
168 .collect();
169 tmpl.apply(&msg_refs)
170 } else {
171 messages.iter().map(|(_, s)| s.as_str()).collect::<Vec<_>>().join("\n")
173 };
174
175 let mut tokens = Vec::new();
177 if let Some(bos) = state.tokenizer.bos_token() {
178 tokens.push(bos);
179 }
180 match state.tokenizer.encode(&prompt) {
181 Ok(encoded) => tokens.extend(encoded),
182 Err(e) => {
183 token_tx.send(TokenEvent::Error(format!("tokenize failed: {e}"))).ok();
184 return;
185 }
186 }
187
188 let prompt_tokens = tokens.len() as u32;
189
190 let prefill_start = Instant::now();
192 let mut logits = match state.runner.prefill(&tokens) {
193 Ok(l) => l,
194 Err(e) => {
195 token_tx.send(TokenEvent::Error(format!("prefill failed: {e}"))).ok();
196 return;
197 }
198 };
199 let prefill_ms = prefill_start.elapsed().as_secs_f64() * 1000.0;
200
201 let sampler = Sampler::new(SamplingParams {
203 temperature,
204 top_p,
205 ..Default::default()
206 });
207
208 let eos = state.tokenizer.eos_token();
209 let decode_start = Instant::now();
210 let mut generated = 0u32;
211 let mut finish_reason = "length".to_string();
212
213 for _ in 0..max_tokens {
214 let token = match sampler.sample(&logits) {
215 Ok(t) => t,
216 Err(e) => {
217 token_tx.send(TokenEvent::Error(format!("sample failed: {e}"))).ok();
218 return;
219 }
220 };
221
222 if Some(token) == eos {
223 finish_reason = "stop".to_string();
224 break;
225 }
226
227 match state.tokenizer.decode(&[token]) {
228 Ok(text) => {
229 if token_tx.send(TokenEvent::Token(text)).is_err() {
230 return; }
232 }
233 Err(e) => {
234 token_tx.send(TokenEvent::Error(format!("decode failed: {e}"))).ok();
235 return;
236 }
237 }
238
239 generated += 1;
240 logits = match state.runner.decode_step(token) {
241 Ok(l) => l,
242 Err(e) => {
243 token_tx.send(TokenEvent::Error(format!("decode_step failed: {e}"))).ok();
244 return;
245 }
246 };
247 }
248
249 let decode_ms = decode_start.elapsed().as_secs_f64() * 1000.0;
250
251 token_tx.send(TokenEvent::Done {
252 prompt_tokens,
253 completion_tokens: generated,
254 finish_reason,
255 prefill_ms,
256 decode_ms,
257 }).ok();
258}