spn_native/inference/traits.rs
1//! Inference backend traits.
2//!
3//! Defines the interface for local model inference backends.
4
5use futures_util::Stream;
6use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
7use std::future::Future;
8use std::path::PathBuf;
9use std::pin::Pin;
10
11use crate::NativeError;
12
13/// Trait for any inference backend (mistral.rs, llama.cpp, etc.).
14///
15/// This trait provides a unified interface for loading and running
16/// local LLM inference. Implementations can use different backends
17/// while presenting the same API to consumers.
18pub trait InferenceBackend: Send + Sync {
19 /// Load a model from disk.
20 ///
21 /// # Arguments
22 /// * `model_path` - Path to the GGUF model file
23 /// * `config` - Load configuration (context size, GPU layers, etc.)
24 ///
25 /// # Returns
26 /// `Ok(())` if the model was loaded successfully.
27 fn load(
28 &mut self,
29 model_path: PathBuf,
30 config: LoadConfig,
31 ) -> impl Future<Output = Result<(), NativeError>> + Send;
32
33 /// Unload the model from memory.
34 ///
35 /// Frees GPU/CPU memory used by the model.
36 fn unload(&mut self) -> impl Future<Output = Result<(), NativeError>> + Send;
37
38 /// Check if a model is currently loaded.
39 #[must_use]
40 fn is_loaded(&self) -> bool;
41
42 /// Get metadata about the loaded model.
43 ///
44 /// Returns `None` if no model is loaded.
45 fn model_info(&self) -> Option<&ModelInfo>;
46
47 /// Generate a response (non-streaming).
48 ///
49 /// # Arguments
50 /// * `prompt` - The input prompt
51 /// * `options` - Generation options (temperature, max_tokens, etc.)
52 ///
53 /// # Returns
54 /// The complete chat response.
55 fn infer(
56 &self,
57 prompt: &str,
58 options: ChatOptions,
59 ) -> impl Future<Output = Result<ChatResponse, NativeError>> + Send;
60
61 /// Generate a response (streaming).
62 ///
63 /// Returns a stream of token strings as they are generated.
64 ///
65 /// # Arguments
66 /// * `prompt` - The input prompt
67 /// * `options` - Generation options (temperature, max_tokens, etc.)
68 fn infer_stream(
69 &self,
70 prompt: &str,
71 options: ChatOptions,
72 ) -> impl Future<Output = Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError>>
73 + Send;
74}
75
76/// Object-safe version of InferenceBackend for dynamic dispatch.
77///
78/// Use this when you need runtime polymorphism (e.g., `Box<dyn DynInferenceBackend>`).
79///
80/// Note: This trait takes owned `String` instead of `&str` for prompts
81/// to enable object-safe async methods.
82#[allow(clippy::type_complexity)]
83pub trait DynInferenceBackend: Send + Sync {
84 /// Load a model from disk (boxed future for object safety).
85 fn load_dyn(
86 &mut self,
87 model_path: PathBuf,
88 config: LoadConfig,
89 ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
90
91 /// Unload the model from memory (boxed future for object safety).
92 fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
93
94 /// Check if a model is currently loaded.
95 fn is_loaded_dyn(&self) -> bool;
96
97 /// Get metadata about the loaded model (cloned for object safety).
98 fn model_info_dyn(&self) -> Option<ModelInfo>;
99
100 /// Generate a response (boxed future for object safety).
101 ///
102 /// Takes owned `String` instead of `&str` for object safety.
103 fn infer_dyn(
104 &self,
105 prompt: String,
106 options: ChatOptions,
107 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>>;
108
109 /// Generate a streaming response (boxed stream for object safety).
110 ///
111 /// Takes owned `String` instead of `&str` for object safety.
112 fn infer_stream_dyn(
113 &self,
114 prompt: String,
115 options: ChatOptions,
116 ) -> Pin<
117 Box<
118 dyn Future<
119 Output = Result<
120 Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
121 NativeError,
122 >,
123 > + Send
124 + '_,
125 >,
126 >;
127}
128
129/// Blanket implementation of DynInferenceBackend for any InferenceBackend.
130impl<T: InferenceBackend + 'static> DynInferenceBackend for T {
131 fn load_dyn(
132 &mut self,
133 model_path: PathBuf,
134 config: LoadConfig,
135 ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
136 Box::pin(self.load(model_path, config))
137 }
138
139 fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
140 Box::pin(self.unload())
141 }
142
143 fn is_loaded_dyn(&self) -> bool {
144 InferenceBackend::is_loaded(self)
145 }
146
147 fn model_info_dyn(&self) -> Option<ModelInfo> {
148 InferenceBackend::model_info(self).cloned()
149 }
150
151 fn infer_dyn(
152 &self,
153 prompt: String,
154 options: ChatOptions,
155 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>> {
156 Box::pin(async move { self.infer(&prompt, options).await })
157 }
158
159 fn infer_stream_dyn(
160 &self,
161 _prompt: String,
162 _options: ChatOptions,
163 ) -> Pin<
164 Box<
165 dyn Future<
166 Output = Result<
167 Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
168 NativeError,
169 >,
170 > + Send
171 + '_,
172 >,
173 > {
174 Box::pin(async move {
175 // We cannot easily box a stream that borrows from self,
176 // so for streaming, callers should use InferenceBackend directly
177 // or collect results into a Vec first
178 Err(NativeError::InvalidConfig(
179 "Streaming not supported via DynInferenceBackend. Use InferenceBackend directly."
180 .to_string(),
181 ))
182 })
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 // Verify the trait is object-safe
191 fn _assert_object_safe(_: &dyn DynInferenceBackend) {}
192}