1use crate::engine::{Engine, EngineCapabilities};
18use crate::types::*;
19use anyhow::{bail, Result};
20use std::collections::BTreeMap;
21use tracing::{debug, warn};
22
23const TRACE_TARGET: &str = "studio_worker::engine::multi";
26
27pub struct MultiEngine {
28 engines: Vec<Box<dyn Engine>>,
29}
30
31impl MultiEngine {
32 pub fn new(engines: Vec<Box<dyn Engine>>) -> Self {
33 Self { engines }
34 }
35
36 fn pick_for(&self, kind: TaskKind, model: &str) -> Option<&dyn Engine> {
42 for e in &self.engines {
43 if e.capabilities().supports(kind, model) {
44 debug!(
45 target: TRACE_TARGET,
46 op = "pick",
47 kind = kind.as_str(),
48 model,
49 sub_engine = e.name(),
50 r#match = "exact",
51 "engine selected"
52 );
53 return Some(e.as_ref());
54 }
55 }
56 warn!(
57 target: TRACE_TARGET,
58 op = "pick",
59 kind = kind.as_str(),
60 model,
61 "no engine claims this exact (kind, model) pair"
62 );
63 None
64 }
65}
66
67impl Engine for MultiEngine {
68 fn name(&self) -> &'static str {
69 "multi"
70 }
71
72 fn capabilities(&self) -> EngineCapabilities {
73 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
74 for e in &self.engines {
75 for (kind, models) in e.capabilities().supported_models_per_kind {
76 let entry = map.entry(kind).or_default();
77 for m in models {
78 if !entry.contains(&m) {
79 entry.push(m);
80 }
81 }
82 }
83 }
84 EngineCapabilities {
85 supported_models_per_kind: map,
86 }
87 }
88
89 fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
90 let kind = task.kind();
91 let Some(engine) = self.pick_for(kind, model) else {
92 bail!(
93 "no engine on this worker can serve model {} (kind={}); \
94 synthetic fallback is disabled",
95 model,
96 kind.as_str()
97 );
98 };
99 engine.dispatch(model, task)
100 }
101
102 fn dispatch_with_source(
103 &self,
104 model: &str,
105 task: Task,
106 source: &crate::types::ModelSource,
107 ) -> Result<TaskResult> {
108 let kind = task.kind();
109 let wanted = match source.engine {
115 crate::types::ModelEngine::SdCpp => "sdcpp",
116 crate::types::ModelEngine::LlamaCpp => "llama",
117 crate::types::ModelEngine::Synthetic => "synthetic",
118 };
119 for e in &self.engines {
120 if e.name() == wanted {
121 debug!(
122 target: TRACE_TARGET,
123 op = "pick",
124 kind = kind.as_str(),
125 model,
126 sub_engine = e.name(),
127 r#match = "model-source",
128 "engine selected by ModelSource.engine"
129 );
130 return e.dispatch_with_source(model, task, source);
131 }
132 }
133 warn!(
134 target: TRACE_TARGET,
135 op = "pick",
136 kind = kind.as_str(),
137 model,
138 sub_engine = wanted,
139 r#match = "model-source",
140 "requested engine not compiled into this worker"
141 );
142 bail!(
143 "no `{}` engine compiled into this worker (model `{}` requires it). \
144 Install the all-backends release build from \
145 https://github.com/webbertakken/studio-worker/releases/latest, \
146 or rebuild from source with `cargo install studio-worker --features all`.",
147 wanted,
148 model
149 );
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::engine::SyntheticEngine;
157
158 struct StubEngine {
159 name: &'static str,
160 kinds: Vec<TaskKind>,
161 models: Vec<String>,
162 }
163
164 impl Engine for StubEngine {
165 fn name(&self) -> &'static str {
166 self.name
167 }
168 fn capabilities(&self) -> EngineCapabilities {
169 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
170 for k in &self.kinds {
171 map.insert(*k, self.models.clone());
172 }
173 EngineCapabilities {
174 supported_models_per_kind: map,
175 }
176 }
177 fn dispatch(&self, _model: &str, task: Task) -> Result<TaskResult> {
178 match task {
181 Task::Image(_) => Ok(TaskResult::Image {
182 bytes: self.name.as_bytes().to_vec(),
183 ext: "test".into(),
184 }),
185 Task::Llm(_) => Ok(TaskResult::Llm {
186 json: serde_json::json!({ "from": self.name }),
187 }),
188 _ => bail!("stub doesn't serve this"),
189 }
190 }
191 }
192
193 fn image_task() -> Task {
194 Task::Image(ImageParams {
195 prompt: "x".into(),
196 width: 64,
197 height: 64,
198 steps: 1,
199 ext: "webp".into(),
200 ..Default::default()
201 })
202 }
203
204 fn llm_task() -> Task {
205 Task::Llm(LlmParams {
206 messages: vec![],
207 max_tokens: 1,
208 temperature: 0.0,
209 ..Default::default()
210 })
211 }
212
213 #[test]
214 fn multi_picks_first_engine_supporting_the_kind_and_model() {
215 let a: Box<dyn Engine> = Box::new(StubEngine {
216 name: "a",
217 kinds: vec![TaskKind::Image],
218 models: vec!["alpha".into()],
219 });
220 let b: Box<dyn Engine> = Box::new(StubEngine {
221 name: "b",
222 kinds: vec![TaskKind::Image],
223 models: vec!["beta".into()],
224 });
225 let multi = MultiEngine::new(vec![a, b]);
226
227 let result = multi.dispatch("alpha", image_task()).unwrap();
228 match result {
229 TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"a"),
230 _ => panic!("expected image"),
231 }
232 let result = multi.dispatch("beta", image_task()).unwrap();
233 match result {
234 TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"b"),
235 _ => panic!("expected image"),
236 }
237 }
238
239 #[test]
240 fn multi_refuses_unknown_model_without_kind_fallback() {
241 let alpha_only: Box<dyn Engine> = Box::new(StubEngine {
246 name: "alpha",
247 kinds: vec![TaskKind::Image],
248 models: vec!["alpha-image".into()],
249 });
250 let llm_only: Box<dyn Engine> = Box::new(StubEngine {
251 name: "llm",
252 kinds: vec![TaskKind::Llm],
253 models: vec!["llama-some".into()],
254 });
255 let multi = MultiEngine::new(vec![alpha_only, llm_only]);
256
257 let err = multi.dispatch("unknown-model", llm_task()).unwrap_err();
258 let msg = err.to_string();
259 assert!(
260 msg.contains("no engine on this worker can serve model"),
261 "expected no-fallback error, got: {msg}"
262 );
263 assert!(msg.contains("unknown-model"));
264 }
265
266 #[test]
267 fn multi_errors_when_no_engine_serves_kind() {
268 let image_only: Box<dyn Engine> = Box::new(StubEngine {
269 name: "image",
270 kinds: vec![TaskKind::Image],
271 models: vec!["x".into()],
272 });
273 let multi = MultiEngine::new(vec![image_only]);
274 let err = multi.dispatch("x", llm_task()).unwrap_err();
275 let msg = err.to_string();
276 assert!(
277 msg.contains("no engine on this worker can serve model"),
278 "expected no-fallback error, got: {msg}"
279 );
280 }
281
282 #[test]
283 fn capabilities_union_across_all_engines() {
284 let img: Box<dyn Engine> = Box::new(SyntheticEngine::new());
285 let stub: Box<dyn Engine> = Box::new(StubEngine {
286 name: "extra",
287 kinds: vec![TaskKind::Image],
288 models: vec!["extra-image-model".into()],
289 });
290 let multi = MultiEngine::new(vec![img, stub]);
291 let caps = multi.capabilities();
292 let image = &caps.supported_models_per_kind[&TaskKind::Image];
293 assert!(image.contains(&"synthetic".to_string()));
294 assert!(image.contains(&"extra-image-model".to_string()));
295 }
296
297 #[test]
298 fn name_is_multi() {
299 let multi = MultiEngine::new(vec![]);
300 assert_eq!(multi.name(), "multi");
301 }
302
303 fn sd_cpp_source() -> crate::types::ModelSource {
304 crate::types::ModelSource {
305 engine: crate::types::ModelEngine::SdCpp,
306 files: vec![],
307 cli_defaults: crate::types::ModelCliDefaults {
308 cfg_scale: 1.0,
309 steps: 8,
310 width: 1024,
311 height: 1024,
312 sampling_method: None,
313 ..Default::default()
314 },
315 }
316 }
317
318 #[test]
323 fn dispatch_with_source_refuses_to_fall_back_to_synthetic_for_real_models() {
324 let synth: Box<dyn Engine> = Box::new(SyntheticEngine::new());
325 let multi = MultiEngine::new(vec![synth]);
326 let source = sd_cpp_source();
327 let err = multi
328 .dispatch_with_source("some-real-flux-model", image_task(), &source)
329 .unwrap_err()
330 .to_string();
331 assert!(
332 err.contains("no `sdcpp` engine compiled"),
333 "expected no-sdcpp-backend error, got: {err}"
334 );
335 }
336
337 #[test]
345 fn dispatch_with_source_warns_when_wanted_engine_missing() {
346 let logs = crate::test_support::capture(|| {
347 let synth: Box<dyn Engine> = Box::new(SyntheticEngine::new());
348 let multi = MultiEngine::new(vec![synth]);
349 let source = sd_cpp_source();
350 let _ = multi.dispatch_with_source("some-real-flux-model", image_task(), &source);
351 });
352 assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
353 assert!(
354 logs.contains("studio_worker::engine::multi"),
355 "expected multi target, got: {logs}"
356 );
357 assert!(logs.contains("op=\"pick\""), "expected op field: {logs}");
358 assert!(
359 logs.contains("sdcpp"),
360 "expected wanted engine name in breadcrumb: {logs}"
361 );
362 assert!(
363 logs.contains("some-real-flux-model"),
364 "expected model id in breadcrumb: {logs}"
365 );
366 }
367
368 #[test]
372 fn dispatch_with_source_routes_synthetic_engine_for_synthetic_models() {
373 let synth: Box<dyn Engine> = Box::new(SyntheticEngine::new());
374 let multi = MultiEngine::new(vec![synth]);
375 let source = crate::types::ModelSource {
376 engine: crate::types::ModelEngine::Synthetic,
377 files: vec![],
378 cli_defaults: crate::types::ModelCliDefaults {
379 cfg_scale: 1.0,
380 steps: 8,
381 width: 1024,
382 height: 1024,
383 sampling_method: None,
384 ..Default::default()
385 },
386 };
387 let result = multi
388 .dispatch_with_source("synthetic", image_task(), &source)
389 .unwrap();
390 assert!(matches!(result, TaskResult::Image { .. }));
391 }
392}