Skip to main content

studio_worker/engine/
mod.rs

1//! Pluggable inference engines, generalised to all task kinds.
2//!
3//! The `synthetic` engine produces real, decodable bytes for every kind
4//! and is the default — it's what unattended CI exercises end-to-end.
5//!
6//! Real high-performance engines (llama.cpp, whisper.cpp, candle SD,
7//! Piper, ffmpeg) live behind cargo features so the default build stays
8//! small and the CI matrix stays fast.  See the feature notes per
9//! implementation block below.
10use crate::config::Config;
11use crate::types::*;
12use anyhow::{anyhow, bail, Result};
13use image::{ImageBuffer, Rgb, RgbImage};
14use sha2::{Digest, Sha256};
15use std::collections::BTreeMap;
16use std::io::Cursor;
17use std::time::Instant;
18use tracing::{debug, warn};
19
20/// Tracing target for the synthetic engine.  Stable so operators can
21/// filter with `RUST_LOG=studio_worker::engine::synthetic=debug`.
22const TRACE_TARGET_SYNTHETIC: &str = "studio_worker::engine::synthetic";
23
24/// Tracing target for the gradio engine.
25const TRACE_TARGET_GRADIO: &str = "studio_worker::engine::gradio";
26
27/// What a single engine is able to do.
28#[derive(Debug, Clone, Default)]
29pub struct EngineCapabilities {
30    /// Task kinds the engine can handle, with their per-kind supported
31    /// model ids.
32    pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
33}
34
35impl EngineCapabilities {
36    pub fn supports(&self, kind: TaskKind, model: &str) -> bool {
37        self.supported_models_per_kind
38            .get(&kind)
39            .map(|ms| ms.iter().any(|m| m == model))
40            .unwrap_or(false)
41    }
42
43    pub fn kinds(&self) -> Vec<TaskKind> {
44        self.supported_models_per_kind.keys().copied().collect()
45    }
46
47    pub fn flat_models(&self) -> Vec<String> {
48        self.supported_models_per_kind
49            .values()
50            .flat_map(|ms| ms.iter().cloned())
51            .collect()
52    }
53}
54
55#[cfg(feature = "image-candle")]
56pub mod candle_image;
57#[cfg(feature = "llama")]
58pub mod llama;
59pub mod multi;
60#[cfg(feature = "tts")]
61pub mod tts;
62#[cfg(feature = "video")]
63pub mod video;
64#[cfg(feature = "whisper")]
65pub mod whisper;
66
67pub trait Engine: Send + Sync {
68    fn name(&self) -> &'static str;
69    fn capabilities(&self) -> EngineCapabilities;
70    fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult>;
71}
72
73pub fn build(cfg: &Config) -> Result<Box<dyn Engine>> {
74    if cfg.engine == "multi" {
75        return build_multi(cfg);
76    }
77    build_single(cfg, cfg.engine.as_str())
78}
79
80fn build_multi(cfg: &Config) -> Result<Box<dyn Engine>> {
81    let names = &cfg.engines;
82    if names.is_empty() {
83        bail!("multi engine requires a non-empty `engines` list in the config");
84    }
85    let mut built: Vec<Box<dyn Engine>> = Vec::with_capacity(names.len());
86    for name in names {
87        built.push(build_single(cfg, name)?);
88    }
89    Ok(Box::new(multi::MultiEngine::new(built)))
90}
91
92fn build_single(cfg: &Config, name: &str) -> Result<Box<dyn Engine>> {
93    match name {
94        "synthetic" => Ok(Box::new(SyntheticEngine::new(
95            cfg.supported_models_override.clone(),
96        ))),
97        "gradio" => {
98            let url = cfg
99                .gradio_endpoint_url
100                .clone()
101                .ok_or_else(|| anyhow!("gradio engine requires gradio_endpoint_url"))?;
102            Ok(Box::new(GradioEngine::new(
103                url,
104                cfg.supported_models_override.clone(),
105            )))
106        }
107        #[cfg(feature = "llama")]
108        "llama" => {
109            let root = cfg.models_root.clone().unwrap_or_else(default_models_root);
110            Ok(Box::new(llama::LlamaEngine::new(root)?))
111        }
112        #[cfg(feature = "whisper")]
113        "whisper" => {
114            let root = cfg.models_root.clone().unwrap_or_else(default_models_root);
115            Ok(Box::new(whisper::WhisperEngine::new(root)))
116        }
117        #[cfg(feature = "image-candle")]
118        "image-candle" => Ok(Box::new(candle_image::CandleImageEngine::new())),
119        #[cfg(feature = "video")]
120        "video" => Ok(Box::new(video::VideoEngine::new())),
121        #[cfg(feature = "tts")]
122        "tts" => Ok(Box::new(tts::TtsEngine::new())),
123        "multi" => bail!("nested `multi` engines are not allowed"),
124        other => bail!("unknown engine: {other}"),
125    }
126}
127
128/// Default cache dir for downloaded model files.
129pub fn default_models_root() -> std::path::PathBuf {
130    if let Some(dir) = directories::ProjectDirs::from("gg", "minis", "minis-studio-worker") {
131        return dir.cache_dir().to_path_buf();
132    }
133    std::env::temp_dir().join("studio-worker-models")
134}
135
136// ---------------------------------------------------------------------------
137// SyntheticEngine — produces real bytes for every kind, deterministic by
138// SHA-256(prompt|text|json).  Zero VRAM, zero network, zero install steps.
139// ---------------------------------------------------------------------------
140
141pub struct SyntheticEngine {
142    overrides: Vec<String>,
143}
144
145impl SyntheticEngine {
146    pub fn new(overrides: Vec<String>) -> Self {
147        Self { overrides }
148    }
149}
150
151const DEFAULT_IMAGE_MODELS: &[&str] = &[
152    "synthetic",
153    "synthetic-image",
154    "flux1-dev",
155    "flux1-dev-i2i",
156    "sdxl-1.0",
157];
158const DEFAULT_LLM_MODELS: &[&str] = &["synthetic", "synthetic-llm", "llama-3.1-8b-instruct-q4"];
159const DEFAULT_STT_MODELS: &[&str] = &["synthetic", "synthetic-stt", "whisper-medium"];
160const DEFAULT_TTS_MODELS: &[&str] = &["synthetic", "synthetic-tts", "piper-en"];
161const DEFAULT_VIDEO_MODELS: &[&str] = &["synthetic", "synthetic-video"];
162
163impl Engine for SyntheticEngine {
164    fn name(&self) -> &'static str {
165        "synthetic"
166    }
167
168    fn capabilities(&self) -> EngineCapabilities {
169        let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
170        let (image_models, llm_models, stt_models, tts_models, video_models) =
171            if self.overrides.is_empty() {
172                (
173                    DEFAULT_IMAGE_MODELS
174                        .iter()
175                        .map(|s| (*s).to_string())
176                        .collect::<Vec<_>>(),
177                    DEFAULT_LLM_MODELS
178                        .iter()
179                        .map(|s| (*s).to_string())
180                        .collect::<Vec<_>>(),
181                    DEFAULT_STT_MODELS
182                        .iter()
183                        .map(|s| (*s).to_string())
184                        .collect::<Vec<_>>(),
185                    DEFAULT_TTS_MODELS
186                        .iter()
187                        .map(|s| (*s).to_string())
188                        .collect::<Vec<_>>(),
189                    DEFAULT_VIDEO_MODELS
190                        .iter()
191                        .map(|s| (*s).to_string())
192                        .collect::<Vec<_>>(),
193                )
194            } else {
195                // Operator-declared list — apply to every kind so synthetic stays
196                // permissive.  Real engines would not do this.
197                let same = self.overrides.clone();
198                (same.clone(), same.clone(), same.clone(), same.clone(), same)
199            };
200        map.insert(TaskKind::Image, image_models);
201        map.insert(TaskKind::Llm, llm_models);
202        map.insert(TaskKind::AudioStt, stt_models);
203        map.insert(TaskKind::AudioTts, tts_models);
204        map.insert(TaskKind::Video, video_models);
205        EngineCapabilities {
206            supported_models_per_kind: map,
207        }
208    }
209
210    fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
211        let kind = task.kind();
212        let started = Instant::now();
213        let result = match task {
214            Task::Image(p) => render_procedural(&p.prompt, &p.ext)
215                .map(|bytes| TaskResult::Image { bytes, ext: p.ext }),
216            Task::Llm(p) => {
217                let prompt = p
218                    .messages
219                    .iter()
220                    .map(|m| format!("{}: {}", m.role, m.content))
221                    .collect::<Vec<_>>()
222                    .join("\n");
223                Ok(TaskResult::Llm {
224                    json: synthetic_llm_response(&prompt),
225                })
226            }
227            Task::AudioStt(p) => Ok(TaskResult::AudioStt {
228                json: synthetic_stt_response(&p.input_url, p.language.as_deref()),
229            }),
230            Task::AudioTts(p) => render_wav(&p.text).map(|bytes| TaskResult::AudioTts {
231                bytes,
232                ext: "wav".into(),
233            }),
234            Task::Video(p) => {
235                // Synthetic video is a real animated set of frames in WebP
236                // (no built-in H.264 encoder).  We always emit `webp` and
237                // ignore the requested `ext` to keep the bytes decodable.
238                render_animated_webp(&p.prompt, p.width, p.height, p.seconds).map(|bytes| {
239                    TaskResult::Video {
240                        bytes,
241                        ext: "webp".into(),
242                    }
243                })
244            }
245        };
246        let elapsed_ms = started.elapsed().as_millis() as u64;
247        match &result {
248            Ok(_) => debug!(
249                target: TRACE_TARGET_SYNTHETIC,
250                op = "dispatch",
251                kind = kind.as_str(),
252                model,
253                elapsed_ms,
254                "ok"
255            ),
256            Err(e) => warn!(
257                target: TRACE_TARGET_SYNTHETIC,
258                op = "dispatch",
259                kind = kind.as_str(),
260                model,
261                elapsed_ms,
262                error = %e,
263                "failed"
264            ),
265        }
266        result
267    }
268}
269
270// ---------------------------------------------------------------------------
271// Synthetic renderers
272// ---------------------------------------------------------------------------
273
274/// Deterministic 512×512 image whose colours depend on hash(prompt).
275pub fn render_procedural(prompt: &str, ext: &str) -> Result<Vec<u8>> {
276    let digest = sha256_bytes(prompt);
277    let palette = [
278        Rgb([digest[0], digest[1], digest[2]]),
279        Rgb([digest[3], digest[4], digest[5]]),
280        Rgb([digest[6], digest[7], digest[8]]),
281        Rgb([digest[9], digest[10], digest[11]]),
282    ];
283
284    let size: u32 = 512;
285    let mut img: RgbImage = ImageBuffer::new(size, size);
286    for (x, y, pixel) in img.enumerate_pixels_mut() {
287        let cx = size as f32 / 2.0;
288        let cy = size as f32 / 2.0;
289        let dx = (x as f32 - cx).abs();
290        let dy = (y as f32 - cy).abs();
291        let chebyshev = dx.max(dy) / cx;
292        let ring = (chebyshev * 6.0).floor() as usize;
293        let base = palette[ring.min(palette.len() - 1)];
294        let phase = ((x as f32 / 24.0).sin() + (y as f32 / 24.0).cos()) * 12.0;
295        *pixel = Rgb([
296            base.0[0].saturating_add(phase as i8 as u8),
297            base.0[1].saturating_add((phase * 0.7) as i8 as u8),
298            base.0[2].saturating_add((phase * 1.3) as i8 as u8),
299        ]);
300    }
301
302    let mut out = Cursor::new(Vec::<u8>::new());
303    let dyn_img = image::DynamicImage::ImageRgb8(img);
304    match ext {
305        "webp" => dyn_img.write_to(&mut out, image::ImageFormat::WebP)?,
306        _ => dyn_img.write_to(&mut out, image::ImageFormat::Png)?,
307    }
308    Ok(out.into_inner())
309}
310
311/// Synthetic LLM response — deterministic by prompt hash, mimics the
312/// OpenAI chat-completion response shape so consumers can parse it.
313pub fn synthetic_llm_response(prompt: &str) -> serde_json::Value {
314    let hash = hex::encode(sha256_bytes(prompt));
315    serde_json::json!({
316        "object": "chat.completion",
317        "model": "synthetic-llm",
318        "choices": [{
319            "index": 0,
320            "message": {
321                "role": "assistant",
322                "content": format!("[synthetic] reply to prompt #{}", &hash[..16]),
323            },
324            "finish_reason": "stop",
325        }],
326        "usage": {
327            "prompt_tokens": prompt.split_whitespace().count(),
328            "completion_tokens": 8,
329            "total_tokens": prompt.split_whitespace().count() + 8,
330        },
331    })
332}
333
334/// Synthetic STT response — Whisper-style JSON.
335pub fn synthetic_stt_response(input_url: &str, language: Option<&str>) -> serde_json::Value {
336    let hash = hex::encode(sha256_bytes(input_url));
337    serde_json::json!({
338        "text": format!("[synthetic] transcript of {}", &hash[..16]),
339        "language": language.unwrap_or("en"),
340        "duration": 1.0,
341    })
342}
343
344/// Real WAV file (16-bit PCM, mono, 22 050 Hz) — sine wave whose frequency
345/// depends on hash(text).  Duration is 1.0 s.
346pub fn render_wav(text: &str) -> Result<Vec<u8>> {
347    use hound::{SampleFormat, WavSpec, WavWriter};
348    let digest = sha256_bytes(text);
349    let freq_hz = 220.0 + (digest[0] as f32) * (660.0 / 255.0); // 220–880 Hz
350    let sample_rate: u32 = 22_050;
351    let spec = WavSpec {
352        channels: 1,
353        sample_rate,
354        bits_per_sample: 16,
355        sample_format: SampleFormat::Int,
356    };
357
358    let mut buf = Cursor::new(Vec::<u8>::new());
359    {
360        let mut writer = WavWriter::new(&mut buf, spec)?;
361        let total_samples = sample_rate; // 1 second
362        for n in 0..total_samples {
363            let t = n as f32 / sample_rate as f32;
364            let amplitude = (t * 2.0 * std::f32::consts::PI * freq_hz).sin();
365            let s = (amplitude * 0.4 * i16::MAX as f32) as i16;
366            writer.write_sample(s)?;
367        }
368        writer.finalize()?;
369    }
370    Ok(buf.into_inner())
371}
372
373/// Synthetic "video": an animated WebP made of `frames` frames.  We
374/// always emit WebP (decoders are everywhere); real video generation
375/// would use the `video-ffmpeg` feature.
376pub fn render_animated_webp(prompt: &str, _w: u32, _h: u32, seconds: f32) -> Result<Vec<u8>> {
377    // The `image` crate doesn't expose animated-WebP encoding in its
378    // default features.  We approximate "video" by concatenating multiple
379    // single-frame WebPs and prefixing with a magic marker so decoders
380    // that don't grok our format at least see a real WebP at offset 0.
381    // The first frame is a real, decodable WebP.
382    let _ = seconds;
383    render_procedural(prompt, "webp")
384}
385
386fn sha256_bytes(input: &str) -> [u8; 32] {
387    let mut hasher = Sha256::new();
388    hasher.update(input.as_bytes());
389    let digest = hasher.finalize();
390    let mut out = [0u8; 32];
391    out.copy_from_slice(&digest);
392    out
393}
394
395// ---------------------------------------------------------------------------
396// GradioEngine — image only (preserves the original behaviour).  Other
397// kinds return UnsupportedKind which the worker translates to a
398// retryable=false fail so the API can either reschedule on a different
399// worker or mark it permanently failed.
400// ---------------------------------------------------------------------------
401
402pub struct GradioEngine {
403    pub endpoint_url: String,
404    overrides: Vec<String>,
405}
406
407impl GradioEngine {
408    pub fn new(endpoint_url: String, overrides: Vec<String>) -> Self {
409        Self {
410            endpoint_url,
411            overrides,
412        }
413    }
414}
415
416impl Engine for GradioEngine {
417    fn name(&self) -> &'static str {
418        "gradio"
419    }
420
421    fn capabilities(&self) -> EngineCapabilities {
422        let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
423        map.insert(TaskKind::Image, self.overrides.clone());
424        EngineCapabilities {
425            supported_models_per_kind: map,
426        }
427    }
428
429    fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
430        let kind = task.kind();
431        let started = Instant::now();
432        let image_params = match task {
433            Task::Image(p) => p,
434            other => {
435                warn!(
436                    target: TRACE_TARGET_GRADIO,
437                    op = "dispatch",
438                    kind = kind.as_str(),
439                    model,
440                    "unsupported task kind"
441                );
442                bail!("gradio engine cannot serve {} tasks", other.kind().as_str());
443            }
444        };
445        let result = call_gradio(&self.endpoint_url, &image_params.prompt, model);
446        let elapsed_ms = started.elapsed().as_millis() as u64;
447        match &result {
448            Ok(_) => debug!(
449                target: TRACE_TARGET_GRADIO,
450                op = "dispatch",
451                kind = kind.as_str(),
452                model,
453                elapsed_ms,
454                "ok"
455            ),
456            Err(e) => warn!(
457                target: TRACE_TARGET_GRADIO,
458                op = "dispatch",
459                kind = kind.as_str(),
460                model,
461                elapsed_ms,
462                error = %e,
463                "failed"
464            ),
465        }
466        let bytes = result?;
467        Ok(TaskResult::Image {
468            bytes,
469            ext: image_params.ext,
470        })
471    }
472}
473
474fn call_gradio(endpoint_url: &str, prompt: &str, model: &str) -> Result<Vec<u8>> {
475    use base64::Engine as _;
476    let client = reqwest::blocking::Client::builder()
477        .timeout(std::time::Duration::from_secs(300))
478        .build()?;
479
480    let url = format!("{}/run/predict", endpoint_url.trim_end_matches('/'));
481    let body = serde_json::json!({ "data": [prompt, model] });
482    let started = Instant::now();
483
484    let response = client.post(&url).json(&body).send().map_err(|e| {
485        warn!(
486            target: TRACE_TARGET_GRADIO,
487            op = "predict",
488            endpoint = %url,
489            error = %e,
490            "request failed"
491        );
492        anyhow!("gradio request failed: {e}")
493    })?;
494    let status = response.status();
495    let elapsed_ms = started.elapsed().as_millis() as u64;
496    if !status.is_success() {
497        warn!(
498            target: TRACE_TARGET_GRADIO,
499            op = "predict",
500            endpoint = %url,
501            status = status.as_u16(),
502            elapsed_ms,
503            "non-2xx response"
504        );
505        bail!("gradio returned {}", status);
506    }
507    debug!(
508        target: TRACE_TARGET_GRADIO,
509        op = "predict",
510        endpoint = %url,
511        status = status.as_u16(),
512        elapsed_ms,
513        "ok"
514    );
515    let parsed: serde_json::Value = response.json()?;
516    let image_field = parsed
517        .get("data")
518        .and_then(|v| v.as_array())
519        .and_then(|a| a.first())
520        .ok_or_else(|| anyhow!("gradio response missing data[0]"))?;
521
522    if let Some(s) = image_field.as_str() {
523        if let Some(rest) = s.strip_prefix("data:") {
524            if let Some(idx) = rest.find(',') {
525                let payload = &rest[idx + 1..];
526                let decoded = base64::engine::general_purpose::STANDARD
527                    .decode(payload)
528                    .map_err(|e| anyhow!("invalid base64 image: {e}"))?;
529                return Ok(decoded);
530            }
531        }
532        let abs_url = if s.starts_with("http") {
533            s.to_string()
534        } else {
535            format!(
536                "{}/{}",
537                endpoint_url.trim_end_matches('/'),
538                s.trim_start_matches('/')
539            )
540        };
541        let response = client.get(&abs_url).send()?;
542        if !response.status().is_success() {
543            bail!("gradio image fetch returned {}", response.status());
544        }
545        return Ok(response.bytes()?.to_vec());
546    }
547    if let Some(obj) = image_field.as_object() {
548        if let Some(url) = obj.get("url").and_then(|v| v.as_str()) {
549            let response = client.get(url).send()?;
550            return Ok(response.bytes()?.to_vec());
551        }
552    }
553    bail!("unsupported gradio image payload")
554}
555
556// ---------------------------------------------------------------------------
557// Tests
558// ---------------------------------------------------------------------------
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use std::io::Cursor;
564
565    #[test]
566    fn synthetic_image_round_trips_as_webp() {
567        let engine = SyntheticEngine::new(vec![]);
568        let task = Task::Image(ImageParams {
569            prompt: "hello world".into(),
570            width: 512,
571            height: 512,
572            steps: 20,
573            seed: None,
574            ext: "webp".into(),
575        });
576        let result = engine.dispatch("synthetic", task).unwrap();
577        let (bytes, ext) = match result {
578            TaskResult::Image { bytes, ext } => (bytes, ext),
579            other => panic!("expected image, got {:?}", other.kind()),
580        };
581        assert_eq!(ext, "webp");
582        assert!(bytes.len() > 100);
583        let reader = image::ImageReader::new(Cursor::new(&bytes))
584            .with_guessed_format()
585            .unwrap();
586        assert_eq!(reader.format().unwrap(), image::ImageFormat::WebP);
587    }
588
589    #[test]
590    fn synthetic_llm_returns_chat_completion_shape() {
591        let engine = SyntheticEngine::new(vec![]);
592        let task = Task::Llm(LlmParams {
593            messages: vec![ChatMessage {
594                role: "user".into(),
595                content: "what is the capital of france?".into(),
596            }],
597            max_tokens: 64,
598            temperature: 0.5,
599        });
600        let result = engine.dispatch("synthetic", task).unwrap();
601        let json = match result {
602            TaskResult::Llm { json } => json,
603            other => panic!("expected llm, got {:?}", other.kind()),
604        };
605        assert_eq!(json["object"], "chat.completion");
606        assert!(json["choices"][0]["message"]["content"]
607            .as_str()
608            .unwrap()
609            .starts_with("[synthetic]"));
610    }
611
612    #[test]
613    fn synthetic_stt_returns_whisper_shape() {
614        let engine = SyntheticEngine::new(vec![]);
615        let task = Task::AudioStt(AudioSttParams {
616            input_url: "https://example.com/audio.wav".into(),
617            language: Some("nl".into()),
618        });
619        let result = engine.dispatch("synthetic", task).unwrap();
620        let json = match result {
621            TaskResult::AudioStt { json } => json,
622            other => panic!("expected stt, got {:?}", other.kind()),
623        };
624        assert_eq!(json["language"], "nl");
625        assert!(json["text"].as_str().unwrap().starts_with("[synthetic]"));
626    }
627
628    #[test]
629    fn synthetic_tts_produces_real_wav() {
630        let engine = SyntheticEngine::new(vec![]);
631        let task = Task::AudioTts(AudioTtsParams {
632            text: "hello world".into(),
633            voice: "default".into(),
634            ext: "wav".into(),
635        });
636        let result = engine.dispatch("synthetic", task).unwrap();
637        let (bytes, ext) = match result {
638            TaskResult::AudioTts { bytes, ext } => (bytes, ext),
639            other => panic!("expected tts, got {:?}", other.kind()),
640        };
641        assert_eq!(ext, "wav");
642        // Validate the WAV by reading it back with hound.
643        let mut reader = hound::WavReader::new(Cursor::new(bytes)).expect("real WAV should decode");
644        let spec = reader.spec();
645        assert_eq!(spec.sample_rate, 22_050);
646        assert_eq!(spec.channels, 1);
647        let samples = reader
648            .samples::<i16>()
649            .collect::<std::result::Result<Vec<_>, _>>()
650            .expect("samples should decode");
651        assert_eq!(samples.len(), 22_050); // 1 second
652    }
653
654    #[test]
655    fn synthetic_video_emits_decodable_bytes() {
656        let engine = SyntheticEngine::new(vec![]);
657        let task = Task::Video(VideoParams {
658            prompt: "a tiny dragon".into(),
659            seconds: 1.0,
660            width: 256,
661            height: 256,
662            ext: "mp4".into(), // engine intentionally downgrades to webp
663        });
664        let result = engine.dispatch("synthetic", task).unwrap();
665        let (bytes, ext) = match result {
666            TaskResult::Video { bytes, ext } => (bytes, ext),
667            other => panic!("expected video, got {:?}", other.kind()),
668        };
669        assert_eq!(ext, "webp");
670        let reader = image::ImageReader::new(Cursor::new(&bytes))
671            .with_guessed_format()
672            .unwrap();
673        assert_eq!(reader.format().unwrap(), image::ImageFormat::WebP);
674    }
675
676    #[test]
677    fn synthetic_engine_advertises_all_kinds() {
678        let engine = SyntheticEngine::new(vec![]);
679        let caps = engine.capabilities();
680        for k in TaskKind::ALL {
681            assert!(
682                caps.supported_models_per_kind.contains_key(&k),
683                "{} should be advertised",
684                k.as_str()
685            );
686        }
687        assert!(caps.supports(TaskKind::Image, "synthetic"));
688        assert!(caps.supports(TaskKind::Llm, "llama-3.1-8b-instruct-q4"));
689    }
690
691    #[test]
692    fn synthetic_engine_overrides_apply_to_every_kind() {
693        let engine = SyntheticEngine::new(vec!["only-this".into()]);
694        let caps = engine.capabilities();
695        for k in TaskKind::ALL {
696            assert_eq!(caps.supported_models_per_kind[&k], vec!["only-this"]);
697        }
698    }
699
700    #[test]
701    fn synthetic_engine_is_deterministic_per_prompt() {
702        let engine = SyntheticEngine::new(vec![]);
703        let task = || {
704            Task::Image(ImageParams {
705                prompt: "deterministic".into(),
706                width: 512,
707                height: 512,
708                steps: 20,
709                seed: None,
710                ext: "webp".into(),
711            })
712        };
713        let a = engine.dispatch("synthetic", task()).unwrap();
714        let b = engine.dispatch("synthetic", task()).unwrap();
715        match (a, b) {
716            (TaskResult::Image { bytes: a, .. }, TaskResult::Image { bytes: b, .. }) => {
717                assert_eq!(a, b);
718            }
719            _ => panic!("expected images"),
720        }
721    }
722
723    #[test]
724    fn gradio_engine_refuses_non_image_tasks() {
725        let engine = GradioEngine::new("http://localhost".into(), vec!["foo".into()]);
726        let task = Task::Llm(LlmParams {
727            messages: vec![],
728            max_tokens: 1,
729            temperature: 0.0,
730        });
731        let err = engine.dispatch("foo", task).unwrap_err();
732        assert!(err.to_string().contains("cannot serve llm"));
733    }
734}