spn_native/inference/traits.rs
1//! Inference backend traits.
2//!
3//! Defines the interface for local model inference backends.
4
5use futures_util::Stream;
6use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
7use std::future::Future;
8use std::path::PathBuf;
9use std::pin::Pin;
10
11use crate::NativeError;
12
13/// Trait for any inference backend (mistral.rs, llama.cpp, etc.).
14///
15/// This trait provides a unified interface for loading and running
16/// local LLM inference. Implementations can use different backends
17/// while presenting the same API to consumers.
18pub trait InferenceBackend: Send + Sync {
19 /// Load a model from disk.
20 ///
21 /// # Arguments
22 /// * `model_path` - Path to the GGUF model file
23 /// * `config` - Load configuration (context size, GPU layers, etc.)
24 ///
25 /// # Returns
26 /// `Ok(())` if the model was loaded successfully.
27 fn load(
28 &mut self,
29 model_path: PathBuf,
30 config: LoadConfig,
31 ) -> impl Future<Output = Result<(), NativeError>> + Send;
32
33 /// Unload the model from memory.
34 ///
35 /// Frees GPU/CPU memory used by the model.
36 fn unload(&mut self) -> impl Future<Output = Result<(), NativeError>> + Send;
37
38 /// Check if a model is currently loaded.
39 #[must_use]
40 fn is_loaded(&self) -> bool;
41
42 /// Get metadata about the loaded model.
43 ///
44 /// Returns `None` if no model is loaded.
45 fn model_info(&self) -> Option<&ModelInfo>;
46
47 /// Generate a response (non-streaming).
48 ///
49 /// # Arguments
50 /// * `prompt` - The input prompt
51 /// * `options` - Generation options (temperature, max_tokens, etc.)
52 ///
53 /// # Returns
54 /// The complete chat response.
55 fn infer(
56 &self,
57 prompt: &str,
58 options: ChatOptions,
59 ) -> impl Future<Output = Result<ChatResponse, NativeError>> + Send;
60
61 /// Generate a response (streaming).
62 ///
63 /// Returns a stream of token strings as they are generated.
64 ///
65 /// # Arguments
66 /// * `prompt` - The input prompt
67 /// * `options` - Generation options (temperature, max_tokens, etc.)
68 fn infer_stream(
69 &self,
70 prompt: &str,
71 options: ChatOptions,
72 ) -> impl Future<
73 Output = Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError>,
74 > + Send;
75}
76
77/// Object-safe version of InferenceBackend for dynamic dispatch.
78///
79/// Use this when you need runtime polymorphism (e.g., `Box<dyn DynInferenceBackend>`).
80///
81/// Note: This trait takes owned `String` instead of `&str` for prompts
82/// to enable object-safe async methods.
83#[allow(clippy::type_complexity)]
84pub trait DynInferenceBackend: Send + Sync {
85 /// Load a model from disk (boxed future for object safety).
86 fn load_dyn(
87 &mut self,
88 model_path: PathBuf,
89 config: LoadConfig,
90 ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
91
92 /// Unload the model from memory (boxed future for object safety).
93 fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
94
95 /// Check if a model is currently loaded.
96 fn is_loaded_dyn(&self) -> bool;
97
98 /// Get metadata about the loaded model (cloned for object safety).
99 fn model_info_dyn(&self) -> Option<ModelInfo>;
100
101 /// Generate a response (boxed future for object safety).
102 ///
103 /// Takes owned `String` instead of `&str` for object safety.
104 fn infer_dyn(
105 &self,
106 prompt: String,
107 options: ChatOptions,
108 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>>;
109
110 /// Generate a streaming response (boxed stream for object safety).
111 ///
112 /// Takes owned `String` instead of `&str` for object safety.
113 fn infer_stream_dyn(
114 &self,
115 prompt: String,
116 options: ChatOptions,
117 ) -> Pin<
118 Box<
119 dyn Future<
120 Output = Result<
121 Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
122 NativeError,
123 >,
124 > + Send
125 + '_,
126 >,
127 >;
128}
129
130/// Blanket implementation of DynInferenceBackend for any InferenceBackend.
131impl<T: InferenceBackend + 'static> DynInferenceBackend for T {
132 fn load_dyn(
133 &mut self,
134 model_path: PathBuf,
135 config: LoadConfig,
136 ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
137 Box::pin(self.load(model_path, config))
138 }
139
140 fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
141 Box::pin(self.unload())
142 }
143
144 fn is_loaded_dyn(&self) -> bool {
145 InferenceBackend::is_loaded(self)
146 }
147
148 fn model_info_dyn(&self) -> Option<ModelInfo> {
149 InferenceBackend::model_info(self).cloned()
150 }
151
152 fn infer_dyn(
153 &self,
154 prompt: String,
155 options: ChatOptions,
156 ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>> {
157 Box::pin(async move { self.infer(&prompt, options).await })
158 }
159
160 fn infer_stream_dyn(
161 &self,
162 _prompt: String,
163 _options: ChatOptions,
164 ) -> Pin<
165 Box<
166 dyn Future<
167 Output = Result<
168 Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
169 NativeError,
170 >,
171 > + Send
172 + '_,
173 >,
174 > {
175 Box::pin(async move {
176 // We cannot easily box a stream that borrows from self,
177 // so for streaming, callers should use InferenceBackend directly
178 // or collect results into a Vec first
179 Err(NativeError::InvalidConfig(
180 "Streaming not supported via DynInferenceBackend. Use InferenceBackend directly."
181 .to_string(),
182 ))
183 })
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 // Verify the trait is object-safe
192 fn _assert_object_safe(_: &dyn DynInferenceBackend) {}
193}