Skip to main content

studio_worker/engine/
multi.rs

1//! Composite engine that delegates per task kind.
2//!
3//! Operators set `engine = "multi"` and `engines = [...]` in their
4//! config to combine multiple inference backends in a single binary:
5//!
6//! ```toml
7//! engine = "multi"
8//! engines = ["llama", "tts", "video", "synthetic"]
9//! ```
10//!
11//! For each task kind, [`MultiEngine`] picks the first engine in the
12//! list that advertises support for the requested model.  If no engine
13//! claims it, the dispatch fails with the same "cannot serve <kind>"
14//! shape the studio's claim loop already knows how to handle.
15use crate::engine::{Engine, EngineCapabilities};
16use crate::types::*;
17use anyhow::{bail, Result};
18use std::collections::BTreeMap;
19use tracing::{debug, warn};
20
21/// Tracing target for the multi engine.  Stable so operators can
22/// filter with `RUST_LOG=studio_worker::engine::multi=debug`.
23const TRACE_TARGET: &str = "studio_worker::engine::multi";
24
25pub struct MultiEngine {
26    engines: Vec<Box<dyn Engine>>,
27}
28
29impl MultiEngine {
30    pub fn new(engines: Vec<Box<dyn Engine>>) -> Self {
31        Self { engines }
32    }
33
34    fn pick_for(&self, kind: TaskKind, model: &str) -> Option<&dyn Engine> {
35        for e in &self.engines {
36            if e.capabilities().supports(kind, model) {
37                debug!(
38                    target: TRACE_TARGET,
39                    op = "pick",
40                    kind = kind.as_str(),
41                    model,
42                    sub_engine = e.name(),
43                    r#match = "exact",
44                    "engine selected"
45                );
46                return Some(e.as_ref());
47            }
48        }
49        // Fallback: any engine that advertises the kind at all.
50        for e in &self.engines {
51            if e.capabilities()
52                .supported_models_per_kind
53                .contains_key(&kind)
54            {
55                debug!(
56                    target: TRACE_TARGET,
57                    op = "pick",
58                    kind = kind.as_str(),
59                    model,
60                    sub_engine = e.name(),
61                    r#match = "fallback",
62                    "engine selected by kind fallback"
63                );
64                return Some(e.as_ref());
65            }
66        }
67        warn!(
68            target: TRACE_TARGET,
69            op = "pick",
70            kind = kind.as_str(),
71            model,
72            "no engine advertises this kind"
73        );
74        None
75    }
76}
77
78impl Engine for MultiEngine {
79    fn name(&self) -> &'static str {
80        "multi"
81    }
82
83    fn capabilities(&self) -> EngineCapabilities {
84        let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
85        for e in &self.engines {
86            for (kind, models) in e.capabilities().supported_models_per_kind {
87                let entry = map.entry(kind).or_default();
88                for m in models {
89                    if !entry.contains(&m) {
90                        entry.push(m);
91                    }
92                }
93            }
94        }
95        EngineCapabilities {
96            supported_models_per_kind: map,
97        }
98    }
99
100    fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
101        let kind = task.kind();
102        let Some(engine) = self.pick_for(kind, model) else {
103            bail!("multi engine cannot serve {} tasks", kind.as_str());
104        };
105        engine.dispatch(model, task)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::engine::SyntheticEngine;
113
114    struct StubEngine {
115        name: &'static str,
116        kinds: Vec<TaskKind>,
117        models: Vec<String>,
118    }
119
120    impl Engine for StubEngine {
121        fn name(&self) -> &'static str {
122            self.name
123        }
124        fn capabilities(&self) -> EngineCapabilities {
125            let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
126            for k in &self.kinds {
127                map.insert(*k, self.models.clone());
128            }
129            EngineCapabilities {
130                supported_models_per_kind: map,
131            }
132        }
133        fn dispatch(&self, _model: &str, task: Task) -> Result<TaskResult> {
134            // Return a sentinel result tagged with the engine name so we
135            // can verify routing in tests.
136            match task {
137                Task::Image(_) => Ok(TaskResult::Image {
138                    bytes: self.name.as_bytes().to_vec(),
139                    ext: "test".into(),
140                }),
141                Task::Llm(_) => Ok(TaskResult::Llm {
142                    json: serde_json::json!({ "from": self.name }),
143                }),
144                _ => bail!("stub doesn't serve this"),
145            }
146        }
147    }
148
149    fn image_task() -> Task {
150        Task::Image(ImageParams {
151            prompt: "x".into(),
152            width: 64,
153            height: 64,
154            steps: 1,
155            seed: None,
156            ext: "webp".into(),
157        })
158    }
159
160    fn llm_task() -> Task {
161        Task::Llm(LlmParams {
162            messages: vec![],
163            max_tokens: 1,
164            temperature: 0.0,
165        })
166    }
167
168    #[test]
169    fn multi_picks_first_engine_supporting_the_kind_and_model() {
170        let a: Box<dyn Engine> = Box::new(StubEngine {
171            name: "a",
172            kinds: vec![TaskKind::Image],
173            models: vec!["alpha".into()],
174        });
175        let b: Box<dyn Engine> = Box::new(StubEngine {
176            name: "b",
177            kinds: vec![TaskKind::Image],
178            models: vec!["beta".into()],
179        });
180        let multi = MultiEngine::new(vec![a, b]);
181
182        let result = multi.dispatch("alpha", image_task()).unwrap();
183        match result {
184            TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"a"),
185            _ => panic!("expected image"),
186        }
187        let result = multi.dispatch("beta", image_task()).unwrap();
188        match result {
189            TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"b"),
190            _ => panic!("expected image"),
191        }
192    }
193
194    #[test]
195    fn multi_falls_back_to_first_engine_advertising_the_kind() {
196        let alpha_only: Box<dyn Engine> = Box::new(StubEngine {
197            name: "alpha",
198            kinds: vec![TaskKind::Image],
199            models: vec!["alpha-image".into()],
200        });
201        let llm_only: Box<dyn Engine> = Box::new(StubEngine {
202            name: "llm",
203            kinds: vec![TaskKind::Llm],
204            models: vec!["llama-some".into()],
205        });
206        let multi = MultiEngine::new(vec![alpha_only, llm_only]);
207
208        let result = multi.dispatch("unknown-model", llm_task()).unwrap();
209        match result {
210            TaskResult::Llm { json } => assert_eq!(json["from"], "llm"),
211            _ => panic!("expected llm"),
212        }
213    }
214
215    #[test]
216    fn multi_errors_when_no_engine_serves_kind() {
217        let image_only: Box<dyn Engine> = Box::new(StubEngine {
218            name: "image",
219            kinds: vec![TaskKind::Image],
220            models: vec!["x".into()],
221        });
222        let multi = MultiEngine::new(vec![image_only]);
223        let err = multi.dispatch("x", llm_task()).unwrap_err();
224        assert!(err.to_string().contains("cannot serve llm"));
225    }
226
227    #[test]
228    fn capabilities_union_across_all_engines() {
229        let img: Box<dyn Engine> = Box::new(SyntheticEngine::new(vec![]));
230        let stub: Box<dyn Engine> = Box::new(StubEngine {
231            name: "extra",
232            kinds: vec![TaskKind::Image],
233            models: vec!["extra-image-model".into()],
234        });
235        let multi = MultiEngine::new(vec![img, stub]);
236        let caps = multi.capabilities();
237        let image = &caps.supported_models_per_kind[&TaskKind::Image];
238        assert!(image.contains(&"synthetic".to_string()));
239        assert!(image.contains(&"extra-image-model".to_string()));
240    }
241
242    #[test]
243    fn name_is_multi() {
244        let multi = MultiEngine::new(vec![]);
245        assert_eq!(multi.name(), "multi");
246    }
247}