Skip to main content

studio_worker/engine/
sdcpp.rs

1//! Engine that runs real image inference by subprocess-invoking the
2//! `stable-diffusion.cpp` (`sd-cli`) binary.
3//!
4//! The studio's offer carries a [`ModelSource`] with everything we
5//! need: an engine identifier (`sd-cpp`), the list of files to
6//! download (diffusion-model + text-encoder + VAE, each with a public
7//! URL + filename), and CLI defaults (cfg-scale, steps, dimensions).
8//! The worker has zero hardcoded model knowledge \u2014 it caches
9//! whatever the studio asks for under `cfg.models_root` and invokes
10//! `sd-cli` with the files arranged by role.
11//!
12//! Layout under `cfg.models_root` (default `~/models`):
13//! ```text
14//! ~/models/<filename1>
15//! ~/models/<filename2>
16//! \u2026
17//! ```
18//! Files are downloaded on first use - skipped when already present
19//! under `cfg.models_root`.  The streamed body is checked against the
20//! server's `Content-Length` so a truncated download is rejected and
21//! cleaned up instead of being renamed into place as a corrupt model
22//! that every later job would fail to load.  Cached files are re-used
23//! across every subsequent job that names them.
24//!
25//! The engine self-registers only when `sd-cli` is present on the box
26//! (either at `$STUDIO_WORKER_SD_CLI`, or `~/.local/bin/sd-cli`, or on
27//! `$PATH`).  Without `sd-cli` the worker can't run real-image jobs
28//! at all so it skips registration and the multi engine falls through
29//! to synthetic for any kind it doesn't have a real backend for.
30
31use crate::engine::download;
32use crate::engine::sd_provision;
33use crate::engine::{Engine, EngineCapabilities};
34use crate::types::{ImageParams, ModelFileRole, ModelSource, Task, TaskKind, TaskResult};
35use anyhow::{anyhow, bail, Context, Result};
36use parking_lot::Mutex;
37use std::collections::BTreeMap;
38use std::ffi::OsString;
39use std::path::{Path, PathBuf};
40use std::process::Command;
41use std::time::Instant;
42use tracing::{debug, info, warn};
43
44const TRACE_TARGET: &str = "studio_worker::engine::sdcpp";
45
46/// Default sample-steps when the studio's `ImageParams.steps` is the
47/// upstream default (20).  Z-Image-Turbo is an 8-step distilled
48/// schedule so 20 wastes time; we honour `ModelSource.cliDefaults.steps`
49/// instead.  Only used as the very last fallback.
50const STEPS_FALLBACK: u32 = 8;
51
52/// Worker-side engine that drives `sd-cli` per job.
53///
54/// `sd-cli` is resolved lazily on the first image job and cached: an
55/// operator install (env / PATH / `~/.local/bin`) wins, otherwise the
56/// binary is auto-provisioned into `<models_root>/bin/`.  The `Mutex`
57/// serialises that one-time resolution so two concurrent jobs can't
58/// race the download.
59pub struct SdCppEngine {
60    sd_cli: Mutex<Option<PathBuf>>,
61    models_root: PathBuf,
62}
63
64impl SdCppEngine {
65    /// Build the engine.  Always registers: `sd-cli` is resolved (and
66    /// provisioned into `<models_root>/bin/` if missing) lazily on the
67    /// first image job, so the engine serves real image work even on a
68    /// box that has never had a stable-diffusion.cpp build installed.
69    /// `models_root` is created on demand by the provisioner / model
70    /// downloader, so registration touches no filesystem.
71    pub fn new(models_root: &Path) -> Self {
72        info!(
73            target: TRACE_TARGET,
74            op = "register",
75            models_root = %models_root.display(),
76            sd_cli_name = sd_provision::binary_name(),
77            "sdcpp engine registered (sd-cli resolved/provisioned on first image job)"
78        );
79        Self {
80            sd_cli: Mutex::new(None),
81            models_root: models_root.to_path_buf(),
82        }
83    }
84
85    /// For tests: build with explicit paths (bypasses sd-cli lookup +
86    /// provisioning by seeding the resolved-path cache).
87    #[cfg(test)]
88    pub fn with_paths(sd_cli: PathBuf, models_root: PathBuf) -> Self {
89        Self {
90            sd_cli: Mutex::new(Some(sd_cli)),
91            models_root,
92        }
93    }
94
95    /// Resolve the `sd-cli` binary, provisioning it on first use.
96    /// Resolution order (operator installs win): a cached path from a
97    /// previous job, then env / `<models_root>/bin` / `~/.local/bin` /
98    /// `$PATH`, then an auto-provisioned download into
99    /// `<models_root>/bin/`.  The result is cached for the worker's
100    /// lifetime.
101    #[cfg_attr(coverage_nightly, coverage(off))]
102    fn ensure_sd_cli(&self) -> Result<PathBuf> {
103        let mut guard = self.sd_cli.lock();
104        if let Some(p) = guard.as_ref() {
105            if p.is_file() {
106                return Ok(p.clone());
107            }
108        }
109        let resolved = match resolve_sd_cli(&self.models_root) {
110            Some(p) => {
111                info!(
112                    target: TRACE_TARGET,
113                    op = "resolve",
114                    sd_cli = %p.display(),
115                    "using existing sd-cli"
116                );
117                p
118            }
119            None => sd_provision::provision(&self.models_root)
120                .context("auto-provisioning sd-cli (stable-diffusion.cpp)")?,
121        };
122        *guard = Some(resolved.clone());
123        Ok(resolved)
124    }
125
126    /// Ensure each file in `source.files` is present under
127    /// `self.models_root`.  Downloads anything missing.  Returns the
128    /// resolved local path for each file (in the same order).
129    #[cfg_attr(coverage_nightly, coverage(off))]
130    fn ensure_files(&self, source: &ModelSource) -> Result<Vec<(ModelFileRole, PathBuf)>> {
131        let mut out = Vec::with_capacity(source.files.len());
132        for file in &source.files {
133            let local = download::ensure_file(&self.models_root, &file.filename, &file.url)?;
134            out.push((file.role, local));
135        }
136        Ok(out)
137    }
138
139    /// Subprocess to `sd-cli` with the resolved diffusion / VAE /
140    /// text-encoder files.  Excluded from coverage: requires an
141    /// actual `sd-cli` binary + cached model files on disk, neither
142    /// of which exists on the CI runner.  Exercised end-to-end via
143    /// the live dev loop.
144    #[cfg_attr(coverage_nightly, coverage(off))]
145    fn dispatch_image(
146        &self,
147        model: &str,
148        params: ImageParams,
149        source: &ModelSource,
150    ) -> Result<TaskResult> {
151        // Resolve (provisioning on first use) the sd-cli binary before
152        // we touch model files, so a missing binary fails fast with the
153        // provisioning error rather than after a multi-GB weight pull.
154        let sd_cli = self.ensure_sd_cli()?;
155        let files = self.ensure_files(source)?;
156        // A `diffusion-model` file is the standalone diffusion weights (sd-cli `--diffusion-model`,
157        // used with split vae/clip); a `model` file is a full checkpoint (sd-cli `-m`/`--model`).
158        // Prefer the explicit diffusion-model role; fall back to a full checkpoint.
159        let diffusion_only = file_for_role(&files, ModelFileRole::DiffusionModel);
160        let full_checkpoint = diffusion_only.is_none();
161        let diffusion_model = diffusion_only
162            .or_else(|| file_for_role(&files, ModelFileRole::Model))
163            .ok_or_else(|| anyhow!("modelSource has no diffusion-model / model file"))?;
164        let vae = file_for_role(&files, ModelFileRole::Vae);
165        let text_encoder = file_for_role(&files, ModelFileRole::TextEncoder);
166        let text_encoder_vision = file_for_role(&files, ModelFileRole::TextEncoderVision);
167
168        let out_dir = std::env::temp_dir().join("studio-worker-sdcpp");
169        std::fs::create_dir_all(&out_dir)
170            .with_context(|| format!("creating sdcpp output dir {}", out_dir.display()))?;
171        let stem = format!(
172            "out-{}-{}",
173            std::process::id(),
174            chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
175        );
176        let out_path = out_dir.join(format!("{stem}.webp"));
177
178        // Own the scratch files from the moment their paths exist so
179        // every failure path (sd-cli error, unreadable output) cleans
180        // up instead of leaking them into the temp dir.
181        let mut temp_files = TempFileGuard::new();
182        temp_files.push(out_path.clone());
183
184        // If the task carries an init image URL, stream it to a
185        // tempfile so we can hand the path to `sd-cli --init-img`.
186        // This is mandatory — the worker refuses i2i jobs whose
187        // init image fails to download (no silent fallback to t2i).
188        // The local extension mirrors the URL's so sd-cli's image
189        // loader can sniff the format.
190        let init_img_path = match params.init_image_url.as_deref() {
191            Some(url) if !url.is_empty() => {
192                let ext = init_image_extension(url);
193                let init_path = out_dir.join(format!("{stem}-init.{ext}"));
194                download::download_file(url, &init_path).with_context(|| {
195                    format!("downloading init image {} -> {}", url, init_path.display())
196                })?;
197                temp_files.push(init_path.clone());
198                Some(init_path)
199            }
200            _ => None,
201        };
202
203        // A mask constrains the edit region — valid alongside either an init image (img2img
204        // inpaint) or a reference image (instruction edit). Download it whenever a base image is
205        // present and a mask URL was supplied; white pixels mark the region the model may change.
206        let has_base = init_img_path.is_some() || params.ref_image_url.as_deref().is_some();
207        let mask_path = match (has_base, params.mask_url.as_deref()) {
208            (true, Some(url)) if !url.is_empty() => {
209                let ext = init_image_extension(url);
210                let path = out_dir.join(format!("{stem}-mask.{ext}"));
211                download::download_file(url, &path)
212                    .with_context(|| format!("downloading mask {} -> {}", url, path.display()))?;
213                temp_files.push(path.clone());
214                Some(path)
215            }
216            _ => None,
217        };
218
219        // Reference image for instruction-edit models (`sd-cli -r`). Downloaded like the init image;
220        // when present the arg builder uses reference mode instead of the img2img/mask path.
221        let ref_img_path = match params.ref_image_url.as_deref() {
222            Some(url) if !url.is_empty() => {
223                let ext = init_image_extension(url);
224                let path = out_dir.join(format!("{stem}-ref.{ext}"));
225                download::download_file(url, &path).with_context(|| {
226                    format!("downloading reference image {} -> {}", url, path.display())
227                })?;
228                temp_files.push(path.clone());
229                Some(path)
230            }
231            _ => None,
232        };
233
234        let args = build_sdcli_args(
235            &params,
236            source,
237            diffusion_model,
238            vae,
239            text_encoder,
240            text_encoder_vision,
241            &out_path,
242            init_img_path.as_deref(),
243            mask_path.as_deref(),
244            ref_img_path.as_deref(),
245            full_checkpoint,
246        );
247        let mut cmd = Command::new(&sd_cli);
248        cmd.args(&args);
249        apply_library_path(&mut cmd, &sd_cli);
250
251        debug!(
252            target: TRACE_TARGET,
253            op = "spawn",
254            sd_cli = %sd_cli.display(),
255            model,
256            i2i = init_img_path.is_some(),
257            arg_count = args.len(),
258            "running sd-cli"
259        );
260
261        let started = Instant::now();
262        let output = cmd
263            .output()
264            .with_context(|| format!("running {}", sd_cli.display()))?;
265        let elapsed_ms = started.elapsed().as_millis() as u64;
266        if !output.status.success() {
267            let stderr = String::from_utf8_lossy(&output.stderr);
268            warn!(
269                target: TRACE_TARGET,
270                op = "spawn",
271                model,
272                elapsed_ms,
273                exit = ?output.status.code(),
274                stderr = %stderr,
275                "sd-cli failed"
276            );
277            bail!(
278                "sd-cli exited with {:?}: {}",
279                output.status.code(),
280                stderr.lines().last().unwrap_or("(no stderr)")
281            );
282        }
283
284        let bytes = std::fs::read(&out_path)
285            .with_context(|| format!("reading sd-cli output at {}", out_path.display()))?;
286        info!(
287            target: TRACE_TARGET,
288            op = "dispatch",
289            model,
290            elapsed_ms,
291            bytes = bytes.len(),
292            "ok"
293        );
294
295        Ok(TaskResult::Image {
296            bytes,
297            ext: "webp".to_string(),
298        })
299    }
300}
301
302impl Engine for SdCppEngine {
303    fn name(&self) -> &'static str {
304        "sdcpp"
305    }
306
307    fn capabilities(&self) -> EngineCapabilities {
308        // Image kind only.  The studio's selection is kind-based now
309        // and the offer carries the model-source, so we don't need to
310        // enumerate model names ourselves.  We still list a single
311        // sentinel string so downstream code that reads
312        // `supportedModels` for display sees "any sd-cpp model".
313        let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
314        map.insert(TaskKind::Image, vec!["sd-cpp:*".to_string()]);
315        EngineCapabilities {
316            supported_models_per_kind: map,
317        }
318    }
319
320    fn dispatch(&self, _model: &str, _task: Task) -> Result<TaskResult> {
321        bail!(
322            "sdcpp engine requires a ModelSource on the offer; legacy push-based offers \
323             (no modelSource) cannot be served - re-promote the job through the studio"
324        )
325    }
326
327    fn dispatch_with_source(
328        &self,
329        model: &str,
330        task: Task,
331        source: &ModelSource,
332    ) -> Result<TaskResult> {
333        let kind = task.kind();
334        match task {
335            Task::Image(p) => self.dispatch_image(model, p, source),
336            _ => bail!("sdcpp engine cannot serve {} tasks", kind.as_str()),
337        }
338    }
339}
340
341// ---------------------------------------------------------------------------
342// Helpers
343// ---------------------------------------------------------------------------
344
345/// Best-effort removal of a temporary file (a per-job `sd-cli` output,
346/// an init image, or a half-written `.part` download).  Removal is
347/// non-fatal — the artefact has already been read or the job already
348/// failed — but a remove that keeps failing silently leaks temp files
349/// and can quietly fill the worker's disk over a long-running session,
350/// so we surface the failure instead of swallowing it.  A `NotFound`
351/// is the desired end state (something already cleaned it up), so it's
352/// not logged.
353fn remove_temp_file(path: &Path) {
354    if let Err(e) = std::fs::remove_file(path) {
355        if e.kind() != std::io::ErrorKind::NotFound {
356            warn!(
357                target: TRACE_TARGET,
358                op = "cleanup",
359                path = %path.display(),
360                error = %e,
361                "failed to remove temp file"
362            );
363        }
364    }
365}
366
367/// RAII owner of a job's scratch files (the `sd-cli` output image and a
368/// downloaded init image).  Registering them up front means every exit
369/// path - the success return, an `sd-cli` non-zero exit, an unreadable
370/// output file, even a panic - removes them on drop instead of leaking
371/// them into the temp dir and slowly filling the worker's disk over a
372/// long-running session.  Removal is best-effort via [`remove_temp_file`],
373/// so a path that never materialised (job failed before `sd-cli` wrote
374/// anything) is silently tolerated.
375struct TempFileGuard {
376    paths: Vec<PathBuf>,
377}
378
379impl TempFileGuard {
380    fn new() -> Self {
381        Self { paths: Vec::new() }
382    }
383
384    fn push(&mut self, path: PathBuf) {
385        self.paths.push(path);
386    }
387}
388
389impl Drop for TempFileGuard {
390    fn drop(&mut self) {
391        for path in &self.paths {
392            remove_temp_file(path);
393        }
394    }
395}
396
397fn file_for_role(files: &[(ModelFileRole, PathBuf)], role: ModelFileRole) -> Option<&Path> {
398    files
399        .iter()
400        .find(|(r, _)| *r == role)
401        .map(|(_, p)| p.as_path())
402}
403
404/// Resolve final per-job width / height / steps / cfg / sampler /
405/// negative-prompt by layering `params` over `source.cli_defaults`
406/// with the agreed precedence (per-job override beats model default
407/// beats engine fallback).  Pure for testability.
408fn resolve_image_args(params: &ImageParams, source: &ModelSource) -> ResolvedImageArgs {
409    let width = if params.width > 0 {
410        params.width
411    } else if source.cli_defaults.width > 0 {
412        source.cli_defaults.width
413    } else {
414        1024
415    };
416    let height = if params.height > 0 {
417        params.height
418    } else if source.cli_defaults.height > 0 {
419        source.cli_defaults.height
420    } else {
421        1024
422    };
423    // Steps: per-job override wins (treat the deserialiser default of
424    // 20 as "caller didn't pick" so the model's tuned step count
425    // doesn't get clobbered by a stale default).
426    let steps = if params.steps > 0 && params.steps != 20 {
427        params.steps
428    } else if source.cli_defaults.steps > 0 {
429        source.cli_defaults.steps
430    } else {
431        STEPS_FALLBACK
432    };
433    let source_cfg = if source.cli_defaults.cfg_scale > 0.0 {
434        source.cli_defaults.cfg_scale
435    } else {
436        1.0
437    };
438    let cfg_scale = params.cfg_scale.filter(|v| *v > 0.0).unwrap_or(source_cfg);
439    let sampling_method = params
440        .sampling_method
441        .clone()
442        .or_else(|| source.cli_defaults.sampling_method.clone());
443    ResolvedImageArgs {
444        width,
445        height,
446        steps,
447        cfg_scale,
448        sampling_method,
449    }
450}
451
452/// Resolved per-job sd-cli numerics.  Output of [`resolve_image_args`].
453#[derive(Debug, Clone, PartialEq)]
454struct ResolvedImageArgs {
455    width: u32,
456    height: u32,
457    steps: u32,
458    cfg_scale: f32,
459    sampling_method: Option<String>,
460}
461
462/// Build the full `sd-cli` argv for one image job.  Pure (no I/O):
463/// the caller resolves files / out-path / init-image-path, this
464/// function only assembles the flag list so it can be asserted in
465/// unit tests without spawning the binary.
466// Eight model-path + i2i components; grouping them adds indirection without
467// improving readability (mirrors the `#[allow]` already used in ws::session).
468#[allow(clippy::too_many_arguments)]
469fn build_sdcli_args(
470    params: &ImageParams,
471    source: &ModelSource,
472    diffusion_model: &Path,
473    vae: Option<&Path>,
474    text_encoder: Option<&Path>,
475    text_encoder_vision: Option<&Path>,
476    out_path: &Path,
477    init_img_path: Option<&Path>,
478    mask_path: Option<&Path>,
479    ref_img_path: Option<&Path>,
480    full_checkpoint: bool,
481) -> Vec<OsString> {
482    let resolved = resolve_image_args(params, source);
483    let mut args: Vec<OsString> = Vec::with_capacity(32);
484
485    // A full checkpoint loads via `-m`/`--model`; standalone diffusion weights via
486    // `--diffusion-model` (alongside split vae/clip files).
487    args.push(
488        if full_checkpoint {
489            "--model"
490        } else {
491            "--diffusion-model"
492        }
493        .into(),
494    );
495    args.push(diffusion_model.into());
496    if let Some(p) = vae {
497        args.push("--vae".into());
498        args.push(p.into());
499    }
500    if let Some(p) = text_encoder {
501        args.push("--llm".into());
502        args.push(p.into());
503    }
504    if let Some(p) = text_encoder_vision {
505        args.push("--llm_vision".into());
506        args.push(p.into());
507    }
508    args.push("-p".into());
509    args.push((&params.prompt as &str).into());
510    if let Some(neg) = params.negative_prompt.as_deref() {
511        if !neg.is_empty() {
512            args.push("--negative-prompt".into());
513            args.push(neg.into());
514        }
515    }
516    if let Some(reference) = ref_img_path {
517        // Reference / instruction-edit mode (Qwen-Image-Edit, Flux Kontext): the model regenerates
518        // the image from the reference per the prompt. Mutually exclusive with the `--init-img`
519        // img2img path. A `--mask` is honoured here too: it constrains the edit to the masked
520        // region (white = editable) and leaves the rest, so the studio can place the edit inside
521        // the author's drawn shape. No `--strength` (that's an img2img-only knob).
522        args.push("-r".into());
523        args.push(reference.into());
524        if let Some(mask) = mask_path {
525            args.push("--mask".into());
526            args.push(mask.into());
527        }
528    } else if let Some(init) = init_img_path {
529        args.push("--init-img".into());
530        args.push(init.into());
531        // `--strength` only makes sense alongside an init image
532        // (sd-cli ignores it otherwise).  Default to 0.75 (sd-cli's
533        // own default) when the caller didn't pick a value.
534        let strength = params.denoise.unwrap_or(0.75);
535        args.push("--strength".into());
536        args.push(strength.to_string().into());
537        // Mask-guided inpaint: only valid with an init image.
538        if let Some(mask) = mask_path {
539            args.push("--mask".into());
540            args.push(mask.into());
541        }
542    }
543    args.push("--cfg-scale".into());
544    args.push(resolved.cfg_scale.to_string().into());
545    args.push("--steps".into());
546    args.push(resolved.steps.to_string().into());
547    args.push("-W".into());
548    args.push(resolved.width.to_string().into());
549    args.push("-H".into());
550    args.push(resolved.height.to_string().into());
551    args.push("-o".into());
552    args.push(out_path.into());
553    if let Some(seed) = params.seed {
554        args.push("--seed".into());
555        args.push(seed.to_string().into());
556    }
557    if let Some(method) = resolved.sampling_method.as_deref() {
558        args.push("--sampling-method".into());
559        args.push(method.into());
560    }
561    // Flow / instruction-edit model flags (model-level constants from the registry). Only emitted
562    // when the model declares them, so SDXL-style models are unaffected.
563    if let Some(shift) = source.cli_defaults.flow_shift {
564        args.push("--flow-shift".into());
565        args.push(shift.to_string().into());
566    }
567    if source.cli_defaults.zero_cond_t == Some(true) {
568        args.push("--qwen-image-zero-cond-t".into());
569    }
570    if source.cli_defaults.offload_to_cpu == Some(true) {
571        args.push("--offload-to-cpu".into());
572    }
573    // VRAM-saving flags that are safe on every box.
574    args.push("--diffusion-fa".into());
575    args
576}
577
578/// Point the per-job `Command`'s dynamic linker at the shared library
579/// that ships next to an auto-provisioned `sd-cli` (Linux / macOS).
580/// No-op on Windows (sibling DLLs resolve automatically) and when the
581/// resolved binary has no sibling library (operator wrapper-script
582/// installs manage their own load path).  Prepends to any inherited
583/// value so a pre-set `LD_LIBRARY_PATH` isn't clobbered.
584#[cfg_attr(coverage_nightly, coverage(off))]
585fn apply_library_path(cmd: &mut Command, sd_cli: &Path) {
586    let Some((var, dir)) = sd_provision::library_path_env(sd_cli) else {
587        return;
588    };
589    let value = match std::env::var_os(var) {
590        Some(existing) => {
591            let mut paths = vec![dir.clone()];
592            paths.extend(std::env::split_paths(&existing));
593            // `join_paths` only fails if a path contains the platform
594            // separator; fall back to our dir alone, the entry that
595            // matters for finding the sibling library.
596            std::env::join_paths(paths).unwrap_or_else(|_| dir.into_os_string())
597        }
598        None => dir.into_os_string(),
599    };
600    cmd.env(var, value);
601}
602
603/// Look up `sd-cli` in env override -> `<models_root>/bin` ->
604/// `~/.local/bin` -> `$PATH`.  The `<models_root>/bin` slot is where a
605/// self-provisioned binary lands, so the auto-provisioner can drop it
606/// next to the cached models and have the worker pick it up with no
607/// PATH fiddling.  Excluded from coverage: touches several host paths
608/// only one of which matches per host, and CI doesn't ship `sd-cli`.
609#[cfg_attr(coverage_nightly, coverage(off))]
610fn resolve_sd_cli(models_root: &Path) -> Option<PathBuf> {
611    let bin = sd_provision::binary_name();
612    if let Ok(p) = std::env::var("STUDIO_WORKER_SD_CLI") {
613        let path = PathBuf::from(p);
614        if path.is_file() {
615            return Some(path);
616        }
617    }
618    let in_models = models_root.join("bin").join(bin);
619    if in_models.is_file() {
620        return Some(in_models);
621    }
622    if let Some(home) = std::env::var_os("HOME") {
623        let candidate = PathBuf::from(home).join(".local/bin").join(bin);
624        if candidate.is_file() {
625            return Some(candidate);
626        }
627    }
628    which(bin)
629}
630
631/// `$PATH` lookup for a bare binary name.  Excluded from coverage
632/// for the same reason as `resolve_sd_cli`.
633#[cfg_attr(coverage_nightly, coverage(off))]
634fn which(bin: &str) -> Option<PathBuf> {
635    let path = std::env::var_os("PATH")?;
636    for entry in std::env::split_paths(&path) {
637        let candidate = entry.join(bin);
638        if candidate.is_file() {
639            return Some(candidate);
640        }
641    }
642    None
643}
644
645/// Pick an extension to use for the init-image tempfile that sd-cli's
646/// image loader can sniff.  Reads the trailing `.<ext>` from the URL's
647/// path (ignoring query + fragment).  Defaults to `webp` when no
648/// recognisable extension is present.
649fn init_image_extension(url: &str) -> &'static str {
650    let path = url.split(['?', '#']).next().unwrap_or(url);
651    let lower_tail = path
652        .rsplit('.')
653        .next()
654        .map(|t| t.to_ascii_lowercase())
655        .unwrap_or_default();
656    match lower_tail.as_str() {
657        "png" => "png",
658        "jpg" | "jpeg" => "jpg",
659        "webp" => "webp",
660        "bmp" => "bmp",
661        "gif" => "gif",
662        "tif" | "tiff" => "tif",
663        _ => "webp",
664    }
665}
666
667// ---------------------------------------------------------------------------
668// Tests
669// ---------------------------------------------------------------------------
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use crate::types::{ModelCliDefaults, ModelEngine, ModelFile, ModelFileRole};
675    use tempfile::tempdir;
676
677    fn fake_source(files: Vec<ModelFile>) -> ModelSource {
678        ModelSource {
679            engine: ModelEngine::SdCpp,
680            files,
681            cli_defaults: ModelCliDefaults {
682                cfg_scale: 1.0,
683                steps: 8,
684                width: 1024,
685                height: 1024,
686                sampling_method: Some("euler".to_string()),
687                ..Default::default()
688            },
689        }
690    }
691
692    #[test]
693    fn temp_file_guard_removes_every_registered_file_on_drop() {
694        let dir = tempdir().unwrap();
695        let out = dir.path().join("out.webp");
696        let init = dir.path().join("out-init.png");
697        std::fs::write(&out, b"image").unwrap();
698        std::fs::write(&init, b"init").unwrap();
699        {
700            let mut guard = TempFileGuard::new();
701            guard.push(out.clone());
702            guard.push(init.clone());
703            assert!(out.exists() && init.exists(), "files present before drop");
704        }
705        assert!(!out.exists(), "sd-cli output temp must be removed on drop");
706        assert!(!init.exists(), "init-image temp must be removed on drop");
707    }
708
709    #[test]
710    fn temp_file_guard_tolerates_a_file_that_never_materialised() {
711        // The output path is registered before sd-cli runs, so a job
712        // that fails before writing anything drops a guard pointing at
713        // a path that never existed.  That is the desired end state,
714        // not a cleanup warning.
715        let dir = tempdir().unwrap();
716        let missing = dir.path().join("never-written.webp");
717        let out = crate::test_support::capture(move || {
718            let mut guard = TempFileGuard::new();
719            guard.push(missing);
720            drop(guard);
721        });
722        assert!(
723            !out.contains("failed to remove temp file"),
724            "a never-created temp file must not warn on cleanup: {out:?}"
725        );
726    }
727
728    #[test]
729    fn remove_temp_file_deletes_an_existing_file_quietly() {
730        let dir = tempdir().unwrap();
731        let f = dir.path().join("artefact.webp");
732        std::fs::write(&f, b"bytes").unwrap();
733        let out = crate::test_support::capture({
734            let f = f.clone();
735            move || remove_temp_file(&f)
736        });
737        assert!(!f.exists(), "file should be gone after cleanup");
738        assert!(
739            !out.contains("failed to remove temp file"),
740            "the success path must not warn: {out:?}"
741        );
742    }
743
744    #[test]
745    fn remove_temp_file_ignores_an_already_missing_file() {
746        let dir = tempdir().unwrap();
747        let missing = dir.path().join("never-existed.webp");
748        let out = crate::test_support::capture(move || remove_temp_file(&missing));
749        assert!(
750            !out.contains("failed to remove temp file"),
751            "a not-found file is the desired end state, not a warning: {out:?}"
752        );
753    }
754
755    #[test]
756    fn remove_temp_file_surfaces_a_failed_removal() {
757        // Pointing the helper at a directory makes `remove_file` fail
758        // on every platform (it refuses to unlink a dir): the closest
759        // portable stand-in for a locked / permission-denied temp file.
760        let dir = tempdir().unwrap();
761        let stubborn = dir.path().join("subdir");
762        std::fs::create_dir(&stubborn).unwrap();
763        let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
764        assert!(
765            out.contains("failed to remove temp file"),
766            "a failed removal must surface in the logs: {out:?}"
767        );
768        assert!(
769            out.contains("subdir"),
770            "the warning must name the offending path: {out:?}"
771        );
772        assert!(
773            out.contains("cleanup"),
774            "the warning should tag the cleanup op: {out:?}"
775        );
776    }
777
778    #[test]
779    fn file_for_role_picks_matching_file() {
780        let files = vec![
781            (ModelFileRole::DiffusionModel, PathBuf::from("/d.gguf")),
782            (ModelFileRole::Vae, PathBuf::from("/v.safetensors")),
783        ];
784        assert_eq!(
785            file_for_role(&files, ModelFileRole::DiffusionModel),
786            Some(Path::new("/d.gguf"))
787        );
788        assert_eq!(
789            file_for_role(&files, ModelFileRole::Vae),
790            Some(Path::new("/v.safetensors"))
791        );
792        assert!(file_for_role(&files, ModelFileRole::TextEncoder).is_none());
793    }
794
795    #[test]
796    fn ensure_files_skips_already_present() {
797        let dir = tempdir().unwrap();
798        let cached = dir.path().join("cached.gguf");
799        std::fs::write(&cached, b"already here").unwrap();
800        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
801        let source = fake_source(vec![ModelFile {
802            role: ModelFileRole::DiffusionModel,
803            url: "https://example.invalid/cached.gguf".into(),
804            filename: "cached.gguf".into(),
805            approx_bytes: None,
806        }]);
807        let resolved = engine.ensure_files(&source).expect("cached file used");
808        assert_eq!(resolved.len(), 1);
809        assert_eq!(resolved[0].0, ModelFileRole::DiffusionModel);
810        assert_eq!(resolved[0].1, cached);
811        // Untouched on disk \u2014 our "download" never ran.
812        assert_eq!(std::fs::read(&cached).unwrap(), b"already here");
813    }
814
815    #[test]
816    fn dispatch_rejects_non_image_tasks() {
817        use crate::types::AudioTtsParams;
818        let dir = tempdir().unwrap();
819        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
820        let task = Task::AudioTts(AudioTtsParams {
821            text: "hi".into(),
822            voice: "v".into(),
823            ext: "wav".into(),
824            ..Default::default()
825        });
826        let source = fake_source(vec![]);
827        let err = engine
828            .dispatch_with_source("anything", task, &source)
829            .unwrap_err();
830        assert!(err.to_string().contains("cannot serve audio_tts"));
831    }
832
833    // The legacy `dispatch_requires_model_source` test is gone: the
834    // trait signature now takes `&ModelSource` so the compiler enforces
835    // it at every call site.  No runtime fallback to police.
836
837    // -----------------------------------------------------------------
838    // Pure arg-builder tests — lock down the sd-cli invocation contract
839    // without needing the binary on the box.
840    // -----------------------------------------------------------------
841
842    fn args_to_strings(args: &[OsString]) -> Vec<String> {
843        args.iter()
844            .map(|s| s.to_string_lossy().into_owned())
845            .collect()
846    }
847
848    fn idx_after(args: &[String], flag: &str) -> Option<usize> {
849        args.iter().position(|a| a == flag).map(|i| i + 1)
850    }
851
852    #[test]
853    fn build_sdcli_args_includes_required_flags() {
854        let params = ImageParams {
855            prompt: "hello".into(),
856            width: 768,
857            height: 512,
858            steps: 20, // "caller didn't pick" → source default wins
859            ..Default::default()
860        };
861        let source = fake_source(vec![]);
862        let args = build_sdcli_args(
863            &params,
864            &source,
865            Path::new("/d.gguf"),
866            Some(Path::new("/v.safetensors")),
867            Some(Path::new("/llm.gguf")),
868            None,
869            Path::new("/tmp/out.webp"),
870            None,
871            None,
872            None,
873            false,
874        );
875        let s = args_to_strings(&args);
876        assert_eq!(s[idx_after(&s, "--diffusion-model").unwrap()], "/d.gguf");
877        assert_eq!(s[idx_after(&s, "--vae").unwrap()], "/v.safetensors");
878        assert_eq!(s[idx_after(&s, "--llm").unwrap()], "/llm.gguf");
879        assert_eq!(s[idx_after(&s, "-p").unwrap()], "hello");
880        assert_eq!(s[idx_after(&s, "-W").unwrap()], "768");
881        assert_eq!(s[idx_after(&s, "-H").unwrap()], "512");
882        // source default cfg_scale=1.0
883        assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "1");
884        // source default steps=8 wins (param.steps==20 treated as default)
885        assert_eq!(s[idx_after(&s, "--steps").unwrap()], "8");
886        assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "euler");
887        assert_eq!(s[idx_after(&s, "-o").unwrap()], "/tmp/out.webp");
888        assert!(s.contains(&"--diffusion-fa".to_string()));
889        // Never includes init-only flags when no init image present.
890        assert!(!s.contains(&"--init-img".to_string()));
891        assert!(!s.contains(&"--strength".to_string()));
892    }
893
894    #[test]
895    fn build_sdcli_args_includes_negative_prompt_when_set() {
896        let params = ImageParams {
897            prompt: "hi".into(),
898            negative_prompt: Some("text, watermark, low quality".into()),
899            ..Default::default()
900        };
901        let source = fake_source(vec![]);
902        let args = build_sdcli_args(
903            &params,
904            &source,
905            Path::new("/d.gguf"),
906            None,
907            None,
908            None,
909            Path::new("/tmp/out.webp"),
910            None,
911            None,
912            None,
913            false,
914        );
915        let s = args_to_strings(&args);
916        assert_eq!(
917            s[idx_after(&s, "--negative-prompt").unwrap()],
918            "text, watermark, low quality"
919        );
920    }
921
922    #[test]
923    fn build_sdcli_args_omits_negative_prompt_when_empty_string() {
924        let params = ImageParams {
925            prompt: "hi".into(),
926            negative_prompt: Some(String::new()),
927            ..Default::default()
928        };
929        let source = fake_source(vec![]);
930        let args = build_sdcli_args(
931            &params,
932            &source,
933            Path::new("/d.gguf"),
934            None,
935            None,
936            None,
937            Path::new("/tmp/out.webp"),
938            None,
939            None,
940            None,
941            false,
942        );
943        let s = args_to_strings(&args);
944        assert!(!s.contains(&"--negative-prompt".to_string()));
945    }
946
947    #[test]
948    fn build_sdcli_args_includes_init_image_and_strength() {
949        let params = ImageParams {
950            prompt: "hi".into(),
951            denoise: Some(0.55),
952            ..Default::default()
953        };
954        let source = fake_source(vec![]);
955        let args = build_sdcli_args(
956            &params,
957            &source,
958            Path::new("/d.gguf"),
959            None,
960            None,
961            None,
962            Path::new("/tmp/out.webp"),
963            Some(Path::new("/tmp/init.webp")),
964            None,
965            None,
966            false,
967        );
968        let s = args_to_strings(&args);
969        assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
970        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.55");
971        // No mask supplied → no inpaint flag.
972        assert!(!s.contains(&"--mask".to_string()));
973    }
974
975    #[test]
976    fn build_sdcli_args_includes_mask_for_inpaint() {
977        let params = ImageParams {
978            prompt: "remove the tree".into(),
979            denoise: Some(0.8),
980            ..Default::default()
981        };
982        let source = fake_source(vec![]);
983        let args = build_sdcli_args(
984            &params,
985            &source,
986            Path::new("/d.gguf"),
987            None,
988            None,
989            None,
990            Path::new("/tmp/out.webp"),
991            Some(Path::new("/tmp/init.webp")),
992            Some(Path::new("/tmp/mask.png")),
993            None,
994            false,
995        );
996        let s = args_to_strings(&args);
997        assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
998        assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
999        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.8");
1000    }
1001
1002    #[test]
1003    fn build_sdcli_args_uses_model_flag_for_full_checkpoint() {
1004        let params = ImageParams {
1005            prompt: "hi".into(),
1006            ..Default::default()
1007        };
1008        let source = fake_source(vec![]);
1009        let args = build_sdcli_args(
1010            &params,
1011            &source,
1012            Path::new("/checkpoint.safetensors"),
1013            Some(Path::new("/v.safetensors")),
1014            None,
1015            None,
1016            Path::new("/tmp/out.webp"),
1017            None,
1018            None,
1019            None,
1020            true,
1021        );
1022        let s = args_to_strings(&args);
1023        // A full checkpoint loads via -m/--model, not --diffusion-model.
1024        assert_eq!(
1025            s[idx_after(&s, "--model").unwrap()],
1026            "/checkpoint.safetensors"
1027        );
1028        assert!(!s.contains(&"--diffusion-model".to_string()));
1029    }
1030
1031    #[test]
1032    fn build_sdcli_args_defaults_denoise_when_init_image_present_but_denoise_none() {
1033        let params = ImageParams {
1034            prompt: "hi".into(),
1035            denoise: None,
1036            ..Default::default()
1037        };
1038        let source = fake_source(vec![]);
1039        let args = build_sdcli_args(
1040            &params,
1041            &source,
1042            Path::new("/d.gguf"),
1043            None,
1044            None,
1045            None,
1046            Path::new("/tmp/out.webp"),
1047            Some(Path::new("/tmp/init.webp")),
1048            None,
1049            None,
1050            false,
1051        );
1052        let s = args_to_strings(&args);
1053        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.75");
1054    }
1055
1056    #[test]
1057    fn build_sdcli_args_per_job_cfg_scale_overrides_model_default() {
1058        let params = ImageParams {
1059            prompt: "hi".into(),
1060            cfg_scale: Some(7.5),
1061            ..Default::default()
1062        };
1063        let source = fake_source(vec![]);
1064        let args = build_sdcli_args(
1065            &params,
1066            &source,
1067            Path::new("/d.gguf"),
1068            None,
1069            None,
1070            None,
1071            Path::new("/tmp/out.webp"),
1072            None,
1073            None,
1074            None,
1075            false,
1076        );
1077        let s = args_to_strings(&args);
1078        assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "7.5");
1079    }
1080
1081    #[test]
1082    fn build_sdcli_args_per_job_sampling_method_overrides_model_default() {
1083        let params = ImageParams {
1084            prompt: "hi".into(),
1085            sampling_method: Some("dpm++2m".into()),
1086            ..Default::default()
1087        };
1088        let source = fake_source(vec![]);
1089        let args = build_sdcli_args(
1090            &params,
1091            &source,
1092            Path::new("/d.gguf"),
1093            None,
1094            None,
1095            None,
1096            Path::new("/tmp/out.webp"),
1097            None,
1098            None,
1099            None,
1100            false,
1101        );
1102        let s = args_to_strings(&args);
1103        assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "dpm++2m");
1104    }
1105
1106    #[test]
1107    fn build_sdcli_args_per_job_steps_overrides_when_non_default() {
1108        let params = ImageParams {
1109            prompt: "hi".into(),
1110            steps: 30, // != 20 → treat as caller override
1111            ..Default::default()
1112        };
1113        let source = fake_source(vec![]);
1114        let args = build_sdcli_args(
1115            &params,
1116            &source,
1117            Path::new("/d.gguf"),
1118            None,
1119            None,
1120            None,
1121            Path::new("/tmp/out.webp"),
1122            None,
1123            None,
1124            None,
1125            false,
1126        );
1127        let s = args_to_strings(&args);
1128        assert_eq!(s[idx_after(&s, "--steps").unwrap()], "30");
1129    }
1130
1131    #[test]
1132    fn build_sdcli_args_seed_included_when_set() {
1133        let params = ImageParams {
1134            prompt: "hi".into(),
1135            seed: Some(42),
1136            ..Default::default()
1137        };
1138        let source = fake_source(vec![]);
1139        let args = build_sdcli_args(
1140            &params,
1141            &source,
1142            Path::new("/d.gguf"),
1143            None,
1144            None,
1145            None,
1146            Path::new("/tmp/out.webp"),
1147            None,
1148            None,
1149            None,
1150            false,
1151        );
1152        let s = args_to_strings(&args);
1153        assert_eq!(s[idx_after(&s, "--seed").unwrap()], "42");
1154    }
1155
1156    /// A model source carrying the Qwen-Image-Edit flow flags.
1157    fn qwen_edit_source() -> ModelSource {
1158        ModelSource {
1159            engine: ModelEngine::SdCpp,
1160            files: vec![],
1161            cli_defaults: ModelCliDefaults {
1162                cfg_scale: 4.0,
1163                steps: 20,
1164                width: 1024,
1165                height: 1024,
1166                sampling_method: Some("euler".to_string()),
1167                flow_shift: Some(3.0),
1168                zero_cond_t: Some(true),
1169                offload_to_cpu: Some(true),
1170            },
1171        }
1172    }
1173
1174    #[test]
1175    fn build_sdcli_args_reference_mode_for_instruction_edit() {
1176        let params = ImageParams {
1177            prompt: "add a red beach ball".into(),
1178            denoise: Some(0.9),
1179            ..Default::default()
1180        };
1181        let source = qwen_edit_source();
1182        let args = build_sdcli_args(
1183            &params,
1184            &source,
1185            Path::new("/qwen.gguf"),
1186            Some(Path::new("/vae.safetensors")),
1187            Some(Path::new("/llm.gguf")),
1188            Some(Path::new("/mmproj.gguf")),
1189            Path::new("/tmp/out.webp"),
1190            None,
1191            Some(Path::new("/tmp/mask.png")),
1192            Some(Path::new("/tmp/ref.webp")),
1193            false,
1194        );
1195        let s = args_to_strings(&args);
1196        // Reference mode: `-r` set, a `--mask` constrains the edit region, and the img2img-only
1197        // `--init-img` / `--strength` flags are suppressed.
1198        assert_eq!(s[idx_after(&s, "-r").unwrap()], "/tmp/ref.webp");
1199        assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
1200        assert!(!s.contains(&"--init-img".to_string()));
1201        assert!(!s.contains(&"--strength".to_string()));
1202        // Vision encoder + Qwen flow flags emitted.
1203        assert_eq!(s[idx_after(&s, "--llm_vision").unwrap()], "/mmproj.gguf");
1204        assert_eq!(s[idx_after(&s, "--flow-shift").unwrap()], "3");
1205        assert!(s.contains(&"--qwen-image-zero-cond-t".to_string()));
1206        assert!(s.contains(&"--offload-to-cpu".to_string()));
1207    }
1208
1209    #[test]
1210    fn build_sdcli_args_omits_qwen_flags_for_plain_model() {
1211        let params = ImageParams {
1212            prompt: "hi".into(),
1213            ..Default::default()
1214        };
1215        // fake_source has no flow_shift / zero_cond_t / offload_to_cpu.
1216        let source = fake_source(vec![]);
1217        let args = build_sdcli_args(
1218            &params,
1219            &source,
1220            Path::new("/d.gguf"),
1221            None,
1222            None,
1223            None,
1224            Path::new("/tmp/out.webp"),
1225            None,
1226            None,
1227            None,
1228            false,
1229        );
1230        let s = args_to_strings(&args);
1231        assert!(!s.contains(&"--flow-shift".to_string()));
1232        assert!(!s.contains(&"--qwen-image-zero-cond-t".to_string()));
1233        assert!(!s.contains(&"--offload-to-cpu".to_string()));
1234        assert!(!s.contains(&"--llm_vision".to_string()));
1235        assert!(!s.contains(&"-r".to_string()));
1236    }
1237
1238    #[test]
1239    fn capabilities_advertises_only_image_kind() {
1240        let dir = tempdir().unwrap();
1241        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
1242        let caps = engine.capabilities();
1243        assert!(caps
1244            .supported_models_per_kind
1245            .contains_key(&TaskKind::Image));
1246        assert_eq!(caps.supported_models_per_kind.len(), 1);
1247    }
1248
1249    #[test]
1250    fn init_image_extension_reads_url_tail() {
1251        assert_eq!(init_image_extension("https://x/y/latest.webp"), "webp");
1252        assert_eq!(init_image_extension("https://x/y/latest.PNG"), "png");
1253        assert_eq!(init_image_extension("https://x/y/latest.jpg"), "jpg");
1254        assert_eq!(init_image_extension("https://x/y/latest.jpeg"), "jpg");
1255        // Query strings + fragments don't trick the parser.
1256        assert_eq!(
1257            init_image_extension("https://x/y/latest.webp?v=42&t=now"),
1258            "webp"
1259        );
1260        assert_eq!(init_image_extension("https://x/y/latest.webp#frag"), "webp");
1261        // Unknown extension falls back to webp.
1262        assert_eq!(
1263            init_image_extension("https://x/y/latest.unknownext"),
1264            "webp"
1265        );
1266        assert_eq!(init_image_extension("https://x/y/no-ext"), "webp");
1267    }
1268}