1use crate::engine::{Engine, EngineCapabilities};
18use crate::types::*;
19use anyhow::{bail, Result};
20use std::collections::BTreeMap;
21use tracing::{debug, warn};
22
23const TRACE_TARGET: &str = "studio_worker::engine::multi";
26
27pub struct MultiEngine {
28 engines: Vec<Box<dyn Engine>>,
29}
30
31impl MultiEngine {
32 pub fn new(engines: Vec<Box<dyn Engine>>) -> Self {
33 Self { engines }
34 }
35
36 fn pick_for(&self, kind: TaskKind, model: &str) -> Option<&dyn Engine> {
42 for e in &self.engines {
43 if e.capabilities().supports(kind, model) {
44 debug!(
45 target: TRACE_TARGET,
46 op = "pick",
47 kind = kind.as_str(),
48 model,
49 sub_engine = e.name(),
50 r#match = "exact",
51 "engine selected"
52 );
53 return Some(e.as_ref());
54 }
55 }
56 warn!(
57 target: TRACE_TARGET,
58 op = "pick",
59 kind = kind.as_str(),
60 model,
61 "no engine claims this exact (kind, model) pair"
62 );
63 None
64 }
65}
66
67impl Engine for MultiEngine {
68 fn name(&self) -> &'static str {
69 "multi"
70 }
71
72 fn capabilities(&self) -> EngineCapabilities {
73 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
74 for e in &self.engines {
75 for (kind, models) in e.capabilities().supported_models_per_kind {
76 let entry = map.entry(kind).or_default();
77 for m in models {
78 if !entry.contains(&m) {
79 entry.push(m);
80 }
81 }
82 }
83 }
84 EngineCapabilities {
85 supported_models_per_kind: map,
86 }
87 }
88
89 fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
90 let kind = task.kind();
91 let Some(engine) = self.pick_for(kind, model) else {
92 bail!(
93 "no engine on this worker can serve model {} (kind={}); \
94 synthetic fallback is disabled",
95 model,
96 kind.as_str()
97 );
98 };
99 engine.dispatch(model, task)
100 }
101
102 fn dispatch_with_source(
103 &self,
104 model: &str,
105 task: Task,
106 source: &crate::types::ModelSource,
107 ) -> Result<TaskResult> {
108 let kind = task.kind();
109 let wanted = match source.engine {
115 crate::types::ModelEngine::SdCpp => "sdcpp",
116 crate::types::ModelEngine::LlamaCpp => "llama",
117 crate::types::ModelEngine::Onnx => "onnx",
118 crate::types::ModelEngine::Synthetic => "synthetic",
119 };
120 for e in &self.engines {
121 if e.name() == wanted {
122 debug!(
123 target: TRACE_TARGET,
124 op = "pick",
125 kind = kind.as_str(),
126 model,
127 sub_engine = e.name(),
128 r#match = "model-source",
129 "engine selected by ModelSource.engine"
130 );
131 return e.dispatch_with_source(model, task, source);
132 }
133 }
134 warn!(
135 target: TRACE_TARGET,
136 op = "pick",
137 kind = kind.as_str(),
138 model,
139 sub_engine = wanted,
140 r#match = "model-source",
141 "requested engine not compiled into this worker"
142 );
143 bail!(
144 "no `{}` engine compiled into this worker (model `{}` requires it). \
145 Install the all-backends release build from \
146 https://github.com/webbertakken/studio-worker/releases/latest, \
147 or rebuild from source with `cargo install studio-worker --features all`.",
148 wanted,
149 model
150 );
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::engine::SyntheticEngine;
158
159 struct StubEngine {
160 name: &'static str,
161 kinds: Vec<TaskKind>,
162 models: Vec<String>,
163 }
164
165 impl Engine for StubEngine {
166 fn name(&self) -> &'static str {
167 self.name
168 }
169 fn capabilities(&self) -> EngineCapabilities {
170 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
171 for k in &self.kinds {
172 map.insert(*k, self.models.clone());
173 }
174 EngineCapabilities {
175 supported_models_per_kind: map,
176 }
177 }
178 fn dispatch(&self, _model: &str, task: Task) -> Result<TaskResult> {
179 match task {
182 Task::Image(_) => Ok(TaskResult::Image {
183 bytes: self.name.as_bytes().to_vec(),
184 ext: "test".into(),
185 }),
186 Task::Llm(_) => Ok(TaskResult::Llm {
187 json: serde_json::json!({ "from": self.name }),
188 }),
189 _ => bail!("stub doesn't serve this"),
190 }
191 }
192 }
193
194 fn image_task() -> Task {
195 Task::Image(ImageParams {
196 prompt: "x".into(),
197 width: 64,
198 height: 64,
199 steps: 1,
200 ext: "webp".into(),
201 ..Default::default()
202 })
203 }
204
205 fn llm_task() -> Task {
206 Task::Llm(LlmParams {
207 messages: vec![],
208 max_tokens: 1,
209 temperature: 0.0,
210 ..Default::default()
211 })
212 }
213
214 #[test]
215 fn multi_picks_first_engine_supporting_the_kind_and_model() {
216 let a: Box<dyn Engine> = Box::new(StubEngine {
217 name: "a",
218 kinds: vec![TaskKind::Image],
219 models: vec!["alpha".into()],
220 });
221 let b: Box<dyn Engine> = Box::new(StubEngine {
222 name: "b",
223 kinds: vec![TaskKind::Image],
224 models: vec!["beta".into()],
225 });
226 let multi = MultiEngine::new(vec![a, b]);
227
228 let result = multi.dispatch("alpha", image_task()).unwrap();
229 match result {
230 TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"a"),
231 _ => panic!("expected image"),
232 }
233 let result = multi.dispatch("beta", image_task()).unwrap();
234 match result {
235 TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"b"),
236 _ => panic!("expected image"),
237 }
238 }
239
240 #[test]
241 fn multi_refuses_unknown_model_without_kind_fallback() {
242 let alpha_only: Box<dyn Engine> = Box::new(StubEngine {
247 name: "alpha",
248 kinds: vec![TaskKind::Image],
249 models: vec!["alpha-image".into()],
250 });
251 let llm_only: Box<dyn Engine> = Box::new(StubEngine {
252 name: "llm",
253 kinds: vec![TaskKind::Llm],
254 models: vec!["llama-some".into()],
255 });
256 let multi = MultiEngine::new(vec![alpha_only, llm_only]);
257
258 let err = multi.dispatch("unknown-model", llm_task()).unwrap_err();
259 let msg = err.to_string();
260 assert!(
261 msg.contains("no engine on this worker can serve model"),
262 "expected no-fallback error, got: {msg}"
263 );
264 assert!(msg.contains("unknown-model"));
265 }
266
267 #[test]
268 fn multi_errors_when_no_engine_serves_kind() {
269 let image_only: Box<dyn Engine> = Box::new(StubEngine {
270 name: "image",
271 kinds: vec![TaskKind::Image],
272 models: vec!["x".into()],
273 });
274 let multi = MultiEngine::new(vec![image_only]);
275 let err = multi.dispatch("x", llm_task()).unwrap_err();
276 let msg = err.to_string();
277 assert!(
278 msg.contains("no engine on this worker can serve model"),
279 "expected no-fallback error, got: {msg}"
280 );
281 }
282
283 #[test]
284 fn capabilities_union_across_all_engines() {
285 let img: Box<dyn Engine> = Box::new(SyntheticEngine::new());
286 let stub: Box<dyn Engine> = Box::new(StubEngine {
287 name: "extra",
288 kinds: vec![TaskKind::Image],
289 models: vec!["extra-image-model".into()],
290 });
291 let multi = MultiEngine::new(vec![img, stub]);
292 let caps = multi.capabilities();
293 let image = &caps.supported_models_per_kind[&TaskKind::Image];
294 assert!(image.contains(&"synthetic".to_string()));
295 assert!(image.contains(&"extra-image-model".to_string()));
296 }
297
298 #[test]
299 fn name_is_multi() {
300 let multi = MultiEngine::new(vec![]);
301 assert_eq!(multi.name(), "multi");
302 }
303
304 fn sd_cpp_source() -> crate::types::ModelSource {
305 crate::types::ModelSource {
306 engine: crate::types::ModelEngine::SdCpp,
307 files: vec![],
308 cli_defaults: crate::types::ModelCliDefaults {
309 cfg_scale: 1.0,
310 steps: 8,
311 width: 1024,
312 height: 1024,
313 sampling_method: None,
314 ..Default::default()
315 },
316 }
317 }
318
319 #[test]
324 fn dispatch_with_source_refuses_to_fall_back_to_synthetic_for_real_models() {
325 let synth: Box<dyn Engine> = Box::new(SyntheticEngine::new());
326 let multi = MultiEngine::new(vec![synth]);
327 let source = sd_cpp_source();
328 let err = multi
329 .dispatch_with_source("some-real-flux-model", image_task(), &source)
330 .unwrap_err()
331 .to_string();
332 assert!(
333 err.contains("no `sdcpp` engine compiled"),
334 "expected no-sdcpp-backend error, got: {err}"
335 );
336 }
337
338 #[test]
346 fn dispatch_with_source_warns_when_wanted_engine_missing() {
347 let logs = crate::test_support::capture(|| {
348 let synth: Box<dyn Engine> = Box::new(SyntheticEngine::new());
349 let multi = MultiEngine::new(vec![synth]);
350 let source = sd_cpp_source();
351 let _ = multi.dispatch_with_source("some-real-flux-model", image_task(), &source);
352 });
353 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
354 assert!(
355 logs.contains("studio_worker::engine::multi"),
356 "expected multi target, got: {logs}"
357 );
358 assert!(logs.contains("op=\"pick\""), "expected op field: {logs}");
359 assert!(
360 logs.contains("sdcpp"),
361 "expected wanted engine name in breadcrumb: {logs}"
362 );
363 assert!(
364 logs.contains("some-real-flux-model"),
365 "expected model id in breadcrumb: {logs}"
366 );
367 }
368
369 #[test]
373 fn dispatch_with_source_routes_synthetic_engine_for_synthetic_models() {
374 let synth: Box<dyn Engine> = Box::new(SyntheticEngine::new());
375 let multi = MultiEngine::new(vec![synth]);
376 let source = crate::types::ModelSource {
377 engine: crate::types::ModelEngine::Synthetic,
378 files: vec![],
379 cli_defaults: crate::types::ModelCliDefaults {
380 cfg_scale: 1.0,
381 steps: 8,
382 width: 1024,
383 height: 1024,
384 sampling_method: None,
385 ..Default::default()
386 },
387 };
388 let result = multi
389 .dispatch_with_source("synthetic", image_task(), &source)
390 .unwrap();
391 assert!(matches!(result, TaskResult::Image { .. }));
392 }
393}