Skip to main content

spn_native/inference/
traits.rs

1//! Inference backend traits.
2//!
3//! Defines the interface for local model inference backends.
4
5use futures_util::Stream;
6use spn_core::{ChatOptions, ChatResponse, LoadConfig, ModelInfo};
7use std::future::Future;
8use std::path::PathBuf;
9use std::pin::Pin;
10
11use crate::NativeError;
12
13/// Trait for any inference backend (mistral.rs, llama.cpp, etc.).
14///
15/// This trait provides a unified interface for loading and running
16/// local LLM inference. Implementations can use different backends
17/// while presenting the same API to consumers.
18pub trait InferenceBackend: Send + Sync {
19    /// Load a model from disk.
20    ///
21    /// # Arguments
22    /// * `model_path` - Path to the GGUF model file
23    /// * `config` - Load configuration (context size, GPU layers, etc.)
24    ///
25    /// # Returns
26    /// `Ok(())` if the model was loaded successfully.
27    fn load(
28        &mut self,
29        model_path: PathBuf,
30        config: LoadConfig,
31    ) -> impl Future<Output = Result<(), NativeError>> + Send;
32
33    /// Unload the model from memory.
34    ///
35    /// Frees GPU/CPU memory used by the model.
36    fn unload(&mut self) -> impl Future<Output = Result<(), NativeError>> + Send;
37
38    /// Check if a model is currently loaded.
39    #[must_use]
40    fn is_loaded(&self) -> bool;
41
42    /// Get metadata about the loaded model.
43    ///
44    /// Returns `None` if no model is loaded.
45    fn model_info(&self) -> Option<&ModelInfo>;
46
47    /// Generate a response (non-streaming).
48    ///
49    /// # Arguments
50    /// * `prompt` - The input prompt
51    /// * `options` - Generation options (temperature, max_tokens, etc.)
52    ///
53    /// # Returns
54    /// The complete chat response.
55    fn infer(
56        &self,
57        prompt: &str,
58        options: ChatOptions,
59    ) -> impl Future<Output = Result<ChatResponse, NativeError>> + Send;
60
61    /// Generate a response (streaming).
62    ///
63    /// Returns a stream of token strings as they are generated.
64    ///
65    /// # Arguments
66    /// * `prompt` - The input prompt
67    /// * `options` - Generation options (temperature, max_tokens, etc.)
68    fn infer_stream(
69        &self,
70        prompt: &str,
71        options: ChatOptions,
72    ) -> impl Future<
73        Output = Result<impl Stream<Item = Result<String, NativeError>> + Send, NativeError>,
74    > + Send;
75}
76
77/// Object-safe version of InferenceBackend for dynamic dispatch.
78///
79/// Use this when you need runtime polymorphism (e.g., `Box<dyn DynInferenceBackend>`).
80///
81/// Note: This trait takes owned `String` instead of `&str` for prompts
82/// to enable object-safe async methods.
83#[allow(clippy::type_complexity)]
84pub trait DynInferenceBackend: Send + Sync {
85    /// Load a model from disk (boxed future for object safety).
86    fn load_dyn(
87        &mut self,
88        model_path: PathBuf,
89        config: LoadConfig,
90    ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
91
92    /// Unload the model from memory (boxed future for object safety).
93    fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>>;
94
95    /// Check if a model is currently loaded.
96    fn is_loaded_dyn(&self) -> bool;
97
98    /// Get metadata about the loaded model (cloned for object safety).
99    fn model_info_dyn(&self) -> Option<ModelInfo>;
100
101    /// Generate a response (boxed future for object safety).
102    ///
103    /// Takes owned `String` instead of `&str` for object safety.
104    fn infer_dyn(
105        &self,
106        prompt: String,
107        options: ChatOptions,
108    ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>>;
109
110    /// Generate a streaming response (boxed stream for object safety).
111    ///
112    /// Takes owned `String` instead of `&str` for object safety.
113    fn infer_stream_dyn(
114        &self,
115        prompt: String,
116        options: ChatOptions,
117    ) -> Pin<
118        Box<
119            dyn Future<
120                    Output = Result<
121                        Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
122                        NativeError,
123                    >,
124                > + Send
125                + '_,
126        >,
127    >;
128}
129
130/// Blanket implementation of DynInferenceBackend for any InferenceBackend.
131impl<T: InferenceBackend + 'static> DynInferenceBackend for T {
132    fn load_dyn(
133        &mut self,
134        model_path: PathBuf,
135        config: LoadConfig,
136    ) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
137        Box::pin(self.load(model_path, config))
138    }
139
140    fn unload_dyn(&mut self) -> Pin<Box<dyn Future<Output = Result<(), NativeError>> + Send + '_>> {
141        Box::pin(self.unload())
142    }
143
144    fn is_loaded_dyn(&self) -> bool {
145        InferenceBackend::is_loaded(self)
146    }
147
148    fn model_info_dyn(&self) -> Option<ModelInfo> {
149        InferenceBackend::model_info(self).cloned()
150    }
151
152    fn infer_dyn(
153        &self,
154        prompt: String,
155        options: ChatOptions,
156    ) -> Pin<Box<dyn Future<Output = Result<ChatResponse, NativeError>> + Send + '_>> {
157        Box::pin(async move { self.infer(&prompt, options).await })
158    }
159
160    fn infer_stream_dyn(
161        &self,
162        _prompt: String,
163        _options: ChatOptions,
164    ) -> Pin<
165        Box<
166            dyn Future<
167                    Output = Result<
168                        Pin<Box<dyn Stream<Item = Result<String, NativeError>> + Send + 'static>>,
169                        NativeError,
170                    >,
171                > + Send
172                + '_,
173        >,
174    > {
175        Box::pin(async move {
176            // We cannot easily box a stream that borrows from self,
177            // so for streaming, callers should use InferenceBackend directly
178            // or collect results into a Vec first
179            Err(NativeError::InvalidConfig(
180                "Streaming not supported via DynInferenceBackend. Use InferenceBackend directly."
181                    .to_string(),
182            ))
183        })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    // Verify the trait is object-safe
192    fn _assert_object_safe(_: &dyn DynInferenceBackend) {}
193}