1use crate::engine::download::{self, TempFileGuard};
32use crate::engine::sd_provision;
33use crate::engine::{Engine, EngineCapabilities};
34use crate::types::{ImageParams, ModelFileRole, ModelSource, Task, TaskKind, TaskResult};
35use anyhow::{anyhow, bail, Context, Result};
36use parking_lot::Mutex;
37use std::collections::BTreeMap;
38use std::ffi::OsString;
39use std::path::{Path, PathBuf};
40use std::process::Command;
41use std::time::Instant;
42use tracing::{debug, info, warn};
43
44const TRACE_TARGET: &str = "studio_worker::engine::sdcpp";
45
46const STEPS_FALLBACK: u32 = 8;
51
52pub struct SdCppEngine {
60 sd_cli: Mutex<Option<PathBuf>>,
61 models_root: PathBuf,
62}
63
64impl SdCppEngine {
65 pub fn new(models_root: &Path) -> Self {
72 info!(
73 target: TRACE_TARGET,
74 op = "register",
75 models_root = %models_root.display(),
76 sd_cli_name = sd_provision::binary_name(),
77 "sdcpp engine registered (sd-cli resolved/provisioned on first image job)"
78 );
79 Self {
80 sd_cli: Mutex::new(None),
81 models_root: models_root.to_path_buf(),
82 }
83 }
84
85 #[cfg(test)]
88 pub fn with_paths(sd_cli: PathBuf, models_root: PathBuf) -> Self {
89 Self {
90 sd_cli: Mutex::new(Some(sd_cli)),
91 models_root,
92 }
93 }
94
95 #[cfg_attr(coverage_nightly, coverage(off))]
102 fn ensure_sd_cli(&self) -> Result<PathBuf> {
103 let mut guard = self.sd_cli.lock();
104 if let Some(p) = guard.as_ref() {
105 if p.is_file() {
106 return Ok(p.clone());
107 }
108 }
109 let resolved = match resolve_sd_cli(&self.models_root) {
110 Some(p) => {
111 info!(
112 target: TRACE_TARGET,
113 op = "resolve",
114 sd_cli = %p.display(),
115 "using existing sd-cli"
116 );
117 p
118 }
119 None => sd_provision::provision(&self.models_root)
120 .context("auto-provisioning sd-cli (stable-diffusion.cpp)")?,
121 };
122 *guard = Some(resolved.clone());
123 Ok(resolved)
124 }
125
126 #[cfg_attr(coverage_nightly, coverage(off))]
130 fn ensure_files(&self, source: &ModelSource) -> Result<Vec<(ModelFileRole, PathBuf)>> {
131 let mut out = Vec::with_capacity(source.files.len());
132 for file in &source.files {
133 let local = download::ensure_file(&self.models_root, file)?;
134 out.push((file.role, local));
135 }
136 Ok(out)
137 }
138
139 #[cfg_attr(coverage_nightly, coverage(off))]
145 fn dispatch_image(
146 &self,
147 model: &str,
148 params: ImageParams,
149 source: &ModelSource,
150 ) -> Result<TaskResult> {
151 let sd_cli = self.ensure_sd_cli()?;
155 if let Err(e) = sd_provision::vulkan_runtime_status() {
160 warn!(
161 target: TRACE_TARGET,
162 op = "preflight",
163 model,
164 error = %e,
165 "GPU runtime missing; refusing image job"
166 );
167 return Err(e);
168 }
169 let files = self.ensure_files(source)?;
170 let diffusion_only = file_for_role(&files, ModelFileRole::DiffusionModel);
174 let full_checkpoint = diffusion_only.is_none();
175 let diffusion_model = diffusion_only
176 .or_else(|| file_for_role(&files, ModelFileRole::Model))
177 .ok_or_else(|| anyhow!("modelSource has no diffusion-model / model file"))?;
178 let vae = file_for_role(&files, ModelFileRole::Vae);
179 let text_encoder = file_for_role(&files, ModelFileRole::TextEncoder);
180 let text_encoder_vision = file_for_role(&files, ModelFileRole::TextEncoderVision);
181
182 let out_dir = std::env::temp_dir().join("studio-worker-sdcpp");
183 std::fs::create_dir_all(&out_dir)
184 .with_context(|| format!("creating sdcpp output dir {}", out_dir.display()))?;
185 let stem = format!(
186 "out-{}-{}",
187 std::process::id(),
188 chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
189 );
190 let out_path = out_dir.join(format!("{stem}.webp"));
191
192 let mut temp_files = TempFileGuard::new();
196 temp_files.push(out_path.clone());
197
198 let init_img_path = match params.init_image_url.as_deref() {
205 Some(url) if !url.is_empty() => {
206 let ext = init_image_extension(url);
207 let init_path = out_dir.join(format!("{stem}-init.{ext}"));
208 download::download_file(url, &init_path).with_context(|| {
209 format!("downloading init image {} -> {}", url, init_path.display())
210 })?;
211 temp_files.push(init_path.clone());
212 Some(init_path)
213 }
214 _ => None,
215 };
216
217 let has_base = init_img_path.is_some() || params.ref_image_url.as_deref().is_some();
221 let mask_path = match (has_base, params.mask_url.as_deref()) {
222 (true, Some(url)) if !url.is_empty() => {
223 let ext = init_image_extension(url);
224 let path = out_dir.join(format!("{stem}-mask.{ext}"));
225 download::download_file(url, &path)
226 .with_context(|| format!("downloading mask {} -> {}", url, path.display()))?;
227 temp_files.push(path.clone());
228 Some(path)
229 }
230 _ => None,
231 };
232
233 let ref_img_path = match params.ref_image_url.as_deref() {
236 Some(url) if !url.is_empty() => {
237 let ext = init_image_extension(url);
238 let path = out_dir.join(format!("{stem}-ref.{ext}"));
239 download::download_file(url, &path).with_context(|| {
240 format!("downloading reference image {} -> {}", url, path.display())
241 })?;
242 temp_files.push(path.clone());
243 Some(path)
244 }
245 _ => None,
246 };
247
248 let args = build_sdcli_args(
249 ¶ms,
250 source,
251 diffusion_model,
252 vae,
253 text_encoder,
254 text_encoder_vision,
255 &out_path,
256 init_img_path.as_deref(),
257 mask_path.as_deref(),
258 ref_img_path.as_deref(),
259 full_checkpoint,
260 );
261 let mut cmd = Command::new(&sd_cli);
262 cmd.args(&args);
263 apply_library_path(&mut cmd, &sd_cli);
264
265 debug!(
266 target: TRACE_TARGET,
267 op = "spawn",
268 sd_cli = %sd_cli.display(),
269 model,
270 i2i = init_img_path.is_some(),
271 arg_count = args.len(),
272 "running sd-cli"
273 );
274
275 let started = Instant::now();
276 let output = cmd
277 .output()
278 .with_context(|| format!("running {}", sd_cli.display()))?;
279 let elapsed_ms = started.elapsed().as_millis() as u64;
280 if !output.status.success() {
281 let stderr = String::from_utf8_lossy(&output.stderr);
282 warn!(
283 target: TRACE_TARGET,
284 op = "spawn",
285 model,
286 elapsed_ms,
287 exit = ?output.status.code(),
288 stderr = %stderr,
289 "sd-cli failed"
290 );
291 bail!(
292 "sd-cli exited with {:?}: {}",
293 output.status.code(),
294 stderr.lines().last().unwrap_or("(no stderr)")
295 );
296 }
297
298 let bytes = std::fs::read(&out_path)
299 .with_context(|| format!("reading sd-cli output at {}", out_path.display()))?;
300 info!(
301 target: TRACE_TARGET,
302 op = "dispatch",
303 model,
304 elapsed_ms,
305 bytes = bytes.len(),
306 "ok"
307 );
308
309 Ok(TaskResult::Image {
310 bytes,
311 ext: "webp".to_string(),
312 })
313 }
314}
315
316impl Engine for SdCppEngine {
317 fn name(&self) -> &'static str {
318 "sdcpp"
319 }
320
321 fn capabilities(&self) -> EngineCapabilities {
322 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
328 map.insert(TaskKind::Image, vec!["sd-cpp:*".to_string()]);
329 EngineCapabilities {
330 supported_models_per_kind: map,
331 }
332 }
333
334 fn dispatch(&self, _model: &str, _task: Task) -> Result<TaskResult> {
335 bail!(
336 "sdcpp engine requires a ModelSource on the offer; legacy push-based offers \
337 (no modelSource) cannot be served - re-promote the job through the studio"
338 )
339 }
340
341 fn dispatch_with_source(
342 &self,
343 model: &str,
344 task: Task,
345 source: &ModelSource,
346 ) -> Result<TaskResult> {
347 match task {
348 Task::Image(p) => self.dispatch_image(model, p, source),
349 other => {
350 let kind = other.kind();
356 warn!(
357 target: TRACE_TARGET,
358 op = "dispatch",
359 model,
360 kind = kind.as_str(),
361 "sdcpp engine only serves image jobs"
362 );
363 Err(crate::engine::UnsupportedTask::new("sdcpp", kind).into())
364 }
365 }
366 }
367}
368
369fn file_for_role(files: &[(ModelFileRole, PathBuf)], role: ModelFileRole) -> Option<&Path> {
378 files
379 .iter()
380 .find(|(r, _)| *r == role)
381 .map(|(_, p)| p.as_path())
382}
383
384fn resolve_image_args(params: &ImageParams, source: &ModelSource) -> ResolvedImageArgs {
389 let width = if params.width > 0 {
390 params.width
391 } else if source.cli_defaults.width > 0 {
392 source.cli_defaults.width
393 } else {
394 1024
395 };
396 let height = if params.height > 0 {
397 params.height
398 } else if source.cli_defaults.height > 0 {
399 source.cli_defaults.height
400 } else {
401 1024
402 };
403 let steps = if params.steps > 0 && params.steps != 20 {
407 params.steps
408 } else if source.cli_defaults.steps > 0 {
409 source.cli_defaults.steps
410 } else {
411 STEPS_FALLBACK
412 };
413 let source_cfg = if source.cli_defaults.cfg_scale > 0.0 {
414 source.cli_defaults.cfg_scale
415 } else {
416 1.0
417 };
418 let cfg_scale = params.cfg_scale.filter(|v| *v > 0.0).unwrap_or(source_cfg);
419 let sampling_method = params
420 .sampling_method
421 .clone()
422 .or_else(|| source.cli_defaults.sampling_method.clone());
423 ResolvedImageArgs {
424 width,
425 height,
426 steps,
427 cfg_scale,
428 sampling_method,
429 }
430}
431
432#[derive(Debug, Clone, PartialEq)]
434struct ResolvedImageArgs {
435 width: u32,
436 height: u32,
437 steps: u32,
438 cfg_scale: f32,
439 sampling_method: Option<String>,
440}
441
442#[allow(clippy::too_many_arguments)]
449fn build_sdcli_args(
450 params: &ImageParams,
451 source: &ModelSource,
452 diffusion_model: &Path,
453 vae: Option<&Path>,
454 text_encoder: Option<&Path>,
455 text_encoder_vision: Option<&Path>,
456 out_path: &Path,
457 init_img_path: Option<&Path>,
458 mask_path: Option<&Path>,
459 ref_img_path: Option<&Path>,
460 full_checkpoint: bool,
461) -> Vec<OsString> {
462 let resolved = resolve_image_args(params, source);
463 let mut args: Vec<OsString> = Vec::with_capacity(32);
464
465 args.push(
468 if full_checkpoint {
469 "--model"
470 } else {
471 "--diffusion-model"
472 }
473 .into(),
474 );
475 args.push(diffusion_model.into());
476 if let Some(p) = vae {
477 args.push("--vae".into());
478 args.push(p.into());
479 }
480 if let Some(p) = text_encoder {
481 args.push("--llm".into());
482 args.push(p.into());
483 }
484 if let Some(p) = text_encoder_vision {
485 args.push("--llm_vision".into());
486 args.push(p.into());
487 }
488 args.push("-p".into());
489 args.push((¶ms.prompt as &str).into());
490 if let Some(neg) = params.negative_prompt.as_deref() {
491 if !neg.is_empty() {
492 args.push("--negative-prompt".into());
493 args.push(neg.into());
494 }
495 }
496 if let Some(reference) = ref_img_path {
497 args.push("-r".into());
503 args.push(reference.into());
504 if let Some(mask) = mask_path {
505 args.push("--mask".into());
506 args.push(mask.into());
507 }
508 } else if let Some(init) = init_img_path {
509 args.push("--init-img".into());
510 args.push(init.into());
511 let strength = params.denoise.unwrap_or(0.75);
515 args.push("--strength".into());
516 args.push(strength.to_string().into());
517 if let Some(mask) = mask_path {
519 args.push("--mask".into());
520 args.push(mask.into());
521 }
522 }
523 args.push("--cfg-scale".into());
524 args.push(resolved.cfg_scale.to_string().into());
525 args.push("--steps".into());
526 args.push(resolved.steps.to_string().into());
527 args.push("-W".into());
528 args.push(resolved.width.to_string().into());
529 args.push("-H".into());
530 args.push(resolved.height.to_string().into());
531 args.push("-o".into());
532 args.push(out_path.into());
533 if let Some(seed) = params.seed {
534 args.push("--seed".into());
535 args.push(seed.to_string().into());
536 }
537 if let Some(method) = resolved.sampling_method.as_deref() {
538 args.push("--sampling-method".into());
539 args.push(method.into());
540 }
541 if let Some(shift) = source.cli_defaults.flow_shift {
544 args.push("--flow-shift".into());
545 args.push(shift.to_string().into());
546 }
547 if source.cli_defaults.zero_cond_t == Some(true) {
548 args.push("--qwen-image-zero-cond-t".into());
549 }
550 if source.cli_defaults.offload_to_cpu == Some(true) {
551 args.push("--offload-to-cpu".into());
552 }
553 args.push("--diffusion-fa".into());
555 args
556}
557
558#[cfg_attr(coverage_nightly, coverage(off))]
565fn apply_library_path(cmd: &mut Command, sd_cli: &Path) {
566 let Some((var, dir)) = sd_provision::library_path_env(sd_cli) else {
567 return;
568 };
569 let value = match std::env::var_os(var) {
570 Some(existing) => {
571 let mut paths = vec![dir.clone()];
572 paths.extend(std::env::split_paths(&existing));
573 std::env::join_paths(paths).unwrap_or_else(|_| dir.into_os_string())
577 }
578 None => dir.into_os_string(),
579 };
580 cmd.env(var, value);
581}
582
583#[cfg_attr(coverage_nightly, coverage(off))]
590fn resolve_sd_cli(models_root: &Path) -> Option<PathBuf> {
591 let bin = sd_provision::binary_name();
592 if let Ok(p) = std::env::var("STUDIO_WORKER_SD_CLI") {
593 let path = PathBuf::from(p);
594 if path.is_file() {
595 return Some(path);
596 }
597 }
598 let in_models = models_root.join("bin").join(bin);
599 if in_models.is_file() {
600 return Some(in_models);
601 }
602 if let Some(home) = std::env::var_os("HOME") {
603 let candidate = PathBuf::from(home).join(".local/bin").join(bin);
604 if candidate.is_file() {
605 return Some(candidate);
606 }
607 }
608 which(bin)
609}
610
611#[cfg_attr(coverage_nightly, coverage(off))]
614fn which(bin: &str) -> Option<PathBuf> {
615 let path = std::env::var_os("PATH")?;
616 for entry in std::env::split_paths(&path) {
617 let candidate = entry.join(bin);
618 if candidate.is_file() {
619 return Some(candidate);
620 }
621 }
622 None
623}
624
625fn init_image_extension(url: &str) -> &'static str {
630 let path = url.split(['?', '#']).next().unwrap_or(url);
631 let lower_tail = path
632 .rsplit('.')
633 .next()
634 .map(|t| t.to_ascii_lowercase())
635 .unwrap_or_default();
636 match lower_tail.as_str() {
637 "png" => "png",
638 "jpg" | "jpeg" => "jpg",
639 "webp" => "webp",
640 "bmp" => "bmp",
641 "gif" => "gif",
642 "tif" | "tiff" => "tif",
643 _ => "webp",
644 }
645}
646
647#[cfg(test)]
652mod tests {
653 use super::*;
654 use crate::types::{ModelCliDefaults, ModelEngine, ModelFile, ModelFileRole};
655 use tempfile::tempdir;
656
657 fn fake_source(files: Vec<ModelFile>) -> ModelSource {
658 ModelSource {
659 engine: ModelEngine::SdCpp,
660 files,
661 cli_defaults: ModelCliDefaults {
662 cfg_scale: 1.0,
663 steps: 8,
664 width: 1024,
665 height: 1024,
666 sampling_method: Some("euler".to_string()),
667 ..Default::default()
668 },
669 }
670 }
671
672 #[test]
673 fn file_for_role_picks_matching_file() {
674 let files = vec![
675 (ModelFileRole::DiffusionModel, PathBuf::from("/d.gguf")),
676 (ModelFileRole::Vae, PathBuf::from("/v.safetensors")),
677 ];
678 assert_eq!(
679 file_for_role(&files, ModelFileRole::DiffusionModel),
680 Some(Path::new("/d.gguf"))
681 );
682 assert_eq!(
683 file_for_role(&files, ModelFileRole::Vae),
684 Some(Path::new("/v.safetensors"))
685 );
686 assert!(file_for_role(&files, ModelFileRole::TextEncoder).is_none());
687 }
688
689 #[test]
690 fn ensure_files_skips_already_present() {
691 let dir = tempdir().unwrap();
692 let cached = dir.path().join("cached.gguf");
693 std::fs::write(&cached, b"already here").unwrap();
694 let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
695 let source = fake_source(vec![ModelFile {
696 role: ModelFileRole::DiffusionModel,
697 url: "https://example.invalid/cached.gguf".into(),
698 filename: "cached.gguf".into(),
699 approx_bytes: None,
700 sha256: None,
701 }]);
702 let resolved = engine.ensure_files(&source).expect("cached file used");
703 assert_eq!(resolved.len(), 1);
704 assert_eq!(resolved[0].0, ModelFileRole::DiffusionModel);
705 assert_eq!(resolved[0].1, cached);
706 assert_eq!(std::fs::read(&cached).unwrap(), b"already here");
708 }
709
710 #[test]
711 fn dispatch_rejects_non_image_tasks() {
712 use crate::types::AudioTtsParams;
713 let dir = tempdir().unwrap();
714 let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
715 let task = Task::AudioTts(AudioTtsParams {
716 text: "hi".into(),
717 voice: "v".into(),
718 ext: "wav".into(),
719 ..Default::default()
720 });
721 let source = fake_source(vec![]);
722 let err = engine
723 .dispatch_with_source("anything", task, &source)
724 .unwrap_err();
725 assert!(err.to_string().contains("cannot serve audio_tts"));
726 }
727
728 fn args_to_strings(args: &[OsString]) -> Vec<String> {
738 args.iter()
739 .map(|s| s.to_string_lossy().into_owned())
740 .collect()
741 }
742
743 fn idx_after(args: &[String], flag: &str) -> Option<usize> {
744 args.iter().position(|a| a == flag).map(|i| i + 1)
745 }
746
747 #[test]
748 fn build_sdcli_args_includes_required_flags() {
749 let params = ImageParams {
750 prompt: "hello".into(),
751 width: 768,
752 height: 512,
753 steps: 20, ..Default::default()
755 };
756 let source = fake_source(vec![]);
757 let args = build_sdcli_args(
758 ¶ms,
759 &source,
760 Path::new("/d.gguf"),
761 Some(Path::new("/v.safetensors")),
762 Some(Path::new("/llm.gguf")),
763 None,
764 Path::new("/tmp/out.webp"),
765 None,
766 None,
767 None,
768 false,
769 );
770 let s = args_to_strings(&args);
771 assert_eq!(s[idx_after(&s, "--diffusion-model").unwrap()], "/d.gguf");
772 assert_eq!(s[idx_after(&s, "--vae").unwrap()], "/v.safetensors");
773 assert_eq!(s[idx_after(&s, "--llm").unwrap()], "/llm.gguf");
774 assert_eq!(s[idx_after(&s, "-p").unwrap()], "hello");
775 assert_eq!(s[idx_after(&s, "-W").unwrap()], "768");
776 assert_eq!(s[idx_after(&s, "-H").unwrap()], "512");
777 assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "1");
779 assert_eq!(s[idx_after(&s, "--steps").unwrap()], "8");
781 assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "euler");
782 assert_eq!(s[idx_after(&s, "-o").unwrap()], "/tmp/out.webp");
783 assert!(s.contains(&"--diffusion-fa".to_string()));
784 assert!(!s.contains(&"--init-img".to_string()));
786 assert!(!s.contains(&"--strength".to_string()));
787 }
788
789 #[test]
790 fn build_sdcli_args_includes_negative_prompt_when_set() {
791 let params = ImageParams {
792 prompt: "hi".into(),
793 negative_prompt: Some("text, watermark, low quality".into()),
794 ..Default::default()
795 };
796 let source = fake_source(vec![]);
797 let args = build_sdcli_args(
798 ¶ms,
799 &source,
800 Path::new("/d.gguf"),
801 None,
802 None,
803 None,
804 Path::new("/tmp/out.webp"),
805 None,
806 None,
807 None,
808 false,
809 );
810 let s = args_to_strings(&args);
811 assert_eq!(
812 s[idx_after(&s, "--negative-prompt").unwrap()],
813 "text, watermark, low quality"
814 );
815 }
816
817 #[test]
818 fn build_sdcli_args_omits_negative_prompt_when_empty_string() {
819 let params = ImageParams {
820 prompt: "hi".into(),
821 negative_prompt: Some(String::new()),
822 ..Default::default()
823 };
824 let source = fake_source(vec![]);
825 let args = build_sdcli_args(
826 ¶ms,
827 &source,
828 Path::new("/d.gguf"),
829 None,
830 None,
831 None,
832 Path::new("/tmp/out.webp"),
833 None,
834 None,
835 None,
836 false,
837 );
838 let s = args_to_strings(&args);
839 assert!(!s.contains(&"--negative-prompt".to_string()));
840 }
841
842 #[test]
843 fn build_sdcli_args_includes_init_image_and_strength() {
844 let params = ImageParams {
845 prompt: "hi".into(),
846 denoise: Some(0.55),
847 ..Default::default()
848 };
849 let source = fake_source(vec![]);
850 let args = build_sdcli_args(
851 ¶ms,
852 &source,
853 Path::new("/d.gguf"),
854 None,
855 None,
856 None,
857 Path::new("/tmp/out.webp"),
858 Some(Path::new("/tmp/init.webp")),
859 None,
860 None,
861 false,
862 );
863 let s = args_to_strings(&args);
864 assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
865 assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.55");
866 assert!(!s.contains(&"--mask".to_string()));
868 }
869
870 #[test]
871 fn build_sdcli_args_includes_mask_for_inpaint() {
872 let params = ImageParams {
873 prompt: "remove the tree".into(),
874 denoise: Some(0.8),
875 ..Default::default()
876 };
877 let source = fake_source(vec![]);
878 let args = build_sdcli_args(
879 ¶ms,
880 &source,
881 Path::new("/d.gguf"),
882 None,
883 None,
884 None,
885 Path::new("/tmp/out.webp"),
886 Some(Path::new("/tmp/init.webp")),
887 Some(Path::new("/tmp/mask.png")),
888 None,
889 false,
890 );
891 let s = args_to_strings(&args);
892 assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
893 assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
894 assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.8");
895 }
896
897 #[test]
898 fn build_sdcli_args_uses_model_flag_for_full_checkpoint() {
899 let params = ImageParams {
900 prompt: "hi".into(),
901 ..Default::default()
902 };
903 let source = fake_source(vec![]);
904 let args = build_sdcli_args(
905 ¶ms,
906 &source,
907 Path::new("/checkpoint.safetensors"),
908 Some(Path::new("/v.safetensors")),
909 None,
910 None,
911 Path::new("/tmp/out.webp"),
912 None,
913 None,
914 None,
915 true,
916 );
917 let s = args_to_strings(&args);
918 assert_eq!(
920 s[idx_after(&s, "--model").unwrap()],
921 "/checkpoint.safetensors"
922 );
923 assert!(!s.contains(&"--diffusion-model".to_string()));
924 }
925
926 #[test]
927 fn build_sdcli_args_defaults_denoise_when_init_image_present_but_denoise_none() {
928 let params = ImageParams {
929 prompt: "hi".into(),
930 denoise: None,
931 ..Default::default()
932 };
933 let source = fake_source(vec![]);
934 let args = build_sdcli_args(
935 ¶ms,
936 &source,
937 Path::new("/d.gguf"),
938 None,
939 None,
940 None,
941 Path::new("/tmp/out.webp"),
942 Some(Path::new("/tmp/init.webp")),
943 None,
944 None,
945 false,
946 );
947 let s = args_to_strings(&args);
948 assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.75");
949 }
950
951 #[test]
952 fn build_sdcli_args_per_job_cfg_scale_overrides_model_default() {
953 let params = ImageParams {
954 prompt: "hi".into(),
955 cfg_scale: Some(7.5),
956 ..Default::default()
957 };
958 let source = fake_source(vec![]);
959 let args = build_sdcli_args(
960 ¶ms,
961 &source,
962 Path::new("/d.gguf"),
963 None,
964 None,
965 None,
966 Path::new("/tmp/out.webp"),
967 None,
968 None,
969 None,
970 false,
971 );
972 let s = args_to_strings(&args);
973 assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "7.5");
974 }
975
976 #[test]
977 fn build_sdcli_args_per_job_sampling_method_overrides_model_default() {
978 let params = ImageParams {
979 prompt: "hi".into(),
980 sampling_method: Some("dpm++2m".into()),
981 ..Default::default()
982 };
983 let source = fake_source(vec![]);
984 let args = build_sdcli_args(
985 ¶ms,
986 &source,
987 Path::new("/d.gguf"),
988 None,
989 None,
990 None,
991 Path::new("/tmp/out.webp"),
992 None,
993 None,
994 None,
995 false,
996 );
997 let s = args_to_strings(&args);
998 assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "dpm++2m");
999 }
1000
1001 #[test]
1002 fn build_sdcli_args_per_job_steps_overrides_when_non_default() {
1003 let params = ImageParams {
1004 prompt: "hi".into(),
1005 steps: 30, ..Default::default()
1007 };
1008 let source = fake_source(vec![]);
1009 let args = build_sdcli_args(
1010 ¶ms,
1011 &source,
1012 Path::new("/d.gguf"),
1013 None,
1014 None,
1015 None,
1016 Path::new("/tmp/out.webp"),
1017 None,
1018 None,
1019 None,
1020 false,
1021 );
1022 let s = args_to_strings(&args);
1023 assert_eq!(s[idx_after(&s, "--steps").unwrap()], "30");
1024 }
1025
1026 #[test]
1027 fn build_sdcli_args_seed_included_when_set() {
1028 let params = ImageParams {
1029 prompt: "hi".into(),
1030 seed: Some(42),
1031 ..Default::default()
1032 };
1033 let source = fake_source(vec![]);
1034 let args = build_sdcli_args(
1035 ¶ms,
1036 &source,
1037 Path::new("/d.gguf"),
1038 None,
1039 None,
1040 None,
1041 Path::new("/tmp/out.webp"),
1042 None,
1043 None,
1044 None,
1045 false,
1046 );
1047 let s = args_to_strings(&args);
1048 assert_eq!(s[idx_after(&s, "--seed").unwrap()], "42");
1049 }
1050
1051 fn qwen_edit_source() -> ModelSource {
1053 ModelSource {
1054 engine: ModelEngine::SdCpp,
1055 files: vec![],
1056 cli_defaults: ModelCliDefaults {
1057 cfg_scale: 4.0,
1058 steps: 20,
1059 width: 1024,
1060 height: 1024,
1061 sampling_method: Some("euler".to_string()),
1062 flow_shift: Some(3.0),
1063 zero_cond_t: Some(true),
1064 offload_to_cpu: Some(true),
1065 },
1066 }
1067 }
1068
1069 #[test]
1070 fn build_sdcli_args_reference_mode_for_instruction_edit() {
1071 let params = ImageParams {
1072 prompt: "add a red beach ball".into(),
1073 denoise: Some(0.9),
1074 ..Default::default()
1075 };
1076 let source = qwen_edit_source();
1077 let args = build_sdcli_args(
1078 ¶ms,
1079 &source,
1080 Path::new("/qwen.gguf"),
1081 Some(Path::new("/vae.safetensors")),
1082 Some(Path::new("/llm.gguf")),
1083 Some(Path::new("/mmproj.gguf")),
1084 Path::new("/tmp/out.webp"),
1085 None,
1086 Some(Path::new("/tmp/mask.png")),
1087 Some(Path::new("/tmp/ref.webp")),
1088 false,
1089 );
1090 let s = args_to_strings(&args);
1091 assert_eq!(s[idx_after(&s, "-r").unwrap()], "/tmp/ref.webp");
1094 assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
1095 assert!(!s.contains(&"--init-img".to_string()));
1096 assert!(!s.contains(&"--strength".to_string()));
1097 assert_eq!(s[idx_after(&s, "--llm_vision").unwrap()], "/mmproj.gguf");
1099 assert_eq!(s[idx_after(&s, "--flow-shift").unwrap()], "3");
1100 assert!(s.contains(&"--qwen-image-zero-cond-t".to_string()));
1101 assert!(s.contains(&"--offload-to-cpu".to_string()));
1102 }
1103
1104 #[test]
1105 fn build_sdcli_args_omits_qwen_flags_for_plain_model() {
1106 let params = ImageParams {
1107 prompt: "hi".into(),
1108 ..Default::default()
1109 };
1110 let source = fake_source(vec![]);
1112 let args = build_sdcli_args(
1113 ¶ms,
1114 &source,
1115 Path::new("/d.gguf"),
1116 None,
1117 None,
1118 None,
1119 Path::new("/tmp/out.webp"),
1120 None,
1121 None,
1122 None,
1123 false,
1124 );
1125 let s = args_to_strings(&args);
1126 assert!(!s.contains(&"--flow-shift".to_string()));
1127 assert!(!s.contains(&"--qwen-image-zero-cond-t".to_string()));
1128 assert!(!s.contains(&"--offload-to-cpu".to_string()));
1129 assert!(!s.contains(&"--llm_vision".to_string()));
1130 assert!(!s.contains(&"-r".to_string()));
1131 }
1132
1133 #[test]
1134 fn capabilities_advertises_only_image_kind() {
1135 let dir = tempdir().unwrap();
1136 let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
1137 let caps = engine.capabilities();
1138 assert!(caps
1139 .supported_models_per_kind
1140 .contains_key(&TaskKind::Image));
1141 assert_eq!(caps.supported_models_per_kind.len(), 1);
1142 }
1143
1144 #[test]
1145 fn init_image_extension_reads_url_tail() {
1146 assert_eq!(init_image_extension("https://x/y/latest.webp"), "webp");
1147 assert_eq!(init_image_extension("https://x/y/latest.PNG"), "png");
1148 assert_eq!(init_image_extension("https://x/y/latest.jpg"), "jpg");
1149 assert_eq!(init_image_extension("https://x/y/latest.jpeg"), "jpg");
1150 assert_eq!(
1152 init_image_extension("https://x/y/latest.webp?v=42&t=now"),
1153 "webp"
1154 );
1155 assert_eq!(init_image_extension("https://x/y/latest.webp#frag"), "webp");
1156 assert_eq!(
1158 init_image_extension("https://x/y/latest.unknownext"),
1159 "webp"
1160 );
1161 assert_eq!(init_image_extension("https://x/y/no-ext"), "webp");
1162 }
1163}