spn_native/inference/
runtime.rs1use crate::inference::traits::InferenceBackend;
7use crate::NativeError;
8use futures_util::stream::Stream;
9use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
10use std::path::PathBuf;
11
12#[cfg(feature = "inference")]
13use spn_core::ChatRole;
14#[cfg(feature = "inference")]
15use std::path::Path;
16#[cfg(feature = "inference")]
17use std::sync::Arc;
18#[cfg(feature = "inference")]
19use tracing::{debug, info};
20
21#[cfg(feature = "inference")]
22use mistralrs::{
23 GgufModelBuilder, MemoryGpuConfig, Model, PagedAttentionMetaBuilder, RequestBuilder,
24 TextMessageRole, TextMessages,
25};
26#[cfg(feature = "inference")]
27use tokio::sync::RwLock;
28
29#[allow(dead_code)] pub struct NativeRuntime {
46 #[cfg(feature = "inference")]
48 model: Option<Arc<RwLock<Model>>>,
49
50 model_info: Option<ModelInfo>,
52
53 model_path: Option<PathBuf>,
55
56 config: Option<LoadConfig>,
58}
59
60impl NativeRuntime {
61 #[must_use]
66 pub fn new() -> Self {
67 Self {
68 #[cfg(feature = "inference")]
69 model: None,
70 model_info: None,
71 model_path: None,
72 config: None,
73 }
74 }
75
76 #[must_use]
78 pub fn model_path(&self) -> Option<&PathBuf> {
79 self.model_path.as_ref()
80 }
81
82 #[must_use]
84 pub fn config(&self) -> Option<&LoadConfig> {
85 self.config.as_ref()
86 }
87
88 #[cfg(feature = "inference")]
90 #[allow(dead_code)] fn convert_role(role: ChatRole) -> TextMessageRole {
92 match role {
93 ChatRole::System => TextMessageRole::System,
94 ChatRole::User => TextMessageRole::User,
95 ChatRole::Assistant => TextMessageRole::Assistant,
96 }
97 }
98}
99
100impl Default for NativeRuntime {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[cfg(feature = "inference")]
107impl InferenceBackend for NativeRuntime {
108 async fn load(&mut self, model_path: PathBuf, config: LoadConfig) -> Result<(), NativeError> {
109 info!(?model_path, "Loading GGUF model");
110
111 if self.model.is_some() {
113 self.unload().await?;
114 }
115
116 if !model_path.exists() {
118 return Err(NativeError::ModelNotFound {
119 repo: "local".to_string(),
120 filename: model_path.to_string_lossy().to_string(),
121 });
122 }
123
124 let parent = model_path
127 .parent()
128 .map(|p| p.to_string_lossy().to_string())
129 .unwrap_or_else(|| ".".to_string());
130 let filename = model_path
131 .file_name()
132 .map(|f| f.to_string_lossy().to_string())
133 .ok_or_else(|| {
134 NativeError::InvalidConfig("Invalid model path: no filename".to_string())
135 })?;
136
137 debug!(gpu_layers = config.gpu_layers, %parent, %filename, "Building model");
138
139 let context_size = config.context_size.unwrap_or(2048);
143 let model = GgufModelBuilder::new(parent, vec![filename])
144 .with_logging()
145 .with_paged_attn(|| {
146 PagedAttentionMetaBuilder::default()
147 .with_block_size(32)
148 .with_gpu_memory(MemoryGpuConfig::ContextSize(context_size as usize))
149 .build()
150 })
151 .map_err(|e| NativeError::InvalidConfig(format!("PagedAttention config error: {e}")))?
152 .build()
153 .await
154 .map_err(|e| NativeError::InvalidConfig(format!("Failed to build model: {e}")))?;
155
156 let info = ModelInfo {
158 name: model_path
159 .file_stem()
160 .map(|s| s.to_string_lossy().to_string())
161 .unwrap_or_else(|| "unknown".to_string()),
162 size: tokio::fs::metadata(&model_path)
163 .await
164 .map(|m| m.len())
165 .unwrap_or(0),
166 quantization: extract_quantization_from_path(&model_path),
167 parameters: None,
168 digest: None,
169 };
170
171 self.model = Some(Arc::new(RwLock::new(model)));
172 self.model_info = Some(info);
173 self.model_path = Some(model_path);
174 self.config = Some(config);
175
176 info!("Model loaded successfully");
177 Ok(())
178 }
179
180 async fn unload(&mut self) -> Result<(), NativeError> {
181 if self.model.is_some() {
182 info!("Unloading model");
183 self.model = None;
184 self.model_info = None;
185 self.model_path = None;
186 self.config = None;
187 }
188 Ok(())
189 }
190
191 fn is_loaded(&self) -> bool {
192 self.model.is_some()
193 }
194
195 fn model_info(&self) -> Option<&ModelInfo> {
196 self.model_info.as_ref()
197 }
198
199 async fn infer(&self, prompt: &str, options: ChatOptions) -> Result<ChatResponse, NativeError> {
200 let model = self.model.as_ref().ok_or(NativeError::ModelNotLoaded)?;
201
202 let model = model.read().await;
203
204 let messages = TextMessages::new().add_message(TextMessageRole::User, prompt);
207
208 debug!(
209 temperature = options.temperature,
210 max_tokens = options.max_tokens,
211 "Running inference"
212 );
213
214 let mut request = RequestBuilder::from(messages);
216
217 if let Some(temp) = options.temperature {
219 request = request.set_sampler_temperature(f64::from(temp));
220 }
221
222 if let Some(max_tokens) = options.max_tokens {
224 request = request.set_sampler_max_len(max_tokens as usize);
225 }
226
227 let response = model
229 .send_chat_request(request)
230 .await
231 .map_err(|e| NativeError::InvalidConfig(format!("Inference failed: {e}")))?;
232
233 let content = response
235 .choices
236 .first()
237 .and_then(|c| c.message.content.clone())
238 .ok_or_else(|| {
239 NativeError::InvalidConfig("Model returned empty response (no choices)".to_string())
240 })?;
241
242 debug!(
244 prompt_tokens = response.usage.prompt_tokens,
245 completion_tokens = response.usage.completion_tokens,
246 avg_prompt_tok_per_sec = ?response.usage.avg_prompt_tok_per_sec,
247 avg_compl_tok_per_sec = ?response.usage.avg_compl_tok_per_sec,
248 "Inference completed"
249 );
250
251 Ok(ChatResponse {
252 message: spn_core::ChatMessage {
253 role: ChatRole::Assistant,
254 content,
255 },
256 done: true,
257 total_duration: None,
258 prompt_eval_count: Some(response.usage.prompt_tokens as u32),
259 eval_count: Some(response.usage.completion_tokens as u32),
260 })
261 }
262
263 async fn infer_stream(
264 &self,
265 _prompt: &str,
266 _options: ChatOptions,
267 ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
268 Err::<futures_util::stream::Empty<Result<String, NativeError>>, _>(
272 NativeError::InvalidConfig(
273 "Streaming not yet implemented for native runtime. Use infer() instead.".to_string(),
274 ),
275 )
276 }
277}
278
279#[cfg(feature = "inference")]
283fn extract_quantization_from_path(path: &Path) -> Option<String> {
284 let filename = path.file_name()?.to_string_lossy();
285 crate::extract_quantization(&filename)
286}
287
288#[cfg(not(feature = "inference"))]
290impl InferenceBackend for NativeRuntime {
291 async fn load(&mut self, _model_path: PathBuf, _config: LoadConfig) -> Result<(), NativeError> {
292 Err(NativeError::InvalidConfig(
293 "Inference feature not enabled. Rebuild with --features inference".to_string(),
294 ))
295 }
296
297 async fn unload(&mut self) -> Result<(), NativeError> {
298 Ok(())
299 }
300
301 fn is_loaded(&self) -> bool {
302 false
303 }
304
305 fn model_info(&self) -> Option<&ModelInfo> {
306 None
307 }
308
309 async fn infer(
310 &self,
311 _prompt: &str,
312 _options: ChatOptions,
313 ) -> Result<ChatResponse, NativeError> {
314 Err(NativeError::InvalidConfig(
315 "Inference feature not enabled. Rebuild with --features inference".to_string(),
316 ))
317 }
318
319 async fn infer_stream(
320 &self,
321 _prompt: &str,
322 _options: ChatOptions,
323 ) -> Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError> {
324 Err::<futures_util::stream::Empty<Result<String, NativeError>>, _>(
325 NativeError::InvalidConfig(
326 "Inference feature not enabled. Rebuild with --features inference".to_string(),
327 ),
328 )
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_runtime_creation() {
338 let runtime = NativeRuntime::new();
339 assert!(!runtime.is_loaded());
340 assert!(runtime.model_info().is_none());
341 assert!(runtime.model_path().is_none());
342 }
343
344 #[test]
345 fn test_runtime_default() {
346 let runtime = NativeRuntime::default();
347 assert!(!runtime.is_loaded());
348 }
349
350 #[tokio::test]
351 #[cfg(not(feature = "inference"))]
352 async fn test_load_without_feature() {
353 let mut runtime = NativeRuntime::new();
354 let result = runtime
355 .load(PathBuf::from("test.gguf"), LoadConfig::default())
356 .await;
357 assert!(result.is_err());
358 assert!(result
359 .unwrap_err()
360 .to_string()
361 .contains("Inference feature not enabled"));
362 }
363
364 #[tokio::test]
365 #[cfg(not(feature = "inference"))]
366 async fn test_infer_without_feature() {
367 let runtime = NativeRuntime::new();
368 let result = runtime.infer("test", ChatOptions::default()).await;
369 assert!(result.is_err());
370 }
371}