1use crate::inference::traits::InferenceBackend;
7use crate::NativeError;
8use futures_util::stream::Stream;
9use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
10use std::path::PathBuf;
11
12#[cfg(feature = "inference")]
13use spn_core::ChatRole;
14#[cfg(feature = "inference")]
15use std::path::Path;
16#[cfg(feature = "inference")]
17use std::sync::Arc;
18#[cfg(feature = "inference")]
19use tracing::{debug, info};
20
21#[cfg(feature = "inference")]
22use mistralrs::{
23 GgufModelBuilder, MemoryGpuConfig, Model, PagedAttentionMetaBuilder, RequestBuilder,
24 TextMessageRole, TextMessages,
25};
26#[cfg(feature = "inference")]
27use tokio::sync::RwLock;
28
29#[allow(dead_code)] pub struct NativeRuntime {
46 #[cfg(feature = "inference")]
48 model: Option<Arc<RwLock<Model>>>,
49
50 model_info: Option<ModelInfo>,
52
53 model_path: Option<PathBuf>,
55
56 config: Option<LoadConfig>,
58}
59
60impl std::fmt::Debug for NativeRuntime {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("NativeRuntime")
64 .field("model_info", &self.model_info)
65 .field("model_path", &self.model_path)
66 .field("config", &self.config)
67 .field("is_loaded", &self.is_loaded())
68 .finish()
69 }
70}
71
72impl Clone for NativeRuntime {
74 fn clone(&self) -> Self {
75 Self {
76 #[cfg(feature = "inference")]
77 model: self.model.clone(),
78 model_info: self.model_info.clone(),
79 model_path: self.model_path.clone(),
80 config: self.config.clone(),
81 }
82 }
83}
84
85impl NativeRuntime {
86 #[must_use]
91 pub fn new() -> Self {
92 Self {
93 #[cfg(feature = "inference")]
94 model: None,
95 model_info: None,
96 model_path: None,
97 config: None,
98 }
99 }
100
101 #[must_use]
103 pub fn model_path(&self) -> Option<&PathBuf> {
104 self.model_path.as_ref()
105 }
106
107 #[must_use]
109 pub fn config(&self) -> Option<&LoadConfig> {
110 self.config.as_ref()
111 }
112
113 #[cfg(feature = "inference")]
115 #[allow(dead_code)] fn convert_role(role: ChatRole) -> TextMessageRole {
117 match role {
118 ChatRole::System => TextMessageRole::System,
119 ChatRole::User => TextMessageRole::User,
120 ChatRole::Assistant => TextMessageRole::Assistant,
121 }
122 }
123}
124
125impl Default for NativeRuntime {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[cfg(feature = "inference")]
132impl InferenceBackend for NativeRuntime {
133 async fn load(&mut self, model_path: PathBuf, config: LoadConfig) -> Result<(), NativeError> {
134 info!(?model_path, "Loading GGUF model");
135
136 if self.model.is_some() {
138 self.unload().await?;
139 }
140
141 if !model_path.exists() {
143 return Err(NativeError::ModelNotFound {
144 repo: "local".to_string(),
145 filename: model_path.to_string_lossy().to_string(),
146 });
147 }
148
149 let parent = model_path
152 .parent()
153 .map(|p| p.to_string_lossy().to_string())
154 .unwrap_or_else(|| ".".to_string());
155 let filename = model_path
156 .file_name()
157 .map(|f| f.to_string_lossy().to_string())
158 .ok_or_else(|| {
159 NativeError::InvalidConfig("Invalid model path: no filename".to_string())
160 })?;
161
162 debug!(gpu_layers = config.gpu_layers, %parent, %filename, "Building model");
163
164 let context_size = config.context_size.unwrap_or(2048);
168 let model = GgufModelBuilder::new(parent, vec![filename])
169 .with_logging()
170 .with_paged_attn(|| {
171 PagedAttentionMetaBuilder::default()
172 .with_block_size(32)
173 .with_gpu_memory(MemoryGpuConfig::ContextSize(context_size as usize))
174 .build()
175 })
176 .map_err(|e| NativeError::InvalidConfig(format!("PagedAttention config error: {e}")))?
177 .build()
178 .await
179 .map_err(|e| NativeError::InvalidConfig(format!("Failed to build model: {e}")))?;
180
181 let info = ModelInfo {
183 name: model_path
184 .file_stem()
185 .map(|s| s.to_string_lossy().to_string())
186 .unwrap_or_else(|| "unknown".to_string()),
187 size: tokio::fs::metadata(&model_path)
188 .await
189 .map(|m| m.len())
190 .unwrap_or(0),
191 quantization: extract_quantization_from_path(&model_path),
192 parameters: None,
193 digest: None,
194 };
195
196 self.model = Some(Arc::new(RwLock::new(model)));
197 self.model_info = Some(info);
198 self.model_path = Some(model_path);
199 self.config = Some(config);
200
201 info!("Model loaded successfully");
202 Ok(())
203 }
204
205 async fn unload(&mut self) -> Result<(), NativeError> {
206 if self.model.is_some() {
207 info!("Unloading model");
208 self.model = None;
209 self.model_info = None;
210 self.model_path = None;
211 self.config = None;
212 }
213 Ok(())
214 }
215
216 fn is_loaded(&self) -> bool {
217 self.model.is_some()
218 }
219
220 fn model_info(&self) -> Option<&ModelInfo> {
221 self.model_info.as_ref()
222 }
223
224 async fn infer(&self, prompt: &str, options: ChatOptions) -> Result<ChatResponse, NativeError> {
225 let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
226
227 let model = model.read().await;
228
229 let messages = TextMessages::new().add_message(TextMessageRole::User, prompt);
232
233 debug!(
234 temperature = options.temperature,
235 max_tokens = options.max_tokens,
236 "Running inference"
237 );
238
239 let mut request = RequestBuilder::from(messages);
241
242 if let Some(temp) = options.temperature {
244 request = request.set_sampler_temperature(f64::from(temp));
245 }
246
247 if let Some(max_tokens) = options.max_tokens {
249 request = request.set_sampler_max_len(max_tokens as usize);
250 }
251
252 let response = model
254 .send_chat_request(request)
255 .await
256 .map_err(|e| NativeError::InvalidConfig(format!("Inference failed: {e}")))?;
257
258 let content = response
260 .choices
261 .first()
262 .and_then(|c| c.message.content.clone())
263 .ok_or_else(|| {
264 NativeError::InvalidConfig("Model returned empty response (no choices)".to_string())
265 })?;
266
267 debug!(
269 prompt_tokens = response.usage.prompt_tokens,
270 completion_tokens = response.usage.completion_tokens,
271 avg_prompt_tok_per_sec = ?response.usage.avg_prompt_tok_per_sec,
272 avg_compl_tok_per_sec = ?response.usage.avg_compl_tok_per_sec,
273 "Inference completed"
274 );
275
276 Ok(ChatResponse {
277 message: spn_core::ChatMessage {
278 role: ChatRole::Assistant,
279 content,
280 },
281 done: true,
282 total_duration: None,
283 prompt_eval_count: Some(response.usage.prompt_tokens as u32),
284 eval_count: Some(response.usage.completion_tokens as u32),
285 })
286 }
287
288 async fn infer_stream(
289 &self,
290 prompt: &str,
291 options: ChatOptions,
292 ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
293 use async_stream::stream;
294 use mistralrs::Response;
295 use tokio::sync::mpsc;
296
297 let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
298 let model_arc = Arc::clone(model);
299 let prompt_owned = prompt.to_string();
300
301 let (tx, mut rx) = mpsc::channel::<Result<String, NativeError>>(32);
303
304 tokio::spawn(async move {
306 let model = model_arc.read().await;
307
308 let messages = TextMessages::new().add_message(TextMessageRole::User, &prompt_owned);
310
311 let mut request = RequestBuilder::from(messages);
313
314 if let Some(temp) = options.temperature {
315 request = request.set_sampler_temperature(f64::from(temp));
316 }
317
318 if let Some(max_tokens) = options.max_tokens {
319 request = request.set_sampler_max_len(max_tokens as usize);
320 }
321
322 match model.stream_chat_request(request).await {
324 Ok(mut stream) => {
325 while let Some(chunk) = stream.next().await {
326 match chunk {
327 Response::Chunk(chunk_response) => {
328 if let Some(choice) = chunk_response.choices.first() {
329 if let Some(text) = &choice.delta.content {
330 if tx.send(Ok(text.clone())).await.is_err() {
331 break;
333 }
334 }
335 }
336 }
337 Response::Done(_) => {
338 debug!("Streaming completed");
339 break;
340 }
341 Response::ModelError(msg, _) => {
342 let _ = tx
343 .send(Err(NativeError::InvalidConfig(format!(
344 "Model error: {}",
345 msg
346 ))))
347 .await;
348 break;
349 }
350 Response::ValidationError(err) => {
351 let _ = tx
352 .send(Err(NativeError::InvalidConfig(format!(
353 "Validation error: {:?}",
354 err
355 ))))
356 .await;
357 break;
358 }
359 Response::InternalError(err) => {
360 let _ = tx
361 .send(Err(NativeError::InvalidConfig(format!(
362 "Internal error: {:?}",
363 err
364 ))))
365 .await;
366 break;
367 }
368 _ => {
369 }
371 }
372 }
373 }
374 Err(e) => {
375 let _ = tx
376 .send(Err(NativeError::InvalidConfig(format!(
377 "Failed to start streaming: {}",
378 e
379 ))))
380 .await;
381 }
382 }
383 });
384
385 Ok(stream! {
387 while let Some(result) = rx.recv().await {
388 yield result;
389 }
390 })
391 }
392}
393
394#[cfg(feature = "inference")]
398fn extract_quantization_from_path(path: &Path) -> Option<String> {
399 let filename = path.file_name()?.to_string_lossy();
400 crate::extract_quantization(&filename)
401}
402
403#[cfg(not(feature = "inference"))]
405impl InferenceBackend for NativeRuntime {
406 async fn load(&mut self, _model_path: PathBuf, _config: LoadConfig) -> Result<(), NativeError> {
407 Err(NativeError::InvalidConfig(
408 "Inference feature not enabled. Rebuild with --features inference".to_string(),
409 ))
410 }
411
412 async fn unload(&mut self) -> Result<(), NativeError> {
413 Ok(())
414 }
415
416 fn is_loaded(&self) -> bool {
417 false
418 }
419
420 fn model_info(&self) -> Option<&ModelInfo> {
421 None
422 }
423
424 async fn infer(
425 &self,
426 _prompt: &str,
427 _options: ChatOptions,
428 ) -> Result<ChatResponse, NativeError> {
429 Err(NativeError::InvalidConfig(
430 "Inference feature not enabled. Rebuild with --features inference".to_string(),
431 ))
432 }
433
434 async fn infer_stream(
435 &self,
436 _prompt: &str,
437 _options: ChatOptions,
438 ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
439 Err::<futures_util::stream::Empty<Result<String, NativeError>>, _>(
440 NativeError::InvalidConfig(
441 "Inference feature not enabled. Rebuild with --features inference".to_string(),
442 ),
443 )
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_runtime_creation() {
453 let runtime = NativeRuntime::new();
454 assert!(!runtime.is_loaded());
455 assert!(runtime.model_info().is_none());
456 assert!(runtime.model_path().is_none());
457 }
458
459 #[test]
460 fn test_runtime_default() {
461 let runtime = NativeRuntime::default();
462 assert!(!runtime.is_loaded());
463 }
464
465 #[tokio::test]
466 #[cfg(not(feature = "inference"))]
467 async fn test_load_without_feature() {
468 let mut runtime = NativeRuntime::new();
469 let result = runtime
470 .load(PathBuf::from("test.gguf"), LoadConfig::default())
471 .await;
472 assert!(result.is_err());
473 assert!(result
474 .unwrap_err()
475 .to_string()
476 .contains("Inference feature not enabled"));
477 }
478
479 #[tokio::test]
480 #[cfg(not(feature = "inference"))]
481 async fn test_infer_without_feature() {
482 let runtime = NativeRuntime::new();
483 let result = runtime.infer("test", ChatOptions::default()).await;
484 assert!(result.is_err());
485 }
486}