1use crate::engine::{Engine, EngineCapabilities};
16use crate::types::*;
17use anyhow::{bail, Result};
18use std::collections::BTreeMap;
19use tracing::{debug, warn};
20
21const TRACE_TARGET: &str = "studio_worker::engine::multi";
24
25pub struct MultiEngine {
26 engines: Vec<Box<dyn Engine>>,
27}
28
29impl MultiEngine {
30 pub fn new(engines: Vec<Box<dyn Engine>>) -> Self {
31 Self { engines }
32 }
33
34 fn pick_for(&self, kind: TaskKind, model: &str) -> Option<&dyn Engine> {
35 for e in &self.engines {
36 if e.capabilities().supports(kind, model) {
37 debug!(
38 target: TRACE_TARGET,
39 op = "pick",
40 kind = kind.as_str(),
41 model,
42 sub_engine = e.name(),
43 r#match = "exact",
44 "engine selected"
45 );
46 return Some(e.as_ref());
47 }
48 }
49 for e in &self.engines {
51 if e.capabilities()
52 .supported_models_per_kind
53 .contains_key(&kind)
54 {
55 debug!(
56 target: TRACE_TARGET,
57 op = "pick",
58 kind = kind.as_str(),
59 model,
60 sub_engine = e.name(),
61 r#match = "fallback",
62 "engine selected by kind fallback"
63 );
64 return Some(e.as_ref());
65 }
66 }
67 warn!(
68 target: TRACE_TARGET,
69 op = "pick",
70 kind = kind.as_str(),
71 model,
72 "no engine advertises this kind"
73 );
74 None
75 }
76}
77
78impl Engine for MultiEngine {
79 fn name(&self) -> &'static str {
80 "multi"
81 }
82
83 fn capabilities(&self) -> EngineCapabilities {
84 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
85 for e in &self.engines {
86 for (kind, models) in e.capabilities().supported_models_per_kind {
87 let entry = map.entry(kind).or_default();
88 for m in models {
89 if !entry.contains(&m) {
90 entry.push(m);
91 }
92 }
93 }
94 }
95 EngineCapabilities {
96 supported_models_per_kind: map,
97 }
98 }
99
100 fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
101 let kind = task.kind();
102 let Some(engine) = self.pick_for(kind, model) else {
103 bail!("multi engine cannot serve {} tasks", kind.as_str());
104 };
105 engine.dispatch(model, task)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::engine::SyntheticEngine;
113
114 struct StubEngine {
115 name: &'static str,
116 kinds: Vec<TaskKind>,
117 models: Vec<String>,
118 }
119
120 impl Engine for StubEngine {
121 fn name(&self) -> &'static str {
122 self.name
123 }
124 fn capabilities(&self) -> EngineCapabilities {
125 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
126 for k in &self.kinds {
127 map.insert(*k, self.models.clone());
128 }
129 EngineCapabilities {
130 supported_models_per_kind: map,
131 }
132 }
133 fn dispatch(&self, _model: &str, task: Task) -> Result<TaskResult> {
134 match task {
137 Task::Image(_) => Ok(TaskResult::Image {
138 bytes: self.name.as_bytes().to_vec(),
139 ext: "test".into(),
140 }),
141 Task::Llm(_) => Ok(TaskResult::Llm {
142 json: serde_json::json!({ "from": self.name }),
143 }),
144 _ => bail!("stub doesn't serve this"),
145 }
146 }
147 }
148
149 fn image_task() -> Task {
150 Task::Image(ImageParams {
151 prompt: "x".into(),
152 width: 64,
153 height: 64,
154 steps: 1,
155 seed: None,
156 ext: "webp".into(),
157 })
158 }
159
160 fn llm_task() -> Task {
161 Task::Llm(LlmParams {
162 messages: vec![],
163 max_tokens: 1,
164 temperature: 0.0,
165 })
166 }
167
168 #[test]
169 fn multi_picks_first_engine_supporting_the_kind_and_model() {
170 let a: Box<dyn Engine> = Box::new(StubEngine {
171 name: "a",
172 kinds: vec![TaskKind::Image],
173 models: vec!["alpha".into()],
174 });
175 let b: Box<dyn Engine> = Box::new(StubEngine {
176 name: "b",
177 kinds: vec![TaskKind::Image],
178 models: vec!["beta".into()],
179 });
180 let multi = MultiEngine::new(vec![a, b]);
181
182 let result = multi.dispatch("alpha", image_task()).unwrap();
183 match result {
184 TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"a"),
185 _ => panic!("expected image"),
186 }
187 let result = multi.dispatch("beta", image_task()).unwrap();
188 match result {
189 TaskResult::Image { bytes, .. } => assert_eq!(bytes, b"b"),
190 _ => panic!("expected image"),
191 }
192 }
193
194 #[test]
195 fn multi_falls_back_to_first_engine_advertising_the_kind() {
196 let alpha_only: Box<dyn Engine> = Box::new(StubEngine {
197 name: "alpha",
198 kinds: vec![TaskKind::Image],
199 models: vec!["alpha-image".into()],
200 });
201 let llm_only: Box<dyn Engine> = Box::new(StubEngine {
202 name: "llm",
203 kinds: vec![TaskKind::Llm],
204 models: vec!["llama-some".into()],
205 });
206 let multi = MultiEngine::new(vec![alpha_only, llm_only]);
207
208 let result = multi.dispatch("unknown-model", llm_task()).unwrap();
209 match result {
210 TaskResult::Llm { json } => assert_eq!(json["from"], "llm"),
211 _ => panic!("expected llm"),
212 }
213 }
214
215 #[test]
216 fn multi_errors_when_no_engine_serves_kind() {
217 let image_only: Box<dyn Engine> = Box::new(StubEngine {
218 name: "image",
219 kinds: vec![TaskKind::Image],
220 models: vec!["x".into()],
221 });
222 let multi = MultiEngine::new(vec![image_only]);
223 let err = multi.dispatch("x", llm_task()).unwrap_err();
224 assert!(err.to_string().contains("cannot serve llm"));
225 }
226
227 #[test]
228 fn capabilities_union_across_all_engines() {
229 let img: Box<dyn Engine> = Box::new(SyntheticEngine::new(vec![]));
230 let stub: Box<dyn Engine> = Box::new(StubEngine {
231 name: "extra",
232 kinds: vec![TaskKind::Image],
233 models: vec!["extra-image-model".into()],
234 });
235 let multi = MultiEngine::new(vec![img, stub]);
236 let caps = multi.capabilities();
237 let image = &caps.supported_models_per_kind[&TaskKind::Image];
238 assert!(image.contains(&"synthetic".to_string()));
239 assert!(image.contains(&"extra-image-model".to_string()));
240 }
241
242 #[test]
243 fn name_is_multi() {
244 let multi = MultiEngine::new(vec![]);
245 assert_eq!(multi.name(), "multi");
246 }
247}