Skip to main content

studio_worker/engine/
sdcpp.rs

1//! Engine that runs real image inference by subprocess-invoking the
2//! `stable-diffusion.cpp` (`sd-cli`) binary.
3//!
4//! The studio's offer carries a [`ModelSource`] with everything we
5//! need: an engine identifier (`sd-cpp`), the list of files to
6//! download (diffusion-model + text-encoder + VAE, each with a public
7//! URL + filename), and CLI defaults (cfg-scale, steps, dimensions).
8//! The worker has zero hardcoded model knowledge — it caches
9//! whatever the studio asks for under `cfg.models_root` and invokes
10//! `sd-cli` with the files arranged by role.
11//!
12//! Layout under `cfg.models_root` (default `~/models`):
13//! ```text
14//! ~/models/<filename1>
15//! ~/models/<filename2>
16//! …
17//! ```
18//! Files are downloaded on first use - skipped when already present
19//! under `cfg.models_root`.  The streamed body is checked against the
20//! server's `Content-Length` so a truncated download is rejected and
21//! cleaned up instead of being renamed into place as a corrupt model
22//! that every later job would fail to load.  Cached files are re-used
23//! across every subsequent job that names them.
24//!
25//! The engine self-registers only when `sd-cli` is present on the box
26//! (either at `$STUDIO_WORKER_SD_CLI`, or `~/.local/bin/sd-cli`, or on
27//! `$PATH`).  Without `sd-cli` the worker can't run real-image jobs
28//! at all so it skips registration and the multi engine falls through
29//! to synthetic for any kind it doesn't have a real backend for.
30
31use crate::engine::download::{self, TempFileGuard};
32use crate::engine::sd_provision;
33use crate::engine::{Engine, EngineCapabilities};
34use crate::types::{ImageParams, ModelFileRole, ModelSource, Task, TaskKind, TaskResult};
35use anyhow::{anyhow, bail, Context, Result};
36use parking_lot::Mutex;
37use std::collections::BTreeMap;
38use std::ffi::OsString;
39use std::path::{Path, PathBuf};
40use std::process::Command;
41use std::time::Instant;
42use tracing::{debug, info, warn};
43
44const TRACE_TARGET: &str = "studio_worker::engine::sdcpp";
45
46/// Default sample-steps when the studio's `ImageParams.steps` is the
47/// upstream default (20).  Z-Image-Turbo is an 8-step distilled
48/// schedule so 20 wastes time; we honour `ModelSource.cliDefaults.steps`
49/// instead.  Only used as the very last fallback.
50const STEPS_FALLBACK: u32 = 8;
51
52/// Worker-side engine that drives `sd-cli` per job.
53///
54/// `sd-cli` is resolved lazily on the first image job and cached: an
55/// operator install (env / PATH / `~/.local/bin`) wins, otherwise the
56/// binary is auto-provisioned into `<models_root>/bin/`.  The `Mutex`
57/// serialises that one-time resolution so two concurrent jobs can't
58/// race the download.
59pub struct SdCppEngine {
60    sd_cli: Mutex<Option<PathBuf>>,
61    models_root: PathBuf,
62}
63
64impl SdCppEngine {
65    /// Build the engine.  Always registers: `sd-cli` is resolved (and
66    /// provisioned into `<models_root>/bin/` if missing) lazily on the
67    /// first image job, so the engine serves real image work even on a
68    /// box that has never had a stable-diffusion.cpp build installed.
69    /// `models_root` is created on demand by the provisioner / model
70    /// downloader, so registration touches no filesystem.
71    pub fn new(models_root: &Path) -> Self {
72        info!(
73            target: TRACE_TARGET,
74            op = "register",
75            models_root = %models_root.display(),
76            sd_cli_name = sd_provision::binary_name(),
77            "sdcpp engine registered (sd-cli resolved/provisioned on first image job)"
78        );
79        Self {
80            sd_cli: Mutex::new(None),
81            models_root: models_root.to_path_buf(),
82        }
83    }
84
85    /// For tests: build with explicit paths (bypasses sd-cli lookup +
86    /// provisioning by seeding the resolved-path cache).
87    #[cfg(test)]
88    pub fn with_paths(sd_cli: PathBuf, models_root: PathBuf) -> Self {
89        Self {
90            sd_cli: Mutex::new(Some(sd_cli)),
91            models_root,
92        }
93    }
94
95    /// Resolve the `sd-cli` binary, provisioning it on first use.
96    /// Resolution order (operator installs win): a cached path from a
97    /// previous job, then env / `<models_root>/bin` / `~/.local/bin` /
98    /// `$PATH`, then an auto-provisioned download into
99    /// `<models_root>/bin/`.  The result is cached for the worker's
100    /// lifetime.
101    #[cfg_attr(coverage_nightly, coverage(off))]
102    fn ensure_sd_cli(&self) -> Result<PathBuf> {
103        let mut guard = self.sd_cli.lock();
104        if let Some(p) = guard.as_ref() {
105            if p.is_file() {
106                return Ok(p.clone());
107            }
108        }
109        let resolved = match resolve_sd_cli(&self.models_root) {
110            Some(p) => {
111                info!(
112                    target: TRACE_TARGET,
113                    op = "resolve",
114                    sd_cli = %p.display(),
115                    "using existing sd-cli"
116                );
117                p
118            }
119            None => sd_provision::provision(&self.models_root)
120                .context("auto-provisioning sd-cli (stable-diffusion.cpp)")?,
121        };
122        *guard = Some(resolved.clone());
123        Ok(resolved)
124    }
125
126    /// Ensure each file in `source.files` is present under
127    /// `self.models_root`.  Downloads anything missing.  Returns the
128    /// resolved local path for each file (in the same order).
129    #[cfg_attr(coverage_nightly, coverage(off))]
130    fn ensure_files(&self, source: &ModelSource) -> Result<Vec<(ModelFileRole, PathBuf)>> {
131        let mut out = Vec::with_capacity(source.files.len());
132        for file in &source.files {
133            let local = download::ensure_file(&self.models_root, file)?;
134            out.push((file.role, local));
135        }
136        Ok(out)
137    }
138
139    /// Subprocess to `sd-cli` with the resolved diffusion / VAE /
140    /// text-encoder files.  Excluded from coverage: requires an
141    /// actual `sd-cli` binary + cached model files on disk, neither
142    /// of which exists on the CI runner.  Exercised end-to-end via
143    /// the live dev loop.
144    #[cfg_attr(coverage_nightly, coverage(off))]
145    fn dispatch_image(
146        &self,
147        model: &str,
148        params: ImageParams,
149        source: &ModelSource,
150    ) -> Result<TaskResult> {
151        // Resolve (provisioning on first use) the sd-cli binary before
152        // we touch model files, so a missing binary fails fast with the
153        // provisioning error rather than after a multi-GB weight pull.
154        let sd_cli = self.ensure_sd_cli()?;
155        // Preflight the GPU runtime next: a missing Vulkan loader can't be
156        // auto-provisioned (it ships with the driver / a system package),
157        // so surface the actionable remedy now instead of after a
158        // multi-GB weight pull and a cryptic sd-cli crash.
159        if let Err(e) = sd_provision::vulkan_runtime_status() {
160            warn!(
161                target: TRACE_TARGET,
162                op = "preflight",
163                model,
164                error = %e,
165                "GPU runtime missing; refusing image job"
166            );
167            return Err(e);
168        }
169        let files = self.ensure_files(source)?;
170        // A `diffusion-model` file is the standalone diffusion weights (sd-cli `--diffusion-model`,
171        // used with split vae/clip); a `model` file is a full checkpoint (sd-cli `-m`/`--model`).
172        // Prefer the explicit diffusion-model role; fall back to a full checkpoint.
173        let diffusion_only = file_for_role(&files, ModelFileRole::DiffusionModel);
174        let full_checkpoint = diffusion_only.is_none();
175        let diffusion_model = diffusion_only
176            .or_else(|| file_for_role(&files, ModelFileRole::Model))
177            .ok_or_else(|| anyhow!("modelSource has no diffusion-model / model file"))?;
178        let vae = file_for_role(&files, ModelFileRole::Vae);
179        let text_encoder = file_for_role(&files, ModelFileRole::TextEncoder);
180        let text_encoder_vision = file_for_role(&files, ModelFileRole::TextEncoderVision);
181
182        let out_dir = std::env::temp_dir().join("studio-worker-sdcpp");
183        std::fs::create_dir_all(&out_dir)
184            .with_context(|| format!("creating sdcpp output dir {}", out_dir.display()))?;
185        let stem = format!(
186            "out-{}-{}",
187            std::process::id(),
188            chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
189        );
190        let out_path = out_dir.join(format!("{stem}.webp"));
191
192        // Own the scratch files from the moment their paths exist so
193        // every failure path (sd-cli error, unreadable output) cleans
194        // up instead of leaking them into the temp dir.
195        let mut temp_files = TempFileGuard::new();
196        temp_files.push(out_path.clone());
197
198        // If the task carries an init image URL, stream it to a
199        // tempfile so we can hand the path to `sd-cli --init-img`.
200        // This is mandatory — the worker refuses i2i jobs whose
201        // init image fails to download (no silent fallback to t2i).
202        // The local extension mirrors the URL's so sd-cli's image
203        // loader can sniff the format.
204        let init_img_path = match params.init_image_url.as_deref() {
205            Some(url) if !url.is_empty() => {
206                let ext = init_image_extension(url);
207                let init_path = out_dir.join(format!("{stem}-init.{ext}"));
208                download::download_file(url, &init_path).with_context(|| {
209                    format!("downloading init image {} -> {}", url, init_path.display())
210                })?;
211                temp_files.push(init_path.clone());
212                Some(init_path)
213            }
214            _ => None,
215        };
216
217        // A mask constrains the edit region — valid alongside either an init image (img2img
218        // inpaint) or a reference image (instruction edit). Download it whenever a base image is
219        // present and a mask URL was supplied; white pixels mark the region the model may change.
220        let has_base = init_img_path.is_some() || params.ref_image_url.as_deref().is_some();
221        let mask_path = match (has_base, params.mask_url.as_deref()) {
222            (true, Some(url)) if !url.is_empty() => {
223                let ext = init_image_extension(url);
224                let path = out_dir.join(format!("{stem}-mask.{ext}"));
225                download::download_file(url, &path)
226                    .with_context(|| format!("downloading mask {} -> {}", url, path.display()))?;
227                temp_files.push(path.clone());
228                Some(path)
229            }
230            _ => None,
231        };
232
233        // Reference image for instruction-edit models (`sd-cli -r`). Downloaded like the init image;
234        // when present the arg builder uses reference mode instead of the img2img/mask path.
235        let ref_img_path = match params.ref_image_url.as_deref() {
236            Some(url) if !url.is_empty() => {
237                let ext = init_image_extension(url);
238                let path = out_dir.join(format!("{stem}-ref.{ext}"));
239                download::download_file(url, &path).with_context(|| {
240                    format!("downloading reference image {} -> {}", url, path.display())
241                })?;
242                temp_files.push(path.clone());
243                Some(path)
244            }
245            _ => None,
246        };
247
248        let args = build_sdcli_args(
249            &params,
250            source,
251            diffusion_model,
252            vae,
253            text_encoder,
254            text_encoder_vision,
255            &out_path,
256            init_img_path.as_deref(),
257            mask_path.as_deref(),
258            ref_img_path.as_deref(),
259            full_checkpoint,
260        );
261        let mut cmd = Command::new(&sd_cli);
262        cmd.args(&args);
263        apply_library_path(&mut cmd, &sd_cli);
264
265        debug!(
266            target: TRACE_TARGET,
267            op = "spawn",
268            sd_cli = %sd_cli.display(),
269            model,
270            i2i = init_img_path.is_some(),
271            arg_count = args.len(),
272            "running sd-cli"
273        );
274
275        let started = Instant::now();
276        let output = cmd
277            .output()
278            .with_context(|| format!("running {}", sd_cli.display()))?;
279        let elapsed_ms = started.elapsed().as_millis() as u64;
280        if !output.status.success() {
281            let stderr = String::from_utf8_lossy(&output.stderr);
282            warn!(
283                target: TRACE_TARGET,
284                op = "spawn",
285                model,
286                elapsed_ms,
287                exit = ?output.status.code(),
288                stderr = %stderr,
289                "sd-cli failed"
290            );
291            bail!(
292                "sd-cli exited with {:?}: {}",
293                output.status.code(),
294                stderr.lines().last().unwrap_or("(no stderr)")
295            );
296        }
297
298        let bytes = std::fs::read(&out_path)
299            .with_context(|| format!("reading sd-cli output at {}", out_path.display()))?;
300        info!(
301            target: TRACE_TARGET,
302            op = "dispatch",
303            model,
304            elapsed_ms,
305            bytes = bytes.len(),
306            "ok"
307        );
308
309        Ok(TaskResult::Image {
310            bytes,
311            ext: "webp".to_string(),
312        })
313    }
314}
315
316impl Engine for SdCppEngine {
317    fn name(&self) -> &'static str {
318        "sdcpp"
319    }
320
321    fn capabilities(&self) -> EngineCapabilities {
322        // Image kind only.  The studio's selection is kind-based now
323        // and the offer carries the model-source, so we don't need to
324        // enumerate model names ourselves.  We still list a single
325        // sentinel string so downstream code that reads
326        // `supportedModels` for display sees "any sd-cpp model".
327        let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
328        map.insert(TaskKind::Image, vec!["sd-cpp:*".to_string()]);
329        EngineCapabilities {
330            supported_models_per_kind: map,
331        }
332    }
333
334    fn dispatch(&self, _model: &str, _task: Task) -> Result<TaskResult> {
335        bail!(
336            "sdcpp engine requires a ModelSource on the offer; legacy push-based offers \
337             (no modelSource) cannot be served - re-promote the job through the studio"
338        )
339    }
340
341    fn dispatch_with_source(
342        &self,
343        model: &str,
344        task: Task,
345        source: &ModelSource,
346    ) -> Result<TaskResult> {
347        match task {
348            Task::Image(p) => self.dispatch_image(model, p, source),
349            other => {
350                // Surface the rejection at this engine's own target,
351                // matching the onnx/llama/whisper/candle engines.
352                // Without it an operator filtering
353                // `RUST_LOG=studio_worker::engine::sdcpp=debug` sees
354                // nothing when sdcpp refuses a non-image task.
355                let kind = other.kind();
356                warn!(
357                    target: TRACE_TARGET,
358                    op = "dispatch",
359                    model,
360                    kind = kind.as_str(),
361                    "sdcpp engine only serves image jobs"
362                );
363                Err(crate::engine::UnsupportedTask::new("sdcpp", kind).into())
364            }
365        }
366    }
367}
368
369// ---------------------------------------------------------------------------
370// Helpers
371// ---------------------------------------------------------------------------
372
373// The per-job scratch cleanup primitives (`remove_temp_file` +
374// `TempFileGuard`) live in `engine::download` so this engine and the
375// onnx engine share one tested implementation.
376
377fn file_for_role(files: &[(ModelFileRole, PathBuf)], role: ModelFileRole) -> Option<&Path> {
378    files
379        .iter()
380        .find(|(r, _)| *r == role)
381        .map(|(_, p)| p.as_path())
382}
383
384/// Resolve final per-job width / height / steps / cfg / sampler /
385/// negative-prompt by layering `params` over `source.cli_defaults`
386/// with the agreed precedence (per-job override beats model default
387/// beats engine fallback).  Pure for testability.
388fn resolve_image_args(params: &ImageParams, source: &ModelSource) -> ResolvedImageArgs {
389    let width = if params.width > 0 {
390        params.width
391    } else if source.cli_defaults.width > 0 {
392        source.cli_defaults.width
393    } else {
394        1024
395    };
396    let height = if params.height > 0 {
397        params.height
398    } else if source.cli_defaults.height > 0 {
399        source.cli_defaults.height
400    } else {
401        1024
402    };
403    // Steps: per-job override wins (treat the deserialiser default of
404    // 20 as "caller didn't pick" so the model's tuned step count
405    // doesn't get clobbered by a stale default).
406    let steps = if params.steps > 0 && params.steps != 20 {
407        params.steps
408    } else if source.cli_defaults.steps > 0 {
409        source.cli_defaults.steps
410    } else {
411        STEPS_FALLBACK
412    };
413    let source_cfg = if source.cli_defaults.cfg_scale > 0.0 {
414        source.cli_defaults.cfg_scale
415    } else {
416        1.0
417    };
418    let cfg_scale = params.cfg_scale.filter(|v| *v > 0.0).unwrap_or(source_cfg);
419    let sampling_method = params
420        .sampling_method
421        .clone()
422        .or_else(|| source.cli_defaults.sampling_method.clone());
423    ResolvedImageArgs {
424        width,
425        height,
426        steps,
427        cfg_scale,
428        sampling_method,
429    }
430}
431
432/// Resolved per-job sd-cli numerics.  Output of [`resolve_image_args`].
433#[derive(Debug, Clone, PartialEq)]
434struct ResolvedImageArgs {
435    width: u32,
436    height: u32,
437    steps: u32,
438    cfg_scale: f32,
439    sampling_method: Option<String>,
440}
441
442/// Build the full `sd-cli` argv for one image job.  Pure (no I/O):
443/// the caller resolves files / out-path / init-image-path, this
444/// function only assembles the flag list so it can be asserted in
445/// unit tests without spawning the binary.
446// Eight model-path + i2i components; grouping them adds indirection without
447// improving readability (mirrors the `#[allow]` already used in ws::session).
448#[allow(clippy::too_many_arguments)]
449fn build_sdcli_args(
450    params: &ImageParams,
451    source: &ModelSource,
452    diffusion_model: &Path,
453    vae: Option<&Path>,
454    text_encoder: Option<&Path>,
455    text_encoder_vision: Option<&Path>,
456    out_path: &Path,
457    init_img_path: Option<&Path>,
458    mask_path: Option<&Path>,
459    ref_img_path: Option<&Path>,
460    full_checkpoint: bool,
461) -> Vec<OsString> {
462    let resolved = resolve_image_args(params, source);
463    let mut args: Vec<OsString> = Vec::with_capacity(32);
464
465    // A full checkpoint loads via `-m`/`--model`; standalone diffusion weights via
466    // `--diffusion-model` (alongside split vae/clip files).
467    args.push(
468        if full_checkpoint {
469            "--model"
470        } else {
471            "--diffusion-model"
472        }
473        .into(),
474    );
475    args.push(diffusion_model.into());
476    if let Some(p) = vae {
477        args.push("--vae".into());
478        args.push(p.into());
479    }
480    if let Some(p) = text_encoder {
481        args.push("--llm".into());
482        args.push(p.into());
483    }
484    if let Some(p) = text_encoder_vision {
485        args.push("--llm_vision".into());
486        args.push(p.into());
487    }
488    args.push("-p".into());
489    args.push((&params.prompt as &str).into());
490    if let Some(neg) = params.negative_prompt.as_deref() {
491        if !neg.is_empty() {
492            args.push("--negative-prompt".into());
493            args.push(neg.into());
494        }
495    }
496    if let Some(reference) = ref_img_path {
497        // Reference / instruction-edit mode (Qwen-Image-Edit, Flux Kontext): the model regenerates
498        // the image from the reference per the prompt. Mutually exclusive with the `--init-img`
499        // img2img path. A `--mask` is honoured here too: it constrains the edit to the masked
500        // region (white = editable) and leaves the rest, so the studio can place the edit inside
501        // the author's drawn shape. No `--strength` (that's an img2img-only knob).
502        args.push("-r".into());
503        args.push(reference.into());
504        if let Some(mask) = mask_path {
505            args.push("--mask".into());
506            args.push(mask.into());
507        }
508    } else if let Some(init) = init_img_path {
509        args.push("--init-img".into());
510        args.push(init.into());
511        // `--strength` only makes sense alongside an init image
512        // (sd-cli ignores it otherwise).  Default to 0.75 (sd-cli's
513        // own default) when the caller didn't pick a value.
514        let strength = params.denoise.unwrap_or(0.75);
515        args.push("--strength".into());
516        args.push(strength.to_string().into());
517        // Mask-guided inpaint: only valid with an init image.
518        if let Some(mask) = mask_path {
519            args.push("--mask".into());
520            args.push(mask.into());
521        }
522    }
523    args.push("--cfg-scale".into());
524    args.push(resolved.cfg_scale.to_string().into());
525    args.push("--steps".into());
526    args.push(resolved.steps.to_string().into());
527    args.push("-W".into());
528    args.push(resolved.width.to_string().into());
529    args.push("-H".into());
530    args.push(resolved.height.to_string().into());
531    args.push("-o".into());
532    args.push(out_path.into());
533    if let Some(seed) = params.seed {
534        args.push("--seed".into());
535        args.push(seed.to_string().into());
536    }
537    if let Some(method) = resolved.sampling_method.as_deref() {
538        args.push("--sampling-method".into());
539        args.push(method.into());
540    }
541    // Flow / instruction-edit model flags (model-level constants from the registry). Only emitted
542    // when the model declares them, so SDXL-style models are unaffected.
543    if let Some(shift) = source.cli_defaults.flow_shift {
544        args.push("--flow-shift".into());
545        args.push(shift.to_string().into());
546    }
547    if source.cli_defaults.zero_cond_t == Some(true) {
548        args.push("--qwen-image-zero-cond-t".into());
549    }
550    if source.cli_defaults.offload_to_cpu == Some(true) {
551        args.push("--offload-to-cpu".into());
552    }
553    // VRAM-saving flags that are safe on every box.
554    args.push("--diffusion-fa".into());
555    args
556}
557
558/// Point the per-job `Command`'s dynamic linker at the shared library
559/// that ships next to an auto-provisioned `sd-cli` (Linux / macOS).
560/// No-op on Windows (sibling DLLs resolve automatically) and when the
561/// resolved binary has no sibling library (operator wrapper-script
562/// installs manage their own load path).  Prepends to any inherited
563/// value so a pre-set `LD_LIBRARY_PATH` isn't clobbered.
564#[cfg_attr(coverage_nightly, coverage(off))]
565fn apply_library_path(cmd: &mut Command, sd_cli: &Path) {
566    let Some((var, dir)) = sd_provision::library_path_env(sd_cli) else {
567        return;
568    };
569    let value = match std::env::var_os(var) {
570        Some(existing) => {
571            let mut paths = vec![dir.clone()];
572            paths.extend(std::env::split_paths(&existing));
573            // `join_paths` only fails if a path contains the platform
574            // separator; fall back to our dir alone, the entry that
575            // matters for finding the sibling library.
576            std::env::join_paths(paths).unwrap_or_else(|_| dir.into_os_string())
577        }
578        None => dir.into_os_string(),
579    };
580    cmd.env(var, value);
581}
582
583/// Look up `sd-cli` in env override -> `<models_root>/bin` ->
584/// `~/.local/bin` -> `$PATH`.  The `<models_root>/bin` slot is where a
585/// self-provisioned binary lands, so the auto-provisioner can drop it
586/// next to the cached models and have the worker pick it up with no
587/// PATH fiddling.  Excluded from coverage: touches several host paths
588/// only one of which matches per host, and CI doesn't ship `sd-cli`.
589#[cfg_attr(coverage_nightly, coverage(off))]
590fn resolve_sd_cli(models_root: &Path) -> Option<PathBuf> {
591    let bin = sd_provision::binary_name();
592    if let Ok(p) = std::env::var("STUDIO_WORKER_SD_CLI") {
593        let path = PathBuf::from(p);
594        if path.is_file() {
595            return Some(path);
596        }
597    }
598    let in_models = models_root.join("bin").join(bin);
599    if in_models.is_file() {
600        return Some(in_models);
601    }
602    if let Some(home) = std::env::var_os("HOME") {
603        let candidate = PathBuf::from(home).join(".local/bin").join(bin);
604        if candidate.is_file() {
605            return Some(candidate);
606        }
607    }
608    which(bin)
609}
610
611/// `$PATH` lookup for a bare binary name.  Excluded from coverage
612/// for the same reason as `resolve_sd_cli`.
613#[cfg_attr(coverage_nightly, coverage(off))]
614fn which(bin: &str) -> Option<PathBuf> {
615    let path = std::env::var_os("PATH")?;
616    for entry in std::env::split_paths(&path) {
617        let candidate = entry.join(bin);
618        if candidate.is_file() {
619            return Some(candidate);
620        }
621    }
622    None
623}
624
625/// Pick an extension to use for the init-image tempfile that sd-cli's
626/// image loader can sniff.  Reads the trailing `.<ext>` from the URL's
627/// path (ignoring query + fragment).  Defaults to `webp` when no
628/// recognisable extension is present.
629fn init_image_extension(url: &str) -> &'static str {
630    let path = url.split(['?', '#']).next().unwrap_or(url);
631    let lower_tail = path
632        .rsplit('.')
633        .next()
634        .map(|t| t.to_ascii_lowercase())
635        .unwrap_or_default();
636    match lower_tail.as_str() {
637        "png" => "png",
638        "jpg" | "jpeg" => "jpg",
639        "webp" => "webp",
640        "bmp" => "bmp",
641        "gif" => "gif",
642        "tif" | "tiff" => "tif",
643        _ => "webp",
644    }
645}
646
647// ---------------------------------------------------------------------------
648// Tests
649// ---------------------------------------------------------------------------
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use crate::types::{ModelCliDefaults, ModelEngine, ModelFile, ModelFileRole};
655    use tempfile::tempdir;
656
657    fn fake_source(files: Vec<ModelFile>) -> ModelSource {
658        ModelSource {
659            engine: ModelEngine::SdCpp,
660            files,
661            cli_defaults: ModelCliDefaults {
662                cfg_scale: 1.0,
663                steps: 8,
664                width: 1024,
665                height: 1024,
666                sampling_method: Some("euler".to_string()),
667                ..Default::default()
668            },
669        }
670    }
671
672    #[test]
673    fn file_for_role_picks_matching_file() {
674        let files = vec![
675            (ModelFileRole::DiffusionModel, PathBuf::from("/d.gguf")),
676            (ModelFileRole::Vae, PathBuf::from("/v.safetensors")),
677        ];
678        assert_eq!(
679            file_for_role(&files, ModelFileRole::DiffusionModel),
680            Some(Path::new("/d.gguf"))
681        );
682        assert_eq!(
683            file_for_role(&files, ModelFileRole::Vae),
684            Some(Path::new("/v.safetensors"))
685        );
686        assert!(file_for_role(&files, ModelFileRole::TextEncoder).is_none());
687    }
688
689    #[test]
690    fn ensure_files_skips_already_present() {
691        let dir = tempdir().unwrap();
692        let cached = dir.path().join("cached.gguf");
693        std::fs::write(&cached, b"already here").unwrap();
694        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
695        let source = fake_source(vec![ModelFile {
696            role: ModelFileRole::DiffusionModel,
697            url: "https://example.invalid/cached.gguf".into(),
698            filename: "cached.gguf".into(),
699            approx_bytes: None,
700            sha256: None,
701        }]);
702        let resolved = engine.ensure_files(&source).expect("cached file used");
703        assert_eq!(resolved.len(), 1);
704        assert_eq!(resolved[0].0, ModelFileRole::DiffusionModel);
705        assert_eq!(resolved[0].1, cached);
706        // Untouched on disk — our "download" never ran.
707        assert_eq!(std::fs::read(&cached).unwrap(), b"already here");
708    }
709
710    #[test]
711    fn dispatch_rejects_non_image_tasks() {
712        use crate::types::AudioTtsParams;
713        let dir = tempdir().unwrap();
714        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
715        let task = Task::AudioTts(AudioTtsParams {
716            text: "hi".into(),
717            voice: "v".into(),
718            ext: "wav".into(),
719            ..Default::default()
720        });
721        let source = fake_source(vec![]);
722        let err = engine
723            .dispatch_with_source("anything", task, &source)
724            .unwrap_err();
725        assert!(err.to_string().contains("cannot serve audio_tts"));
726    }
727
728    // The legacy `dispatch_requires_model_source` test is gone: the
729    // trait signature now takes `&ModelSource` so the compiler enforces
730    // it at every call site.  No runtime fallback to police.
731
732    // -----------------------------------------------------------------
733    // Pure arg-builder tests — lock down the sd-cli invocation contract
734    // without needing the binary on the box.
735    // -----------------------------------------------------------------
736
737    fn args_to_strings(args: &[OsString]) -> Vec<String> {
738        args.iter()
739            .map(|s| s.to_string_lossy().into_owned())
740            .collect()
741    }
742
743    fn idx_after(args: &[String], flag: &str) -> Option<usize> {
744        args.iter().position(|a| a == flag).map(|i| i + 1)
745    }
746
747    #[test]
748    fn build_sdcli_args_includes_required_flags() {
749        let params = ImageParams {
750            prompt: "hello".into(),
751            width: 768,
752            height: 512,
753            steps: 20, // "caller didn't pick" → source default wins
754            ..Default::default()
755        };
756        let source = fake_source(vec![]);
757        let args = build_sdcli_args(
758            &params,
759            &source,
760            Path::new("/d.gguf"),
761            Some(Path::new("/v.safetensors")),
762            Some(Path::new("/llm.gguf")),
763            None,
764            Path::new("/tmp/out.webp"),
765            None,
766            None,
767            None,
768            false,
769        );
770        let s = args_to_strings(&args);
771        assert_eq!(s[idx_after(&s, "--diffusion-model").unwrap()], "/d.gguf");
772        assert_eq!(s[idx_after(&s, "--vae").unwrap()], "/v.safetensors");
773        assert_eq!(s[idx_after(&s, "--llm").unwrap()], "/llm.gguf");
774        assert_eq!(s[idx_after(&s, "-p").unwrap()], "hello");
775        assert_eq!(s[idx_after(&s, "-W").unwrap()], "768");
776        assert_eq!(s[idx_after(&s, "-H").unwrap()], "512");
777        // source default cfg_scale=1.0
778        assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "1");
779        // source default steps=8 wins (param.steps==20 treated as default)
780        assert_eq!(s[idx_after(&s, "--steps").unwrap()], "8");
781        assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "euler");
782        assert_eq!(s[idx_after(&s, "-o").unwrap()], "/tmp/out.webp");
783        assert!(s.contains(&"--diffusion-fa".to_string()));
784        // Never includes init-only flags when no init image present.
785        assert!(!s.contains(&"--init-img".to_string()));
786        assert!(!s.contains(&"--strength".to_string()));
787    }
788
789    #[test]
790    fn build_sdcli_args_includes_negative_prompt_when_set() {
791        let params = ImageParams {
792            prompt: "hi".into(),
793            negative_prompt: Some("text, watermark, low quality".into()),
794            ..Default::default()
795        };
796        let source = fake_source(vec![]);
797        let args = build_sdcli_args(
798            &params,
799            &source,
800            Path::new("/d.gguf"),
801            None,
802            None,
803            None,
804            Path::new("/tmp/out.webp"),
805            None,
806            None,
807            None,
808            false,
809        );
810        let s = args_to_strings(&args);
811        assert_eq!(
812            s[idx_after(&s, "--negative-prompt").unwrap()],
813            "text, watermark, low quality"
814        );
815    }
816
817    #[test]
818    fn build_sdcli_args_omits_negative_prompt_when_empty_string() {
819        let params = ImageParams {
820            prompt: "hi".into(),
821            negative_prompt: Some(String::new()),
822            ..Default::default()
823        };
824        let source = fake_source(vec![]);
825        let args = build_sdcli_args(
826            &params,
827            &source,
828            Path::new("/d.gguf"),
829            None,
830            None,
831            None,
832            Path::new("/tmp/out.webp"),
833            None,
834            None,
835            None,
836            false,
837        );
838        let s = args_to_strings(&args);
839        assert!(!s.contains(&"--negative-prompt".to_string()));
840    }
841
842    #[test]
843    fn build_sdcli_args_includes_init_image_and_strength() {
844        let params = ImageParams {
845            prompt: "hi".into(),
846            denoise: Some(0.55),
847            ..Default::default()
848        };
849        let source = fake_source(vec![]);
850        let args = build_sdcli_args(
851            &params,
852            &source,
853            Path::new("/d.gguf"),
854            None,
855            None,
856            None,
857            Path::new("/tmp/out.webp"),
858            Some(Path::new("/tmp/init.webp")),
859            None,
860            None,
861            false,
862        );
863        let s = args_to_strings(&args);
864        assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
865        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.55");
866        // No mask supplied → no inpaint flag.
867        assert!(!s.contains(&"--mask".to_string()));
868    }
869
870    #[test]
871    fn build_sdcli_args_includes_mask_for_inpaint() {
872        let params = ImageParams {
873            prompt: "remove the tree".into(),
874            denoise: Some(0.8),
875            ..Default::default()
876        };
877        let source = fake_source(vec![]);
878        let args = build_sdcli_args(
879            &params,
880            &source,
881            Path::new("/d.gguf"),
882            None,
883            None,
884            None,
885            Path::new("/tmp/out.webp"),
886            Some(Path::new("/tmp/init.webp")),
887            Some(Path::new("/tmp/mask.png")),
888            None,
889            false,
890        );
891        let s = args_to_strings(&args);
892        assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
893        assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
894        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.8");
895    }
896
897    #[test]
898    fn build_sdcli_args_uses_model_flag_for_full_checkpoint() {
899        let params = ImageParams {
900            prompt: "hi".into(),
901            ..Default::default()
902        };
903        let source = fake_source(vec![]);
904        let args = build_sdcli_args(
905            &params,
906            &source,
907            Path::new("/checkpoint.safetensors"),
908            Some(Path::new("/v.safetensors")),
909            None,
910            None,
911            Path::new("/tmp/out.webp"),
912            None,
913            None,
914            None,
915            true,
916        );
917        let s = args_to_strings(&args);
918        // A full checkpoint loads via -m/--model, not --diffusion-model.
919        assert_eq!(
920            s[idx_after(&s, "--model").unwrap()],
921            "/checkpoint.safetensors"
922        );
923        assert!(!s.contains(&"--diffusion-model".to_string()));
924    }
925
926    #[test]
927    fn build_sdcli_args_defaults_denoise_when_init_image_present_but_denoise_none() {
928        let params = ImageParams {
929            prompt: "hi".into(),
930            denoise: None,
931            ..Default::default()
932        };
933        let source = fake_source(vec![]);
934        let args = build_sdcli_args(
935            &params,
936            &source,
937            Path::new("/d.gguf"),
938            None,
939            None,
940            None,
941            Path::new("/tmp/out.webp"),
942            Some(Path::new("/tmp/init.webp")),
943            None,
944            None,
945            false,
946        );
947        let s = args_to_strings(&args);
948        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.75");
949    }
950
951    #[test]
952    fn build_sdcli_args_per_job_cfg_scale_overrides_model_default() {
953        let params = ImageParams {
954            prompt: "hi".into(),
955            cfg_scale: Some(7.5),
956            ..Default::default()
957        };
958        let source = fake_source(vec![]);
959        let args = build_sdcli_args(
960            &params,
961            &source,
962            Path::new("/d.gguf"),
963            None,
964            None,
965            None,
966            Path::new("/tmp/out.webp"),
967            None,
968            None,
969            None,
970            false,
971        );
972        let s = args_to_strings(&args);
973        assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "7.5");
974    }
975
976    #[test]
977    fn build_sdcli_args_per_job_sampling_method_overrides_model_default() {
978        let params = ImageParams {
979            prompt: "hi".into(),
980            sampling_method: Some("dpm++2m".into()),
981            ..Default::default()
982        };
983        let source = fake_source(vec![]);
984        let args = build_sdcli_args(
985            &params,
986            &source,
987            Path::new("/d.gguf"),
988            None,
989            None,
990            None,
991            Path::new("/tmp/out.webp"),
992            None,
993            None,
994            None,
995            false,
996        );
997        let s = args_to_strings(&args);
998        assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "dpm++2m");
999    }
1000
1001    #[test]
1002    fn build_sdcli_args_per_job_steps_overrides_when_non_default() {
1003        let params = ImageParams {
1004            prompt: "hi".into(),
1005            steps: 30, // != 20 → treat as caller override
1006            ..Default::default()
1007        };
1008        let source = fake_source(vec![]);
1009        let args = build_sdcli_args(
1010            &params,
1011            &source,
1012            Path::new("/d.gguf"),
1013            None,
1014            None,
1015            None,
1016            Path::new("/tmp/out.webp"),
1017            None,
1018            None,
1019            None,
1020            false,
1021        );
1022        let s = args_to_strings(&args);
1023        assert_eq!(s[idx_after(&s, "--steps").unwrap()], "30");
1024    }
1025
1026    #[test]
1027    fn build_sdcli_args_seed_included_when_set() {
1028        let params = ImageParams {
1029            prompt: "hi".into(),
1030            seed: Some(42),
1031            ..Default::default()
1032        };
1033        let source = fake_source(vec![]);
1034        let args = build_sdcli_args(
1035            &params,
1036            &source,
1037            Path::new("/d.gguf"),
1038            None,
1039            None,
1040            None,
1041            Path::new("/tmp/out.webp"),
1042            None,
1043            None,
1044            None,
1045            false,
1046        );
1047        let s = args_to_strings(&args);
1048        assert_eq!(s[idx_after(&s, "--seed").unwrap()], "42");
1049    }
1050
1051    /// A model source carrying the Qwen-Image-Edit flow flags.
1052    fn qwen_edit_source() -> ModelSource {
1053        ModelSource {
1054            engine: ModelEngine::SdCpp,
1055            files: vec![],
1056            cli_defaults: ModelCliDefaults {
1057                cfg_scale: 4.0,
1058                steps: 20,
1059                width: 1024,
1060                height: 1024,
1061                sampling_method: Some("euler".to_string()),
1062                flow_shift: Some(3.0),
1063                zero_cond_t: Some(true),
1064                offload_to_cpu: Some(true),
1065            },
1066        }
1067    }
1068
1069    #[test]
1070    fn build_sdcli_args_reference_mode_for_instruction_edit() {
1071        let params = ImageParams {
1072            prompt: "add a red beach ball".into(),
1073            denoise: Some(0.9),
1074            ..Default::default()
1075        };
1076        let source = qwen_edit_source();
1077        let args = build_sdcli_args(
1078            &params,
1079            &source,
1080            Path::new("/qwen.gguf"),
1081            Some(Path::new("/vae.safetensors")),
1082            Some(Path::new("/llm.gguf")),
1083            Some(Path::new("/mmproj.gguf")),
1084            Path::new("/tmp/out.webp"),
1085            None,
1086            Some(Path::new("/tmp/mask.png")),
1087            Some(Path::new("/tmp/ref.webp")),
1088            false,
1089        );
1090        let s = args_to_strings(&args);
1091        // Reference mode: `-r` set, a `--mask` constrains the edit region, and the img2img-only
1092        // `--init-img` / `--strength` flags are suppressed.
1093        assert_eq!(s[idx_after(&s, "-r").unwrap()], "/tmp/ref.webp");
1094        assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
1095        assert!(!s.contains(&"--init-img".to_string()));
1096        assert!(!s.contains(&"--strength".to_string()));
1097        // Vision encoder + Qwen flow flags emitted.
1098        assert_eq!(s[idx_after(&s, "--llm_vision").unwrap()], "/mmproj.gguf");
1099        assert_eq!(s[idx_after(&s, "--flow-shift").unwrap()], "3");
1100        assert!(s.contains(&"--qwen-image-zero-cond-t".to_string()));
1101        assert!(s.contains(&"--offload-to-cpu".to_string()));
1102    }
1103
1104    #[test]
1105    fn build_sdcli_args_omits_qwen_flags_for_plain_model() {
1106        let params = ImageParams {
1107            prompt: "hi".into(),
1108            ..Default::default()
1109        };
1110        // fake_source has no flow_shift / zero_cond_t / offload_to_cpu.
1111        let source = fake_source(vec![]);
1112        let args = build_sdcli_args(
1113            &params,
1114            &source,
1115            Path::new("/d.gguf"),
1116            None,
1117            None,
1118            None,
1119            Path::new("/tmp/out.webp"),
1120            None,
1121            None,
1122            None,
1123            false,
1124        );
1125        let s = args_to_strings(&args);
1126        assert!(!s.contains(&"--flow-shift".to_string()));
1127        assert!(!s.contains(&"--qwen-image-zero-cond-t".to_string()));
1128        assert!(!s.contains(&"--offload-to-cpu".to_string()));
1129        assert!(!s.contains(&"--llm_vision".to_string()));
1130        assert!(!s.contains(&"-r".to_string()));
1131    }
1132
1133    #[test]
1134    fn capabilities_advertises_only_image_kind() {
1135        let dir = tempdir().unwrap();
1136        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
1137        let caps = engine.capabilities();
1138        assert!(caps
1139            .supported_models_per_kind
1140            .contains_key(&TaskKind::Image));
1141        assert_eq!(caps.supported_models_per_kind.len(), 1);
1142    }
1143
1144    #[test]
1145    fn init_image_extension_reads_url_tail() {
1146        assert_eq!(init_image_extension("https://x/y/latest.webp"), "webp");
1147        assert_eq!(init_image_extension("https://x/y/latest.PNG"), "png");
1148        assert_eq!(init_image_extension("https://x/y/latest.jpg"), "jpg");
1149        assert_eq!(init_image_extension("https://x/y/latest.jpeg"), "jpg");
1150        // Query strings + fragments don't trick the parser.
1151        assert_eq!(
1152            init_image_extension("https://x/y/latest.webp?v=42&t=now"),
1153            "webp"
1154        );
1155        assert_eq!(init_image_extension("https://x/y/latest.webp#frag"), "webp");
1156        // Unknown extension falls back to webp.
1157        assert_eq!(
1158            init_image_extension("https://x/y/latest.unknownext"),
1159            "webp"
1160        );
1161        assert_eq!(init_image_extension("https://x/y/no-ext"), "webp");
1162    }
1163}