Skip to main content

studio_worker/engine/
sdcpp.rs

1//! Engine that runs real image inference by subprocess-invoking the
2//! `stable-diffusion.cpp` (`sd-cli`) binary.
3//!
4//! The studio's offer carries a [`ModelSource`] with everything we
5//! need: an engine identifier (`sd-cpp`), the list of files to
6//! download (diffusion-model + text-encoder + VAE, each with a public
7//! URL + filename), and CLI defaults (cfg-scale, steps, dimensions).
8//! The worker has zero hardcoded model knowledge — it caches
9//! whatever the studio asks for under `cfg.models_root` and invokes
10//! `sd-cli` with the files arranged by role.
11//!
12//! Layout under `cfg.models_root` (default `~/models`):
13//! ```text
14//! ~/models/<filename1>
15//! ~/models/<filename2>
16//! …
17//! ```
18//! Files are downloaded on first use - skipped when already present
19//! under `cfg.models_root`.  The streamed body is checked against the
20//! server's `Content-Length` so a truncated download is rejected and
21//! cleaned up instead of being renamed into place as a corrupt model
22//! that every later job would fail to load.  Cached files are re-used
23//! across every subsequent job that names them.
24//!
25//! The engine self-registers only when `sd-cli` is present on the box
26//! (either at `$STUDIO_WORKER_SD_CLI`, or `~/.local/bin/sd-cli`, or on
27//! `$PATH`).  Without `sd-cli` the worker can't run real-image jobs
28//! at all so it skips registration and the multi engine falls through
29//! to synthetic for any kind it doesn't have a real backend for.
30
31use crate::engine::download::{self, TempFileGuard};
32use crate::engine::sd_provision;
33use crate::engine::{Engine, EngineCapabilities};
34use crate::types::{ImageParams, ModelFileRole, ModelSource, Task, TaskKind, TaskResult};
35use anyhow::{anyhow, bail, Context, Result};
36use parking_lot::Mutex;
37use std::collections::BTreeMap;
38use std::ffi::OsString;
39use std::path::{Path, PathBuf};
40use std::process::Command;
41use std::time::Instant;
42use tracing::{debug, info, warn};
43
44const TRACE_TARGET: &str = "studio_worker::engine::sdcpp";
45
46/// Default sample-steps when the studio's `ImageParams.steps` is the
47/// upstream default (20).  Z-Image-Turbo is an 8-step distilled
48/// schedule so 20 wastes time; we honour `ModelSource.cliDefaults.steps`
49/// instead.  Only used as the very last fallback.
50const STEPS_FALLBACK: u32 = 8;
51
52/// Worker-side engine that drives `sd-cli` per job.
53///
54/// `sd-cli` is resolved lazily on the first image job and cached: an
55/// operator install (env / PATH / `~/.local/bin`) wins, otherwise the
56/// binary is auto-provisioned into `<models_root>/bin/`.  The `Mutex`
57/// serialises that one-time resolution so two concurrent jobs can't
58/// race the download.
59pub struct SdCppEngine {
60    sd_cli: Mutex<Option<PathBuf>>,
61    models_root: PathBuf,
62}
63
64impl SdCppEngine {
65    /// Build the engine.  Always registers: `sd-cli` is resolved (and
66    /// provisioned into `<models_root>/bin/` if missing) lazily on the
67    /// first image job, so the engine serves real image work even on a
68    /// box that has never had a stable-diffusion.cpp build installed.
69    /// `models_root` is created on demand by the provisioner / model
70    /// downloader, so registration touches no filesystem.
71    pub fn new(models_root: &Path) -> Self {
72        info!(
73            target: TRACE_TARGET,
74            op = "register",
75            models_root = %models_root.display(),
76            sd_cli_name = sd_provision::binary_name(),
77            "sdcpp engine registered (sd-cli resolved/provisioned on first image job)"
78        );
79        Self {
80            sd_cli: Mutex::new(None),
81            models_root: models_root.to_path_buf(),
82        }
83    }
84
85    /// For tests: build with explicit paths (bypasses sd-cli lookup +
86    /// provisioning by seeding the resolved-path cache).
87    #[cfg(test)]
88    pub fn with_paths(sd_cli: PathBuf, models_root: PathBuf) -> Self {
89        Self {
90            sd_cli: Mutex::new(Some(sd_cli)),
91            models_root,
92        }
93    }
94
95    /// Resolve the `sd-cli` binary, provisioning it on first use.
96    /// Resolution order (operator installs win): a cached path from a
97    /// previous job, then env / `<models_root>/bin` / `~/.local/bin` /
98    /// `$PATH`, then an auto-provisioned download into
99    /// `<models_root>/bin/`.  The result is cached for the worker's
100    /// lifetime.
101    #[cfg_attr(coverage_nightly, coverage(off))]
102    fn ensure_sd_cli(&self) -> Result<PathBuf> {
103        let mut guard = self.sd_cli.lock();
104        if let Some(p) = guard.as_ref() {
105            if p.is_file() {
106                return Ok(p.clone());
107            }
108        }
109        let resolved = match resolve_sd_cli(&self.models_root) {
110            Some(p) => {
111                info!(
112                    target: TRACE_TARGET,
113                    op = "resolve",
114                    sd_cli = %p.display(),
115                    "using existing sd-cli"
116                );
117                p
118            }
119            None => sd_provision::provision(&self.models_root)
120                .context("auto-provisioning sd-cli (stable-diffusion.cpp)")?,
121        };
122        *guard = Some(resolved.clone());
123        Ok(resolved)
124    }
125
126    /// Ensure each file in `source.files` is present under
127    /// `self.models_root`.  Downloads anything missing.  Returns the
128    /// resolved local path for each file (in the same order).
129    #[cfg_attr(coverage_nightly, coverage(off))]
130    fn ensure_files(&self, source: &ModelSource) -> Result<Vec<(ModelFileRole, PathBuf)>> {
131        let mut out = Vec::with_capacity(source.files.len());
132        for file in &source.files {
133            let local = download::ensure_file(&self.models_root, file)?;
134            out.push((file.role, local));
135        }
136        Ok(out)
137    }
138
139    /// Subprocess to `sd-cli` with the resolved diffusion / VAE /
140    /// text-encoder files.  Excluded from coverage: requires an
141    /// actual `sd-cli` binary + cached model files on disk, neither
142    /// of which exists on the CI runner.  Exercised end-to-end via
143    /// the live dev loop.
144    #[cfg_attr(coverage_nightly, coverage(off))]
145    fn dispatch_image(
146        &self,
147        model: &str,
148        params: ImageParams,
149        source: &ModelSource,
150    ) -> Result<TaskResult> {
151        // Resolve (provisioning on first use) the sd-cli binary before
152        // we touch model files, so a missing binary fails fast with the
153        // provisioning error rather than after a multi-GB weight pull.
154        let sd_cli = self.ensure_sd_cli()?;
155        // Preflight the GPU runtime next: a missing Vulkan loader can't be
156        // auto-provisioned (it ships with the driver / a system package),
157        // so surface the actionable remedy now instead of after a
158        // multi-GB weight pull and a cryptic sd-cli crash.
159        if let Err(e) = sd_provision::vulkan_runtime_status() {
160            warn!(
161                target: TRACE_TARGET,
162                op = "preflight",
163                model,
164                error = %e,
165                "GPU runtime missing; refusing image job"
166            );
167            return Err(e);
168        }
169        let files = self.ensure_files(source)?;
170        // A `diffusion-model` file is the standalone diffusion weights (sd-cli `--diffusion-model`,
171        // used with split vae/clip); a `model` file is a full checkpoint (sd-cli `-m`/`--model`).
172        // Prefer the explicit diffusion-model role; fall back to a full checkpoint.
173        let diffusion_only = file_for_role(&files, ModelFileRole::DiffusionModel);
174        let full_checkpoint = diffusion_only.is_none();
175        let diffusion_model = diffusion_only
176            .or_else(|| file_for_role(&files, ModelFileRole::Model))
177            .ok_or_else(|| anyhow!("modelSource has no diffusion-model / model file"))?;
178        let vae = file_for_role(&files, ModelFileRole::Vae);
179        let text_encoder = file_for_role(&files, ModelFileRole::TextEncoder);
180        let text_encoder_vision = file_for_role(&files, ModelFileRole::TextEncoderVision);
181
182        let out_dir = std::env::temp_dir().join("studio-worker-sdcpp");
183        std::fs::create_dir_all(&out_dir)
184            .with_context(|| format!("creating sdcpp output dir {}", out_dir.display()))?;
185        let stem = format!(
186            "out-{}-{}",
187            std::process::id(),
188            chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
189        );
190        let out_path = out_dir.join(format!("{stem}.webp"));
191
192        // Own the scratch files from the moment their paths exist so
193        // every failure path (sd-cli error, unreadable output) cleans
194        // up instead of leaking them into the temp dir.
195        let mut temp_files = TempFileGuard::new();
196        temp_files.push(out_path.clone());
197
198        // If the task carries an init image URL, stream it to a
199        // tempfile so we can hand the path to `sd-cli --init-img`.
200        // This is mandatory — the worker refuses i2i jobs whose
201        // init image fails to download (no silent fallback to t2i).
202        // The local extension first mirrors the URL's, then is corrected
203        // to the file's real content format — studio asset URLs lie
204        // (`latest.webp` is often JPEG bytes) and sd-cli picks its image
205        // decoder purely from the extension.
206        let init_img_path = match params.init_image_url.as_deref() {
207            Some(url) if !url.is_empty() => {
208                let ext = init_image_extension(url);
209                let init_path = out_dir.join(format!("{stem}-init.{ext}"));
210                download::download_file(url, &init_path).with_context(|| {
211                    format!("downloading init image {} -> {}", url, init_path.display())
212                })?;
213                temp_files.push(init_path.clone());
214                let usable = download::ensure_correct_image_extension(&init_path)?;
215                if usable != init_path {
216                    temp_files.push(usable.clone());
217                }
218                Some(usable)
219            }
220            _ => None,
221        };
222
223        // A mask constrains the edit region — valid alongside either an init image (img2img
224        // inpaint) or a reference image (instruction edit). Download it whenever a base image is
225        // present and a mask URL was supplied; white pixels mark the region the model may change.
226        let has_base = init_img_path.is_some() || params.ref_image_url.as_deref().is_some();
227        let mask_path = match (has_base, params.mask_url.as_deref()) {
228            (true, Some(url)) if !url.is_empty() => {
229                let ext = init_image_extension(url);
230                let path = out_dir.join(format!("{stem}-mask.{ext}"));
231                download::download_file(url, &path)
232                    .with_context(|| format!("downloading mask {} -> {}", url, path.display()))?;
233                temp_files.push(path.clone());
234                let usable = download::ensure_correct_image_extension(&path)?;
235                if usable != path {
236                    temp_files.push(usable.clone());
237                }
238                Some(usable)
239            }
240            _ => None,
241        };
242
243        // Reference image for instruction-edit models (`sd-cli -r`). Downloaded like the init image;
244        // when present the arg builder uses reference mode instead of the img2img/mask path.
245        let ref_img_path = match params.ref_image_url.as_deref() {
246            Some(url) if !url.is_empty() => {
247                let ext = init_image_extension(url);
248                let path = out_dir.join(format!("{stem}-ref.{ext}"));
249                download::download_file(url, &path).with_context(|| {
250                    format!("downloading reference image {} -> {}", url, path.display())
251                })?;
252                temp_files.push(path.clone());
253                let usable = download::ensure_correct_image_extension(&path)?;
254                if usable != path {
255                    temp_files.push(usable.clone());
256                }
257                Some(usable)
258            }
259            _ => None,
260        };
261
262        let args = build_sdcli_args(
263            &params,
264            source,
265            diffusion_model,
266            vae,
267            text_encoder,
268            text_encoder_vision,
269            &out_path,
270            init_img_path.as_deref(),
271            mask_path.as_deref(),
272            ref_img_path.as_deref(),
273            full_checkpoint,
274        );
275        let mut cmd = Command::new(&sd_cli);
276        cmd.args(&args);
277        apply_library_path(&mut cmd, &sd_cli);
278
279        debug!(
280            target: TRACE_TARGET,
281            op = "spawn",
282            sd_cli = %sd_cli.display(),
283            model,
284            i2i = init_img_path.is_some(),
285            arg_count = args.len(),
286            "running sd-cli"
287        );
288
289        let started = Instant::now();
290        let output = cmd
291            .output()
292            .with_context(|| format!("running {}", sd_cli.display()))?;
293        let elapsed_ms = started.elapsed().as_millis() as u64;
294        if !output.status.success() {
295            let stderr = String::from_utf8_lossy(&output.stderr);
296            warn!(
297                target: TRACE_TARGET,
298                op = "spawn",
299                model,
300                elapsed_ms,
301                exit = ?output.status.code(),
302                stderr = %stderr,
303                "sd-cli failed"
304            );
305            bail!(
306                "sd-cli exited with {:?}: {}",
307                output.status.code(),
308                stderr.lines().last().unwrap_or("(no stderr)")
309            );
310        }
311
312        let bytes = std::fs::read(&out_path)
313            .with_context(|| format!("reading sd-cli output at {}", out_path.display()))?;
314        info!(
315            target: TRACE_TARGET,
316            op = "dispatch",
317            model,
318            elapsed_ms,
319            bytes = bytes.len(),
320            "ok"
321        );
322
323        Ok(TaskResult::Image {
324            bytes,
325            ext: "webp".to_string(),
326        })
327    }
328}
329
330impl Engine for SdCppEngine {
331    fn name(&self) -> &'static str {
332        "sdcpp"
333    }
334
335    fn capabilities(&self) -> EngineCapabilities {
336        // Image kind only.  The studio's selection is kind-based now
337        // and the offer carries the model-source, so we don't need to
338        // enumerate model names ourselves.  We still list a single
339        // sentinel string so downstream code that reads
340        // `supportedModels` for display sees "any sd-cpp model".
341        let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
342        map.insert(TaskKind::Image, vec!["sd-cpp:*".to_string()]);
343        EngineCapabilities {
344            supported_models_per_kind: map,
345        }
346    }
347
348    fn dispatch(&self, _model: &str, _task: Task) -> Result<TaskResult> {
349        bail!(
350            "sdcpp engine requires a ModelSource on the offer; legacy push-based offers \
351             (no modelSource) cannot be served - re-promote the job through the studio"
352        )
353    }
354
355    fn dispatch_with_source(
356        &self,
357        model: &str,
358        task: Task,
359        source: &ModelSource,
360    ) -> Result<TaskResult> {
361        match task {
362            Task::Image(p) => self.dispatch_image(model, p, source),
363            other => {
364                // Surface the rejection at this engine's own target,
365                // matching the onnx/llama/whisper/candle engines.
366                // Without it an operator filtering
367                // `RUST_LOG=studio_worker::engine::sdcpp=debug` sees
368                // nothing when sdcpp refuses a non-image task.
369                let kind = other.kind();
370                warn!(
371                    target: TRACE_TARGET,
372                    op = "dispatch",
373                    model,
374                    kind = kind.as_str(),
375                    "sdcpp engine only serves image jobs"
376                );
377                Err(crate::engine::UnsupportedTask::new("sdcpp", kind).into())
378            }
379        }
380    }
381}
382
383// ---------------------------------------------------------------------------
384// Helpers
385// ---------------------------------------------------------------------------
386
387// The per-job scratch cleanup primitives (`remove_temp_file` +
388// `TempFileGuard`) live in `engine::download` so this engine and the
389// onnx engine share one tested implementation.
390
391fn file_for_role(files: &[(ModelFileRole, PathBuf)], role: ModelFileRole) -> Option<&Path> {
392    files
393        .iter()
394        .find(|(r, _)| *r == role)
395        .map(|(_, p)| p.as_path())
396}
397
398/// Resolve final per-job width / height / steps / cfg / sampler /
399/// negative-prompt by layering `params` over `source.cli_defaults`
400/// with the agreed precedence (per-job override beats model default
401/// beats engine fallback).  Pure for testability.
402fn resolve_image_args(params: &ImageParams, source: &ModelSource) -> ResolvedImageArgs {
403    let width = if params.width > 0 {
404        params.width
405    } else if source.cli_defaults.width > 0 {
406        source.cli_defaults.width
407    } else {
408        1024
409    };
410    let height = if params.height > 0 {
411        params.height
412    } else if source.cli_defaults.height > 0 {
413        source.cli_defaults.height
414    } else {
415        1024
416    };
417    // Steps: per-job override wins (treat the deserialiser default of
418    // 20 as "caller didn't pick" so the model's tuned step count
419    // doesn't get clobbered by a stale default).
420    let steps = if params.steps > 0 && params.steps != 20 {
421        params.steps
422    } else if source.cli_defaults.steps > 0 {
423        source.cli_defaults.steps
424    } else {
425        STEPS_FALLBACK
426    };
427    let source_cfg = if source.cli_defaults.cfg_scale > 0.0 {
428        source.cli_defaults.cfg_scale
429    } else {
430        1.0
431    };
432    let cfg_scale = params.cfg_scale.filter(|v| *v > 0.0).unwrap_or(source_cfg);
433    let sampling_method = params
434        .sampling_method
435        .clone()
436        .or_else(|| source.cli_defaults.sampling_method.clone());
437    ResolvedImageArgs {
438        width,
439        height,
440        steps,
441        cfg_scale,
442        sampling_method,
443    }
444}
445
446/// Resolved per-job sd-cli numerics.  Output of [`resolve_image_args`].
447#[derive(Debug, Clone, PartialEq)]
448struct ResolvedImageArgs {
449    width: u32,
450    height: u32,
451    steps: u32,
452    cfg_scale: f32,
453    sampling_method: Option<String>,
454}
455
456/// Build the full `sd-cli` argv for one image job.  Pure (no I/O):
457/// the caller resolves files / out-path / init-image-path, this
458/// function only assembles the flag list so it can be asserted in
459/// unit tests without spawning the binary.
460// Eight model-path + i2i components; grouping them adds indirection without
461// improving readability (mirrors the `#[allow]` already used in ws::session).
462#[allow(clippy::too_many_arguments)]
463fn build_sdcli_args(
464    params: &ImageParams,
465    source: &ModelSource,
466    diffusion_model: &Path,
467    vae: Option<&Path>,
468    text_encoder: Option<&Path>,
469    text_encoder_vision: Option<&Path>,
470    out_path: &Path,
471    init_img_path: Option<&Path>,
472    mask_path: Option<&Path>,
473    ref_img_path: Option<&Path>,
474    full_checkpoint: bool,
475) -> Vec<OsString> {
476    let resolved = resolve_image_args(params, source);
477    let mut args: Vec<OsString> = Vec::with_capacity(32);
478
479    // A full checkpoint loads via `-m`/`--model`; standalone diffusion weights via
480    // `--diffusion-model` (alongside split vae/clip files).
481    args.push(
482        if full_checkpoint {
483            "--model"
484        } else {
485            "--diffusion-model"
486        }
487        .into(),
488    );
489    args.push(diffusion_model.into());
490    if let Some(p) = vae {
491        args.push("--vae".into());
492        args.push(p.into());
493    }
494    if let Some(p) = text_encoder {
495        args.push("--llm".into());
496        args.push(p.into());
497    }
498    if let Some(p) = text_encoder_vision {
499        args.push("--llm_vision".into());
500        args.push(p.into());
501    }
502    args.push("-p".into());
503    args.push((&params.prompt as &str).into());
504    if let Some(neg) = params.negative_prompt.as_deref() {
505        if !neg.is_empty() {
506            args.push("--negative-prompt".into());
507            args.push(neg.into());
508        }
509    }
510    if let Some(reference) = ref_img_path {
511        // Reference / instruction-edit mode (Qwen-Image-Edit, Flux Kontext): the model regenerates
512        // the image from the reference per the prompt. Mutually exclusive with the `--init-img`
513        // img2img path. A `--mask` is honoured here too: it constrains the edit to the masked
514        // region (white = editable) and leaves the rest, so the studio can place the edit inside
515        // the author's drawn shape. No `--strength` (that's an img2img-only knob).
516        args.push("-r".into());
517        args.push(reference.into());
518        if let Some(mask) = mask_path {
519            args.push("--mask".into());
520            args.push(mask.into());
521        }
522    } else if let Some(init) = init_img_path {
523        args.push("--init-img".into());
524        args.push(init.into());
525        // `--strength` only makes sense alongside an init image
526        // (sd-cli ignores it otherwise).  Default to 0.75 (sd-cli's
527        // own default) when the caller didn't pick a value.
528        let strength = params.denoise.unwrap_or(0.75);
529        args.push("--strength".into());
530        args.push(strength.to_string().into());
531        // Mask-guided inpaint: only valid with an init image.
532        if let Some(mask) = mask_path {
533            args.push("--mask".into());
534            args.push(mask.into());
535        }
536    }
537    args.push("--cfg-scale".into());
538    args.push(resolved.cfg_scale.to_string().into());
539    args.push("--steps".into());
540    args.push(resolved.steps.to_string().into());
541    args.push("-W".into());
542    args.push(resolved.width.to_string().into());
543    args.push("-H".into());
544    args.push(resolved.height.to_string().into());
545    args.push("-o".into());
546    args.push(out_path.into());
547    if let Some(seed) = params.seed {
548        args.push("--seed".into());
549        args.push(seed.to_string().into());
550    }
551    if let Some(method) = resolved.sampling_method.as_deref() {
552        args.push("--sampling-method".into());
553        args.push(method.into());
554    }
555    // Flow / instruction-edit model flags (model-level constants from the registry). Only emitted
556    // when the model declares them, so SDXL-style models are unaffected.
557    if let Some(shift) = source.cli_defaults.flow_shift {
558        args.push("--flow-shift".into());
559        args.push(shift.to_string().into());
560    }
561    if source.cli_defaults.zero_cond_t == Some(true) {
562        args.push("--qwen-image-zero-cond-t".into());
563    }
564    if source.cli_defaults.offload_to_cpu == Some(true) {
565        args.push("--offload-to-cpu".into());
566    }
567    // VRAM-saving flags that are safe on every box.
568    args.push("--diffusion-fa".into());
569    args
570}
571
572/// Point the per-job `Command`'s dynamic linker at the shared library
573/// that ships next to an auto-provisioned `sd-cli` (Linux / macOS).
574/// No-op on Windows (sibling DLLs resolve automatically) and when the
575/// resolved binary has no sibling library (operator wrapper-script
576/// installs manage their own load path).  Prepends to any inherited
577/// value so a pre-set `LD_LIBRARY_PATH` isn't clobbered.
578#[cfg_attr(coverage_nightly, coverage(off))]
579fn apply_library_path(cmd: &mut Command, sd_cli: &Path) {
580    let Some((var, dir)) = sd_provision::library_path_env(sd_cli) else {
581        return;
582    };
583    let value = match std::env::var_os(var) {
584        Some(existing) => {
585            let mut paths = vec![dir.clone()];
586            paths.extend(std::env::split_paths(&existing));
587            // `join_paths` only fails if a path contains the platform
588            // separator; fall back to our dir alone, the entry that
589            // matters for finding the sibling library.
590            std::env::join_paths(paths).unwrap_or_else(|_| dir.into_os_string())
591        }
592        None => dir.into_os_string(),
593    };
594    cmd.env(var, value);
595}
596
597/// Look up `sd-cli` in env override -> `<models_root>/bin` ->
598/// `~/.local/bin` -> `$PATH`.  The `<models_root>/bin` slot is where a
599/// self-provisioned binary lands, so the auto-provisioner can drop it
600/// next to the cached models and have the worker pick it up with no
601/// PATH fiddling.  Excluded from coverage: touches several host paths
602/// only one of which matches per host, and CI doesn't ship `sd-cli`.
603#[cfg_attr(coverage_nightly, coverage(off))]
604fn resolve_sd_cli(models_root: &Path) -> Option<PathBuf> {
605    let bin = sd_provision::binary_name();
606    if let Ok(p) = std::env::var("STUDIO_WORKER_SD_CLI") {
607        let path = PathBuf::from(p);
608        if path.is_file() {
609            return Some(path);
610        }
611    }
612    let in_models = models_root.join("bin").join(bin);
613    if in_models.is_file() {
614        return Some(in_models);
615    }
616    if let Some(home) = std::env::var_os("HOME") {
617        let candidate = PathBuf::from(home).join(".local/bin").join(bin);
618        if candidate.is_file() {
619            return Some(candidate);
620        }
621    }
622    which(bin)
623}
624
625/// `$PATH` lookup for a bare binary name.  Excluded from coverage
626/// for the same reason as `resolve_sd_cli`.
627#[cfg_attr(coverage_nightly, coverage(off))]
628fn which(bin: &str) -> Option<PathBuf> {
629    let path = std::env::var_os("PATH")?;
630    for entry in std::env::split_paths(&path) {
631        let candidate = entry.join(bin);
632        if candidate.is_file() {
633            return Some(candidate);
634        }
635    }
636    None
637}
638
639/// Pick an extension to use for the init-image tempfile that sd-cli's
640/// image loader can sniff.  Reads the trailing `.<ext>` from the URL's
641/// path (ignoring query + fragment).  Defaults to `webp` when no
642/// recognisable extension is present.
643fn init_image_extension(url: &str) -> &'static str {
644    let path = url.split(['?', '#']).next().unwrap_or(url);
645    let lower_tail = path
646        .rsplit('.')
647        .next()
648        .map(|t| t.to_ascii_lowercase())
649        .unwrap_or_default();
650    match lower_tail.as_str() {
651        "png" => "png",
652        "jpg" | "jpeg" => "jpg",
653        "webp" => "webp",
654        "bmp" => "bmp",
655        "gif" => "gif",
656        "tif" | "tiff" => "tif",
657        _ => "webp",
658    }
659}
660
661// ---------------------------------------------------------------------------
662// Tests
663// ---------------------------------------------------------------------------
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use crate::types::{ModelCliDefaults, ModelEngine, ModelFile, ModelFileRole};
669    use tempfile::tempdir;
670
671    fn fake_source(files: Vec<ModelFile>) -> ModelSource {
672        ModelSource {
673            engine: ModelEngine::SdCpp,
674            files,
675            cli_defaults: ModelCliDefaults {
676                cfg_scale: 1.0,
677                steps: 8,
678                width: 1024,
679                height: 1024,
680                sampling_method: Some("euler".to_string()),
681                ..Default::default()
682            },
683        }
684    }
685
686    #[test]
687    fn file_for_role_picks_matching_file() {
688        let files = vec![
689            (ModelFileRole::DiffusionModel, PathBuf::from("/d.gguf")),
690            (ModelFileRole::Vae, PathBuf::from("/v.safetensors")),
691        ];
692        assert_eq!(
693            file_for_role(&files, ModelFileRole::DiffusionModel),
694            Some(Path::new("/d.gguf"))
695        );
696        assert_eq!(
697            file_for_role(&files, ModelFileRole::Vae),
698            Some(Path::new("/v.safetensors"))
699        );
700        assert!(file_for_role(&files, ModelFileRole::TextEncoder).is_none());
701    }
702
703    #[test]
704    fn ensure_files_skips_already_present() {
705        let dir = tempdir().unwrap();
706        let cached = dir.path().join("cached.gguf");
707        std::fs::write(&cached, b"already here").unwrap();
708        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
709        let source = fake_source(vec![ModelFile {
710            role: ModelFileRole::DiffusionModel,
711            url: "https://example.invalid/cached.gguf".into(),
712            filename: "cached.gguf".into(),
713            approx_bytes: None,
714            sha256: None,
715        }]);
716        let resolved = engine.ensure_files(&source).expect("cached file used");
717        assert_eq!(resolved.len(), 1);
718        assert_eq!(resolved[0].0, ModelFileRole::DiffusionModel);
719        assert_eq!(resolved[0].1, cached);
720        // Untouched on disk — our "download" never ran.
721        assert_eq!(std::fs::read(&cached).unwrap(), b"already here");
722    }
723
724    #[test]
725    fn dispatch_rejects_non_image_tasks() {
726        use crate::types::AudioTtsParams;
727        let dir = tempdir().unwrap();
728        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
729        let task = Task::AudioTts(AudioTtsParams {
730            text: "hi".into(),
731            voice: "v".into(),
732            ext: "wav".into(),
733            ..Default::default()
734        });
735        let source = fake_source(vec![]);
736        let err = engine
737            .dispatch_with_source("anything", task, &source)
738            .unwrap_err();
739        assert!(err.to_string().contains("cannot serve audio_tts"));
740    }
741
742    // The legacy `dispatch_requires_model_source` test is gone: the
743    // trait signature now takes `&ModelSource` so the compiler enforces
744    // it at every call site.  No runtime fallback to police.
745
746    // -----------------------------------------------------------------
747    // Pure arg-builder tests — lock down the sd-cli invocation contract
748    // without needing the binary on the box.
749    // -----------------------------------------------------------------
750
751    fn args_to_strings(args: &[OsString]) -> Vec<String> {
752        args.iter()
753            .map(|s| s.to_string_lossy().into_owned())
754            .collect()
755    }
756
757    fn idx_after(args: &[String], flag: &str) -> Option<usize> {
758        args.iter().position(|a| a == flag).map(|i| i + 1)
759    }
760
761    #[test]
762    fn build_sdcli_args_includes_required_flags() {
763        let params = ImageParams {
764            prompt: "hello".into(),
765            width: 768,
766            height: 512,
767            steps: 20, // "caller didn't pick" → source default wins
768            ..Default::default()
769        };
770        let source = fake_source(vec![]);
771        let args = build_sdcli_args(
772            &params,
773            &source,
774            Path::new("/d.gguf"),
775            Some(Path::new("/v.safetensors")),
776            Some(Path::new("/llm.gguf")),
777            None,
778            Path::new("/tmp/out.webp"),
779            None,
780            None,
781            None,
782            false,
783        );
784        let s = args_to_strings(&args);
785        assert_eq!(s[idx_after(&s, "--diffusion-model").unwrap()], "/d.gguf");
786        assert_eq!(s[idx_after(&s, "--vae").unwrap()], "/v.safetensors");
787        assert_eq!(s[idx_after(&s, "--llm").unwrap()], "/llm.gguf");
788        assert_eq!(s[idx_after(&s, "-p").unwrap()], "hello");
789        assert_eq!(s[idx_after(&s, "-W").unwrap()], "768");
790        assert_eq!(s[idx_after(&s, "-H").unwrap()], "512");
791        // source default cfg_scale=1.0
792        assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "1");
793        // source default steps=8 wins (param.steps==20 treated as default)
794        assert_eq!(s[idx_after(&s, "--steps").unwrap()], "8");
795        assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "euler");
796        assert_eq!(s[idx_after(&s, "-o").unwrap()], "/tmp/out.webp");
797        assert!(s.contains(&"--diffusion-fa".to_string()));
798        // Never includes init-only flags when no init image present.
799        assert!(!s.contains(&"--init-img".to_string()));
800        assert!(!s.contains(&"--strength".to_string()));
801    }
802
803    #[test]
804    fn build_sdcli_args_includes_negative_prompt_when_set() {
805        let params = ImageParams {
806            prompt: "hi".into(),
807            negative_prompt: Some("text, watermark, low quality".into()),
808            ..Default::default()
809        };
810        let source = fake_source(vec![]);
811        let args = build_sdcli_args(
812            &params,
813            &source,
814            Path::new("/d.gguf"),
815            None,
816            None,
817            None,
818            Path::new("/tmp/out.webp"),
819            None,
820            None,
821            None,
822            false,
823        );
824        let s = args_to_strings(&args);
825        assert_eq!(
826            s[idx_after(&s, "--negative-prompt").unwrap()],
827            "text, watermark, low quality"
828        );
829    }
830
831    #[test]
832    fn build_sdcli_args_omits_negative_prompt_when_empty_string() {
833        let params = ImageParams {
834            prompt: "hi".into(),
835            negative_prompt: Some(String::new()),
836            ..Default::default()
837        };
838        let source = fake_source(vec![]);
839        let args = build_sdcli_args(
840            &params,
841            &source,
842            Path::new("/d.gguf"),
843            None,
844            None,
845            None,
846            Path::new("/tmp/out.webp"),
847            None,
848            None,
849            None,
850            false,
851        );
852        let s = args_to_strings(&args);
853        assert!(!s.contains(&"--negative-prompt".to_string()));
854    }
855
856    #[test]
857    fn build_sdcli_args_includes_init_image_and_strength() {
858        let params = ImageParams {
859            prompt: "hi".into(),
860            denoise: Some(0.55),
861            ..Default::default()
862        };
863        let source = fake_source(vec![]);
864        let args = build_sdcli_args(
865            &params,
866            &source,
867            Path::new("/d.gguf"),
868            None,
869            None,
870            None,
871            Path::new("/tmp/out.webp"),
872            Some(Path::new("/tmp/init.webp")),
873            None,
874            None,
875            false,
876        );
877        let s = args_to_strings(&args);
878        assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
879        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.55");
880        // No mask supplied → no inpaint flag.
881        assert!(!s.contains(&"--mask".to_string()));
882    }
883
884    #[test]
885    fn build_sdcli_args_includes_mask_for_inpaint() {
886        let params = ImageParams {
887            prompt: "remove the tree".into(),
888            denoise: Some(0.8),
889            ..Default::default()
890        };
891        let source = fake_source(vec![]);
892        let args = build_sdcli_args(
893            &params,
894            &source,
895            Path::new("/d.gguf"),
896            None,
897            None,
898            None,
899            Path::new("/tmp/out.webp"),
900            Some(Path::new("/tmp/init.webp")),
901            Some(Path::new("/tmp/mask.png")),
902            None,
903            false,
904        );
905        let s = args_to_strings(&args);
906        assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
907        assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
908        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.8");
909    }
910
911    #[test]
912    fn build_sdcli_args_uses_model_flag_for_full_checkpoint() {
913        let params = ImageParams {
914            prompt: "hi".into(),
915            ..Default::default()
916        };
917        let source = fake_source(vec![]);
918        let args = build_sdcli_args(
919            &params,
920            &source,
921            Path::new("/checkpoint.safetensors"),
922            Some(Path::new("/v.safetensors")),
923            None,
924            None,
925            Path::new("/tmp/out.webp"),
926            None,
927            None,
928            None,
929            true,
930        );
931        let s = args_to_strings(&args);
932        // A full checkpoint loads via -m/--model, not --diffusion-model.
933        assert_eq!(
934            s[idx_after(&s, "--model").unwrap()],
935            "/checkpoint.safetensors"
936        );
937        assert!(!s.contains(&"--diffusion-model".to_string()));
938    }
939
940    #[test]
941    fn build_sdcli_args_defaults_denoise_when_init_image_present_but_denoise_none() {
942        let params = ImageParams {
943            prompt: "hi".into(),
944            denoise: None,
945            ..Default::default()
946        };
947        let source = fake_source(vec![]);
948        let args = build_sdcli_args(
949            &params,
950            &source,
951            Path::new("/d.gguf"),
952            None,
953            None,
954            None,
955            Path::new("/tmp/out.webp"),
956            Some(Path::new("/tmp/init.webp")),
957            None,
958            None,
959            false,
960        );
961        let s = args_to_strings(&args);
962        assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.75");
963    }
964
965    #[test]
966    fn build_sdcli_args_per_job_cfg_scale_overrides_model_default() {
967        let params = ImageParams {
968            prompt: "hi".into(),
969            cfg_scale: Some(7.5),
970            ..Default::default()
971        };
972        let source = fake_source(vec![]);
973        let args = build_sdcli_args(
974            &params,
975            &source,
976            Path::new("/d.gguf"),
977            None,
978            None,
979            None,
980            Path::new("/tmp/out.webp"),
981            None,
982            None,
983            None,
984            false,
985        );
986        let s = args_to_strings(&args);
987        assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "7.5");
988    }
989
990    #[test]
991    fn build_sdcli_args_per_job_sampling_method_overrides_model_default() {
992        let params = ImageParams {
993            prompt: "hi".into(),
994            sampling_method: Some("dpm++2m".into()),
995            ..Default::default()
996        };
997        let source = fake_source(vec![]);
998        let args = build_sdcli_args(
999            &params,
1000            &source,
1001            Path::new("/d.gguf"),
1002            None,
1003            None,
1004            None,
1005            Path::new("/tmp/out.webp"),
1006            None,
1007            None,
1008            None,
1009            false,
1010        );
1011        let s = args_to_strings(&args);
1012        assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "dpm++2m");
1013    }
1014
1015    #[test]
1016    fn build_sdcli_args_per_job_steps_overrides_when_non_default() {
1017        let params = ImageParams {
1018            prompt: "hi".into(),
1019            steps: 30, // != 20 → treat as caller override
1020            ..Default::default()
1021        };
1022        let source = fake_source(vec![]);
1023        let args = build_sdcli_args(
1024            &params,
1025            &source,
1026            Path::new("/d.gguf"),
1027            None,
1028            None,
1029            None,
1030            Path::new("/tmp/out.webp"),
1031            None,
1032            None,
1033            None,
1034            false,
1035        );
1036        let s = args_to_strings(&args);
1037        assert_eq!(s[idx_after(&s, "--steps").unwrap()], "30");
1038    }
1039
1040    #[test]
1041    fn build_sdcli_args_seed_included_when_set() {
1042        let params = ImageParams {
1043            prompt: "hi".into(),
1044            seed: Some(42),
1045            ..Default::default()
1046        };
1047        let source = fake_source(vec![]);
1048        let args = build_sdcli_args(
1049            &params,
1050            &source,
1051            Path::new("/d.gguf"),
1052            None,
1053            None,
1054            None,
1055            Path::new("/tmp/out.webp"),
1056            None,
1057            None,
1058            None,
1059            false,
1060        );
1061        let s = args_to_strings(&args);
1062        assert_eq!(s[idx_after(&s, "--seed").unwrap()], "42");
1063    }
1064
1065    /// A model source carrying the Qwen-Image-Edit flow flags.
1066    fn qwen_edit_source() -> ModelSource {
1067        ModelSource {
1068            engine: ModelEngine::SdCpp,
1069            files: vec![],
1070            cli_defaults: ModelCliDefaults {
1071                cfg_scale: 4.0,
1072                steps: 20,
1073                width: 1024,
1074                height: 1024,
1075                sampling_method: Some("euler".to_string()),
1076                flow_shift: Some(3.0),
1077                zero_cond_t: Some(true),
1078                offload_to_cpu: Some(true),
1079            },
1080        }
1081    }
1082
1083    #[test]
1084    fn build_sdcli_args_reference_mode_for_instruction_edit() {
1085        let params = ImageParams {
1086            prompt: "add a red beach ball".into(),
1087            denoise: Some(0.9),
1088            ..Default::default()
1089        };
1090        let source = qwen_edit_source();
1091        let args = build_sdcli_args(
1092            &params,
1093            &source,
1094            Path::new("/qwen.gguf"),
1095            Some(Path::new("/vae.safetensors")),
1096            Some(Path::new("/llm.gguf")),
1097            Some(Path::new("/mmproj.gguf")),
1098            Path::new("/tmp/out.webp"),
1099            None,
1100            Some(Path::new("/tmp/mask.png")),
1101            Some(Path::new("/tmp/ref.webp")),
1102            false,
1103        );
1104        let s = args_to_strings(&args);
1105        // Reference mode: `-r` set, a `--mask` constrains the edit region, and the img2img-only
1106        // `--init-img` / `--strength` flags are suppressed.
1107        assert_eq!(s[idx_after(&s, "-r").unwrap()], "/tmp/ref.webp");
1108        assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
1109        assert!(!s.contains(&"--init-img".to_string()));
1110        assert!(!s.contains(&"--strength".to_string()));
1111        // Vision encoder + Qwen flow flags emitted.
1112        assert_eq!(s[idx_after(&s, "--llm_vision").unwrap()], "/mmproj.gguf");
1113        assert_eq!(s[idx_after(&s, "--flow-shift").unwrap()], "3");
1114        assert!(s.contains(&"--qwen-image-zero-cond-t".to_string()));
1115        assert!(s.contains(&"--offload-to-cpu".to_string()));
1116    }
1117
1118    #[test]
1119    fn build_sdcli_args_omits_qwen_flags_for_plain_model() {
1120        let params = ImageParams {
1121            prompt: "hi".into(),
1122            ..Default::default()
1123        };
1124        // fake_source has no flow_shift / zero_cond_t / offload_to_cpu.
1125        let source = fake_source(vec![]);
1126        let args = build_sdcli_args(
1127            &params,
1128            &source,
1129            Path::new("/d.gguf"),
1130            None,
1131            None,
1132            None,
1133            Path::new("/tmp/out.webp"),
1134            None,
1135            None,
1136            None,
1137            false,
1138        );
1139        let s = args_to_strings(&args);
1140        assert!(!s.contains(&"--flow-shift".to_string()));
1141        assert!(!s.contains(&"--qwen-image-zero-cond-t".to_string()));
1142        assert!(!s.contains(&"--offload-to-cpu".to_string()));
1143        assert!(!s.contains(&"--llm_vision".to_string()));
1144        assert!(!s.contains(&"-r".to_string()));
1145    }
1146
1147    #[test]
1148    fn capabilities_advertises_only_image_kind() {
1149        let dir = tempdir().unwrap();
1150        let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
1151        let caps = engine.capabilities();
1152        assert!(caps
1153            .supported_models_per_kind
1154            .contains_key(&TaskKind::Image));
1155        assert_eq!(caps.supported_models_per_kind.len(), 1);
1156    }
1157
1158    #[test]
1159    fn init_image_extension_reads_url_tail() {
1160        assert_eq!(init_image_extension("https://x/y/latest.webp"), "webp");
1161        assert_eq!(init_image_extension("https://x/y/latest.PNG"), "png");
1162        assert_eq!(init_image_extension("https://x/y/latest.jpg"), "jpg");
1163        assert_eq!(init_image_extension("https://x/y/latest.jpeg"), "jpg");
1164        // Query strings + fragments don't trick the parser.
1165        assert_eq!(
1166            init_image_extension("https://x/y/latest.webp?v=42&t=now"),
1167            "webp"
1168        );
1169        assert_eq!(init_image_extension("https://x/y/latest.webp#frag"), "webp");
1170        // Unknown extension falls back to webp.
1171        assert_eq!(
1172            init_image_extension("https://x/y/latest.unknownext"),
1173            "webp"
1174        );
1175        assert_eq!(init_image_extension("https://x/y/no-ext"), "webp");
1176    }
1177}