uni_xervo/traits.rs
1//! Core traits that every provider and model implementation must satisfy.
2
3use crate::api::{ModelAliasSpec, ModelTask};
4use crate::error::Result;
5use async_trait::async_trait;
6use std::any::Any;
7
8/// Advertised capabilities of a [`ModelProvider`].
9#[derive(Debug, Clone)]
10pub struct ProviderCapabilities {
11 /// The set of [`ModelTask`] variants this provider can handle.
12 pub supported_tasks: Vec<ModelTask>,
13}
14
15/// Health status reported by a provider.
16#[derive(Debug, Clone)]
17pub enum ProviderHealth {
18 /// The provider is fully operational.
19 Healthy,
20 /// The provider is operational but experiencing partial issues.
21 Degraded(String),
22 /// The provider cannot serve requests.
23 Unhealthy(String),
24}
25
26/// A pluggable backend that knows how to load models for one or more
27/// [`ModelTask`] types.
28///
29/// Providers are registered with [`ModelRuntimeBuilder::register_provider`](crate::runtime::ModelRuntimeBuilder::register_provider)
30/// and are identified by their [`provider_id`](ModelProvider::provider_id)
31/// (e.g. `"local/candle"`, `"remote/openai"`).
32#[async_trait]
33pub trait ModelProvider: Send + Sync {
34 /// Unique identifier for this provider (e.g. `"local/candle"`, `"remote/openai"`).
35 fn provider_id(&self) -> &'static str;
36
37 /// Return the set of tasks this provider supports.
38 fn capabilities(&self) -> ProviderCapabilities;
39
40 /// Load (or connect to) a model described by `spec` and return a type-erased
41 /// handle.
42 ///
43 /// The returned [`LoadedModelHandle`] is expected to contain an
44 /// `Arc<dyn EmbeddingModel>`, `Arc<dyn RerankerModel>`, or
45 /// `Arc<dyn GeneratorModel>` depending on the task.
46 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle>;
47
48 /// Report the current health of this provider.
49 async fn health(&self) -> ProviderHealth;
50
51 /// Optional one-time warmup hook called during runtime startup.
52 ///
53 /// Use this for provider-wide initialization such as setting up API clients
54 /// or pre-caching shared resources. The default implementation is a no-op.
55 async fn warmup(&self) -> Result<()> {
56 Ok(())
57 }
58}
59
60/// A type-erased, reference-counted handle to a loaded model instance.
61///
62/// Providers wrap their concrete model (e.g. `Arc<dyn EmbeddingModel>`) inside
63/// this `Arc<dyn Any + Send + Sync>` so the runtime can store them uniformly.
64/// The runtime later downcasts the handle back to the expected trait object.
65pub type LoadedModelHandle = std::sync::Arc<dyn Any + Send + Sync>;
66
67/// A model that produces dense vector embeddings from text.
68#[async_trait]
69pub trait EmbeddingModel: Send + Sync + Any {
70 /// Embed a batch of text strings into dense vectors.
71 ///
72 /// Returns one `Vec<f32>` per input text, each with [`dimensions()`](EmbeddingModel::dimensions)
73 /// elements.
74 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>>;
75
76 /// The dimensionality of the embedding vectors produced by this model.
77 fn dimensions(&self) -> u32;
78
79 /// The underlying model identifier (e.g. a HuggingFace repo ID or API model name).
80 fn model_id(&self) -> &str;
81
82 /// Optional warmup hook (e.g. load weights into memory on first access).
83 /// The default is a no-op.
84 async fn warmup(&self) -> Result<()> {
85 Ok(())
86 }
87}
88
89/// A single scored document returned by a [`RerankerModel`].
90#[derive(Debug, Clone)]
91pub struct ScoredDoc {
92 /// Zero-based index into the original `docs` slice passed to
93 /// [`RerankerModel::rerank`].
94 pub index: usize,
95 /// Relevance score assigned by the reranker (higher is more relevant).
96 pub score: f32,
97 /// The document text, if the provider returns it. May be `None`.
98 pub text: Option<String>,
99}
100
101/// A model that re-scores documents against a query for relevance ranking.
102#[async_trait]
103pub trait RerankerModel: Send + Sync {
104 /// Rerank `docs` by relevance to `query`, returning scored results
105 /// (typically sorted by descending score).
106 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>>;
107
108 /// Optional warmup hook. The default is a no-op.
109 async fn warmup(&self) -> Result<()> {
110 Ok(())
111 }
112}
113
114// ---------------------------------------------------------------------------
115// Multimodal message types
116// ---------------------------------------------------------------------------
117
118/// The role of a message in a conversation.
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub enum MessageRole {
121 /// System-level instructions.
122 System,
123 /// A user turn.
124 User,
125 /// An assistant (model) turn.
126 Assistant,
127}
128
129/// Image data that can be passed as part of a [`ContentBlock`].
130#[derive(Debug, Clone)]
131pub enum ImageInput {
132 /// Raw image bytes with a MIME type (e.g. `"image/png"`).
133 Bytes { data: Vec<u8>, media_type: String },
134 /// A URL pointing to an image.
135 Url(String),
136}
137
138/// A single block of content within a [`Message`].
139#[derive(Debug, Clone)]
140pub enum ContentBlock {
141 /// Plain text content.
142 Text(String),
143 /// An image (for vision models).
144 Image(ImageInput),
145}
146
147/// A single message in a conversation, containing one or more content blocks.
148#[derive(Debug, Clone)]
149pub struct Message {
150 /// The role of the message sender.
151 pub role: MessageRole,
152 /// The content blocks that make up this message.
153 pub content: Vec<ContentBlock>,
154}
155
156impl Message {
157 /// Create a user message with a single text block.
158 pub fn user(text: impl Into<String>) -> Self {
159 Self {
160 role: MessageRole::User,
161 content: vec![ContentBlock::Text(text.into())],
162 }
163 }
164
165 /// Create an assistant message with a single text block.
166 pub fn assistant(text: impl Into<String>) -> Self {
167 Self {
168 role: MessageRole::Assistant,
169 content: vec![ContentBlock::Text(text.into())],
170 }
171 }
172
173 /// Create a system message with a single text block.
174 pub fn system(text: impl Into<String>) -> Self {
175 Self {
176 role: MessageRole::System,
177 content: vec![ContentBlock::Text(text.into())],
178 }
179 }
180
181 /// Extract the concatenated text from all [`ContentBlock::Text`] blocks.
182 pub fn text(&self) -> String {
183 self.content
184 .iter()
185 .filter_map(|b| match b {
186 ContentBlock::Text(t) => Some(t.as_str()),
187 _ => None,
188 })
189 .collect::<Vec<_>>()
190 .join(" ")
191 }
192}
193
194// ---------------------------------------------------------------------------
195// Generation options and results
196// ---------------------------------------------------------------------------
197
198/// Sampling and length parameters for text generation.
199#[derive(Debug, Clone, Default)]
200pub struct GenerationOptions {
201 /// Maximum number of tokens to generate. Provider default if `None`.
202 pub max_tokens: Option<usize>,
203 /// Sampling temperature (0.0 = greedy, higher = more random).
204 pub temperature: Option<f32>,
205 /// Nucleus sampling threshold.
206 pub top_p: Option<f32>,
207 /// Desired image width (for diffusion models; ignored by text/vision).
208 pub width: Option<u32>,
209 /// Desired image height (for diffusion models; ignored by text/vision).
210 pub height: Option<u32>,
211}
212
213/// An image produced by a generation call (e.g. from a diffusion model).
214#[derive(Debug, Clone)]
215pub struct GeneratedImage {
216 /// Raw image bytes (e.g. PNG).
217 pub data: Vec<u8>,
218 /// MIME type (e.g. `"image/png"`).
219 pub media_type: String,
220}
221
222/// Audio output produced by a speech model.
223#[derive(Debug, Clone)]
224pub struct AudioOutput {
225 /// PCM sample data.
226 pub pcm_data: Vec<f32>,
227 /// Sample rate in Hz.
228 pub sample_rate: usize,
229 /// Number of audio channels.
230 pub channels: usize,
231}
232
233/// The output of a generation call.
234#[derive(Debug, Clone)]
235pub struct GenerationResult {
236 /// The generated text (may be empty for image/audio-only results).
237 pub text: String,
238 /// Token usage statistics, if reported by the provider.
239 pub usage: Option<TokenUsage>,
240 /// Generated images (non-empty for diffusion models).
241 pub images: Vec<GeneratedImage>,
242 /// Generated audio (present for speech models).
243 pub audio: Option<AudioOutput>,
244}
245
246/// Token counts for a generation request.
247#[derive(Debug, Clone)]
248pub struct TokenUsage {
249 /// Number of tokens in the prompt / input.
250 pub prompt_tokens: usize,
251 /// Number of tokens generated.
252 pub completion_tokens: usize,
253 /// Sum of prompt and completion tokens.
254 pub total_tokens: usize,
255}
256
257/// A model that generates text, images, or audio from a conversational
258/// message history.
259///
260/// Messages carry explicit roles via [`Message`] and may contain multimodal
261/// content (text and images). The output [`GenerationResult`] is a union:
262/// text, images, and audio fields — consumers check what is populated.
263#[async_trait]
264pub trait GeneratorModel: Send + Sync {
265 /// Generate a response given a conversation history and sampling options.
266 async fn generate(
267 &self,
268 messages: &[Message],
269 options: GenerationOptions,
270 ) -> Result<GenerationResult>;
271
272 /// Optional warmup hook. The default is a no-op.
273 async fn warmup(&self) -> Result<()> {
274 Ok(())
275 }
276}