Skip to main content

spn_native/inference/
runtime.rs

1//! Native runtime implementation using mistral.rs.
2//!
3//! This module provides the `NativeRuntime` struct which implements
4//! the `InferenceBackend` trait using the mistral.rs library.
5
6use crate::inference::traits::InferenceBackend;
7use crate::NativeError;
8use futures_util::stream::Stream;
9use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
10use std::path::PathBuf;
11
12#[cfg(feature = "inference")]
13use spn_core::ChatRole;
14#[cfg(feature = "inference")]
15use std::path::Path;
16#[cfg(feature = "inference")]
17use std::sync::Arc;
18#[cfg(feature = "inference")]
19use tracing::{debug, info};
20
21#[cfg(feature = "inference")]
22use mistralrs::{
23    GgufModelBuilder, MemoryGpuConfig, Model, PagedAttentionMetaBuilder, RequestBuilder,
24    TextMessageRole, TextMessages,
25};
26#[cfg(feature = "inference")]
27use tokio::sync::RwLock;
28
29/// Native runtime for local LLM inference.
30///
31/// Uses mistral.rs for high-performance inference on GGUF models.
32/// Supports CPU and GPU (Metal on macOS, CUDA on Linux) acceleration.
33///
34/// # Example
35///
36/// ```ignore
37/// use spn_native::inference::NativeRuntime;
38/// use spn_core::LoadConfig;
39///
40/// let mut runtime = NativeRuntime::new()?;
41/// runtime.load("model.gguf".into(), LoadConfig::default()).await?;
42/// let response = runtime.infer("Hello!", Default::default()).await?;
43/// ```
44#[allow(dead_code)] // Fields used only with inference feature
45pub struct NativeRuntime {
46    /// The loaded model (None if no model is loaded).
47    #[cfg(feature = "inference")]
48    model: Option<Arc<RwLock<Model>>>,
49
50    /// Metadata about the loaded model.
51    model_info: Option<ModelInfo>,
52
53    /// Path to the currently loaded model.
54    model_path: Option<PathBuf>,
55
56    /// Load configuration used for the current model.
57    config: Option<LoadConfig>,
58}
59
60// Manual Debug implementation (Model doesn't implement Debug)
61impl std::fmt::Debug for NativeRuntime {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("NativeRuntime")
64            .field("model_info", &self.model_info)
65            .field("model_path", &self.model_path)
66            .field("config", &self.config)
67            .field("is_loaded", &self.is_loaded())
68            .finish()
69    }
70}
71
72// Manual Clone implementation (clones the Arc, not the model itself)
73impl Clone for NativeRuntime {
74    fn clone(&self) -> Self {
75        Self {
76            #[cfg(feature = "inference")]
77            model: self.model.clone(),
78            model_info: self.model_info.clone(),
79            model_path: self.model_path.clone(),
80            config: self.config.clone(),
81        }
82    }
83}
84
85impl NativeRuntime {
86    /// Create a new native runtime.
87    ///
88    /// The runtime is created without a model loaded. Call `load()` to
89    /// load a model before running inference.
90    #[must_use]
91    pub fn new() -> Self {
92        Self {
93            #[cfg(feature = "inference")]
94            model: None,
95            model_info: None,
96            model_path: None,
97            config: None,
98        }
99    }
100
101    /// Get the path to the currently loaded model.
102    #[must_use]
103    pub fn model_path(&self) -> Option<&PathBuf> {
104        self.model_path.as_ref()
105    }
106
107    /// Get the load configuration for the current model.
108    #[must_use]
109    pub fn config(&self) -> Option<&LoadConfig> {
110        self.config.as_ref()
111    }
112
113    /// Convert spn-core ChatRole to mistral.rs TextMessageRole.
114    #[cfg(feature = "inference")]
115    #[allow(dead_code)] // Will be used for streaming support
116    fn convert_role(role: ChatRole) -> TextMessageRole {
117        match role {
118            ChatRole::System => TextMessageRole::System,
119            ChatRole::User => TextMessageRole::User,
120            ChatRole::Assistant => TextMessageRole::Assistant,
121        }
122    }
123}
124
125impl Default for NativeRuntime {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131#[cfg(feature = "inference")]
132impl InferenceBackend for NativeRuntime {
133    async fn load(&mut self, model_path: PathBuf, config: LoadConfig) -> Result<(), NativeError> {
134        info!(?model_path, "Loading GGUF model");
135
136        // Unload any existing model
137        if self.model.is_some() {
138            self.unload().await?;
139        }
140
141        // Validate path exists
142        if !model_path.exists() {
143            return Err(NativeError::ModelNotFound {
144                repo: "local".to_string(),
145                filename: model_path.to_string_lossy().to_string(),
146            });
147        }
148
149        // Build the model using GgufModelBuilder
150        // API: GgufModelBuilder::new(directory, vec![filename])
151        let parent = model_path
152            .parent()
153            .map(|p| p.to_string_lossy().to_string())
154            .unwrap_or_else(|| ".".to_string());
155        let filename = model_path
156            .file_name()
157            .map(|f| f.to_string_lossy().to_string())
158            .ok_or_else(|| {
159                NativeError::InvalidConfig("Invalid model path: no filename".to_string())
160            })?;
161
162        debug!(gpu_layers = config.gpu_layers, %parent, %filename, "Building model");
163
164        // Build model with PagedAttention for better memory management.
165        // PagedAttention enables efficient KV cache handling for longer contexts.
166        // Use context_size from LoadConfig, defaulting to 2048 if not specified.
167        let context_size = config.context_size.unwrap_or(2048);
168        let model = GgufModelBuilder::new(parent, vec![filename])
169            .with_logging()
170            .with_paged_attn(|| {
171                PagedAttentionMetaBuilder::default()
172                    .with_block_size(32)
173                    .with_gpu_memory(MemoryGpuConfig::ContextSize(context_size as usize))
174                    .build()
175            })
176            .map_err(|e| NativeError::InvalidConfig(format!("PagedAttention config error: {e}")))?
177            .build()
178            .await
179            .map_err(|e| NativeError::InvalidConfig(format!("Failed to build model: {e}")))?;
180
181        // Extract model info from the loaded model
182        let info = ModelInfo {
183            name: model_path
184                .file_stem()
185                .map(|s| s.to_string_lossy().to_string())
186                .unwrap_or_else(|| "unknown".to_string()),
187            size: tokio::fs::metadata(&model_path)
188                .await
189                .map(|m| m.len())
190                .unwrap_or(0),
191            quantization: extract_quantization_from_path(&model_path),
192            parameters: None,
193            digest: None,
194        };
195
196        self.model = Some(Arc::new(RwLock::new(model)));
197        self.model_info = Some(info);
198        self.model_path = Some(model_path);
199        self.config = Some(config);
200
201        info!("Model loaded successfully");
202        Ok(())
203    }
204
205    async fn unload(&mut self) -> Result<(), NativeError> {
206        if self.model.is_some() {
207            info!("Unloading model");
208            self.model = None;
209            self.model_info = None;
210            self.model_path = None;
211            self.config = None;
212        }
213        Ok(())
214    }
215
216    fn is_loaded(&self) -> bool {
217        self.model.is_some()
218    }
219
220    fn model_info(&self) -> Option<&ModelInfo> {
221        self.model_info.as_ref()
222    }
223
224    async fn infer(&self, prompt: &str, options: ChatOptions) -> Result<ChatResponse, NativeError> {
225        let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
226
227        let model = model.read().await;
228
229        // Build messages - just user prompt for now
230        // System messages should be passed as part of the prompt or via messages API
231        let messages = TextMessages::new().add_message(TextMessageRole::User, prompt);
232
233        debug!(
234            temperature = options.temperature,
235            max_tokens = options.max_tokens,
236            "Running inference"
237        );
238
239        // Build request with sampling parameters
240        let mut request = RequestBuilder::from(messages);
241
242        // Apply temperature if provided (convert f32 to f64)
243        if let Some(temp) = options.temperature {
244            request = request.set_sampler_temperature(f64::from(temp));
245        }
246
247        // Apply max_tokens if provided
248        if let Some(max_tokens) = options.max_tokens {
249            request = request.set_sampler_max_len(max_tokens as usize);
250        }
251
252        // Send request with sampling parameters
253        let response = model
254            .send_chat_request(request)
255            .await
256            .map_err(|e| NativeError::InvalidConfig(format!("Inference failed: {e}")))?;
257
258        // Extract response content - fail if no content returned
259        let content = response
260            .choices
261            .first()
262            .and_then(|c| c.message.content.clone())
263            .ok_or_else(|| {
264                NativeError::InvalidConfig("Model returned empty response (no choices)".to_string())
265            })?;
266
267        // Log performance metrics for debugging and optimization
268        debug!(
269            prompt_tokens = response.usage.prompt_tokens,
270            completion_tokens = response.usage.completion_tokens,
271            avg_prompt_tok_per_sec = ?response.usage.avg_prompt_tok_per_sec,
272            avg_compl_tok_per_sec = ?response.usage.avg_compl_tok_per_sec,
273            "Inference completed"
274        );
275
276        Ok(ChatResponse {
277            message: spn_core::ChatMessage {
278                role: ChatRole::Assistant,
279                content,
280            },
281            done: true,
282            total_duration: None,
283            prompt_eval_count: Some(response.usage.prompt_tokens as u32),
284            eval_count: Some(response.usage.completion_tokens as u32),
285        })
286    }
287
288    async fn infer_stream(
289        &self,
290        prompt: &str,
291        options: ChatOptions,
292    ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
293        use async_stream::stream;
294        use mistralrs::Response;
295        use tokio::sync::mpsc;
296
297        let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
298        let model_arc = Arc::clone(model);
299        let prompt_owned = prompt.to_string();
300
301        // Create channel for streaming chunks
302        let (tx, mut rx) = mpsc::channel::<Result<String, NativeError>>(32);
303
304        // Spawn streaming task that holds the model lock
305        tokio::spawn(async move {
306            let model = model_arc.read().await;
307
308            // Build messages
309            let messages = TextMessages::new().add_message(TextMessageRole::User, &prompt_owned);
310
311            // Build request with sampling parameters
312            let mut request = RequestBuilder::from(messages);
313
314            if let Some(temp) = options.temperature {
315                request = request.set_sampler_temperature(f64::from(temp));
316            }
317
318            if let Some(max_tokens) = options.max_tokens {
319                request = request.set_sampler_max_len(max_tokens as usize);
320            }
321
322            // Stream chat request
323            match model.stream_chat_request(request).await {
324                Ok(mut stream) => {
325                    while let Some(chunk) = stream.next().await {
326                        match chunk {
327                            Response::Chunk(chunk_response) => {
328                                if let Some(choice) = chunk_response.choices.first() {
329                                    if let Some(text) = &choice.delta.content {
330                                        if tx.send(Ok(text.clone())).await.is_err() {
331                                            // Receiver dropped, stop streaming
332                                            break;
333                                        }
334                                    }
335                                }
336                            }
337                            Response::Done(_) => {
338                                debug!("Streaming completed");
339                                break;
340                            }
341                            Response::ModelError(msg, _) => {
342                                let _ = tx
343                                    .send(Err(NativeError::InvalidConfig(format!(
344                                        "Model error: {}",
345                                        msg
346                                    ))))
347                                    .await;
348                                break;
349                            }
350                            Response::ValidationError(err) => {
351                                let _ = tx
352                                    .send(Err(NativeError::InvalidConfig(format!(
353                                        "Validation error: {:?}",
354                                        err
355                                    ))))
356                                    .await;
357                                break;
358                            }
359                            Response::InternalError(err) => {
360                                let _ = tx
361                                    .send(Err(NativeError::InvalidConfig(format!(
362                                        "Internal error: {:?}",
363                                        err
364                                    ))))
365                                    .await;
366                                break;
367                            }
368                            _ => {
369                                // Other response types, continue
370                            }
371                        }
372                    }
373                }
374                Err(e) => {
375                    let _ = tx
376                        .send(Err(NativeError::InvalidConfig(format!(
377                            "Failed to start streaming: {}",
378                            e
379                        ))))
380                        .await;
381                }
382            }
383        });
384
385        // Convert mpsc receiver to Stream
386        Ok(stream! {
387            while let Some(result) = rx.recv().await {
388                yield result;
389            }
390        })
391    }
392}
393
394/// Extract quantization from file path.
395///
396/// Delegates to [`crate::extract_quantization`] for the actual parsing.
397#[cfg(feature = "inference")]
398fn extract_quantization_from_path(path: &Path) -> Option<String> {
399    let filename = path.file_name()?.to_string_lossy();
400    crate::extract_quantization(&filename)
401}
402
403// Stub implementation when inference feature is not enabled
404#[cfg(not(feature = "inference"))]
405impl InferenceBackend for NativeRuntime {
406    async fn load(&mut self, _model_path: PathBuf, _config: LoadConfig) -> Result<(), NativeError> {
407        Err(NativeError::InvalidConfig(
408            "Inference feature not enabled. Rebuild with --features inference".to_string(),
409        ))
410    }
411
412    async fn unload(&mut self) -> Result<(), NativeError> {
413        Ok(())
414    }
415
416    fn is_loaded(&self) -> bool {
417        false
418    }
419
420    fn model_info(&self) -> Option<&ModelInfo> {
421        None
422    }
423
424    async fn infer(
425        &self,
426        _prompt: &str,
427        _options: ChatOptions,
428    ) -> Result<ChatResponse, NativeError> {
429        Err(NativeError::InvalidConfig(
430            "Inference feature not enabled. Rebuild with --features inference".to_string(),
431        ))
432    }
433
434    async fn infer_stream(
435        &self,
436        _prompt: &str,
437        _options: ChatOptions,
438    ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
439        Err::<futures_util::stream::Empty<Result<String, NativeError>>, _>(
440            NativeError::InvalidConfig(
441                "Inference feature not enabled. Rebuild with --features inference".to_string(),
442            ),
443        )
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_runtime_creation() {
453        let runtime = NativeRuntime::new();
454        assert!(!runtime.is_loaded());
455        assert!(runtime.model_info().is_none());
456        assert!(runtime.model_path().is_none());
457    }
458
459    #[test]
460    fn test_runtime_default() {
461        let runtime = NativeRuntime::default();
462        assert!(!runtime.is_loaded());
463    }
464
465    #[tokio::test]
466    #[cfg(not(feature = "inference"))]
467    async fn test_load_without_feature() {
468        let mut runtime = NativeRuntime::new();
469        let result = runtime
470            .load(PathBuf::from("test.gguf"), LoadConfig::default())
471            .await;
472        assert!(result.is_err());
473        assert!(result
474            .unwrap_err()
475            .to_string()
476            .contains("Inference feature not enabled"));
477    }
478
479    #[tokio::test]
480    #[cfg(not(feature = "inference"))]
481    async fn test_infer_without_feature() {
482        let runtime = NativeRuntime::new();
483        let result = runtime.infer("test", ChatOptions::default()).await;
484        assert!(result.is_err());
485    }
486}