1use crate::engine::download;
32use crate::engine::{Engine, EngineCapabilities};
33use crate::types::{ImageParams, ModelFileRole, ModelSource, Task, TaskKind, TaskResult};
34use anyhow::{anyhow, bail, Context, Result};
35use std::collections::BTreeMap;
36use std::ffi::OsString;
37use std::path::{Path, PathBuf};
38use std::process::Command;
39use std::time::Instant;
40use tracing::{debug, info, warn};
41
42const TRACE_TARGET: &str = "studio_worker::engine::sdcpp";
43
44const STEPS_FALLBACK: u32 = 8;
49
50pub struct SdCppEngine {
52 sd_cli: PathBuf,
53 models_root: PathBuf,
54}
55
56impl SdCppEngine {
57 #[cfg_attr(coverage_nightly, coverage(off))]
61 pub fn try_new(models_root: &Path) -> Option<Self> {
62 let Some(sd_cli) = resolve_sd_cli(models_root) else {
63 info!(
69 target: TRACE_TARGET,
70 op = "register",
71 models_root = %models_root.display(),
72 sd_cli_name = sd_cli_binary_name(),
73 "sd-cli not found (checked $STUDIO_WORKER_SD_CLI, \
74 <models_root>/bin, ~/.local/bin, and $PATH); real image \
75 generation is disabled on this worker until it is \
76 installed — see docs/operations/sd-cli-install.md"
77 );
78 return None;
79 };
80 if let Err(e) = std::fs::create_dir_all(models_root) {
81 warn!(
82 target: TRACE_TARGET,
83 models_root = %models_root.display(),
84 error = %e,
85 "could not create models_root; skipping sdcpp registration"
86 );
87 return None;
88 }
89 info!(
90 target: TRACE_TARGET,
91 sd_cli = %sd_cli.display(),
92 models_root = %models_root.display(),
93 "sdcpp engine registered"
94 );
95 Some(Self {
96 sd_cli,
97 models_root: models_root.to_path_buf(),
98 })
99 }
100
101 #[cfg(test)]
103 pub fn with_paths(sd_cli: PathBuf, models_root: PathBuf) -> Self {
104 Self {
105 sd_cli,
106 models_root,
107 }
108 }
109
110 #[cfg_attr(coverage_nightly, coverage(off))]
114 fn ensure_files(&self, source: &ModelSource) -> Result<Vec<(ModelFileRole, PathBuf)>> {
115 let mut out = Vec::with_capacity(source.files.len());
116 for file in &source.files {
117 let local = download::ensure_file(&self.models_root, &file.filename, &file.url)?;
118 out.push((file.role, local));
119 }
120 Ok(out)
121 }
122
123 #[cfg_attr(coverage_nightly, coverage(off))]
129 fn dispatch_image(
130 &self,
131 model: &str,
132 params: ImageParams,
133 source: &ModelSource,
134 ) -> Result<TaskResult> {
135 let files = self.ensure_files(source)?;
136 let diffusion_only = file_for_role(&files, ModelFileRole::DiffusionModel);
140 let full_checkpoint = diffusion_only.is_none();
141 let diffusion_model = diffusion_only
142 .or_else(|| file_for_role(&files, ModelFileRole::Model))
143 .ok_or_else(|| anyhow!("modelSource has no diffusion-model / model file"))?;
144 let vae = file_for_role(&files, ModelFileRole::Vae);
145 let text_encoder = file_for_role(&files, ModelFileRole::TextEncoder);
146 let text_encoder_vision = file_for_role(&files, ModelFileRole::TextEncoderVision);
147
148 let out_dir = std::env::temp_dir().join("studio-worker-sdcpp");
149 std::fs::create_dir_all(&out_dir)
150 .with_context(|| format!("creating sdcpp output dir {}", out_dir.display()))?;
151 let stem = format!(
152 "out-{}-{}",
153 std::process::id(),
154 chrono::Utc::now().timestamp_nanos_opt().unwrap_or_default()
155 );
156 let out_path = out_dir.join(format!("{stem}.webp"));
157
158 let mut temp_files = TempFileGuard::new();
162 temp_files.push(out_path.clone());
163
164 let init_img_path = match params.init_image_url.as_deref() {
171 Some(url) if !url.is_empty() => {
172 let ext = init_image_extension(url);
173 let init_path = out_dir.join(format!("{stem}-init.{ext}"));
174 download::download_file(url, &init_path).with_context(|| {
175 format!("downloading init image {} -> {}", url, init_path.display())
176 })?;
177 temp_files.push(init_path.clone());
178 Some(init_path)
179 }
180 _ => None,
181 };
182
183 let has_base = init_img_path.is_some() || params.ref_image_url.as_deref().is_some();
187 let mask_path = match (has_base, params.mask_url.as_deref()) {
188 (true, Some(url)) if !url.is_empty() => {
189 let ext = init_image_extension(url);
190 let path = out_dir.join(format!("{stem}-mask.{ext}"));
191 download::download_file(url, &path)
192 .with_context(|| format!("downloading mask {} -> {}", url, path.display()))?;
193 temp_files.push(path.clone());
194 Some(path)
195 }
196 _ => None,
197 };
198
199 let ref_img_path = match params.ref_image_url.as_deref() {
202 Some(url) if !url.is_empty() => {
203 let ext = init_image_extension(url);
204 let path = out_dir.join(format!("{stem}-ref.{ext}"));
205 download::download_file(url, &path).with_context(|| {
206 format!("downloading reference image {} -> {}", url, path.display())
207 })?;
208 temp_files.push(path.clone());
209 Some(path)
210 }
211 _ => None,
212 };
213
214 let args = build_sdcli_args(
215 ¶ms,
216 source,
217 diffusion_model,
218 vae,
219 text_encoder,
220 text_encoder_vision,
221 &out_path,
222 init_img_path.as_deref(),
223 mask_path.as_deref(),
224 ref_img_path.as_deref(),
225 full_checkpoint,
226 );
227 let mut cmd = Command::new(&self.sd_cli);
228 cmd.args(&args);
229
230 debug!(
231 target: TRACE_TARGET,
232 op = "spawn",
233 sd_cli = %self.sd_cli.display(),
234 model,
235 i2i = init_img_path.is_some(),
236 arg_count = args.len(),
237 "running sd-cli"
238 );
239
240 let started = Instant::now();
241 let output = cmd
242 .output()
243 .with_context(|| format!("running {}", self.sd_cli.display()))?;
244 let elapsed_ms = started.elapsed().as_millis() as u64;
245 if !output.status.success() {
246 let stderr = String::from_utf8_lossy(&output.stderr);
247 warn!(
248 target: TRACE_TARGET,
249 op = "spawn",
250 model,
251 elapsed_ms,
252 exit = ?output.status.code(),
253 stderr = %stderr,
254 "sd-cli failed"
255 );
256 bail!(
257 "sd-cli exited with {:?}: {}",
258 output.status.code(),
259 stderr.lines().last().unwrap_or("(no stderr)")
260 );
261 }
262
263 let bytes = std::fs::read(&out_path)
264 .with_context(|| format!("reading sd-cli output at {}", out_path.display()))?;
265 info!(
266 target: TRACE_TARGET,
267 op = "dispatch",
268 model,
269 elapsed_ms,
270 bytes = bytes.len(),
271 "ok"
272 );
273
274 Ok(TaskResult::Image {
275 bytes,
276 ext: "webp".to_string(),
277 })
278 }
279}
280
281impl Engine for SdCppEngine {
282 fn name(&self) -> &'static str {
283 "sdcpp"
284 }
285
286 fn capabilities(&self) -> EngineCapabilities {
287 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
293 map.insert(TaskKind::Image, vec!["sd-cpp:*".to_string()]);
294 EngineCapabilities {
295 supported_models_per_kind: map,
296 }
297 }
298
299 fn dispatch(&self, _model: &str, _task: Task) -> Result<TaskResult> {
300 bail!(
301 "sdcpp engine requires a ModelSource on the offer; legacy push-based offers \
302 (no modelSource) cannot be served - re-promote the job through the studio"
303 )
304 }
305
306 fn dispatch_with_source(
307 &self,
308 model: &str,
309 task: Task,
310 source: &ModelSource,
311 ) -> Result<TaskResult> {
312 let kind = task.kind();
313 match task {
314 Task::Image(p) => self.dispatch_image(model, p, source),
315 _ => bail!("sdcpp engine cannot serve {} tasks", kind.as_str()),
316 }
317 }
318}
319
320fn remove_temp_file(path: &Path) {
333 if let Err(e) = std::fs::remove_file(path) {
334 if e.kind() != std::io::ErrorKind::NotFound {
335 warn!(
336 target: TRACE_TARGET,
337 op = "cleanup",
338 path = %path.display(),
339 error = %e,
340 "failed to remove temp file"
341 );
342 }
343 }
344}
345
346struct TempFileGuard {
355 paths: Vec<PathBuf>,
356}
357
358impl TempFileGuard {
359 fn new() -> Self {
360 Self { paths: Vec::new() }
361 }
362
363 fn push(&mut self, path: PathBuf) {
364 self.paths.push(path);
365 }
366}
367
368impl Drop for TempFileGuard {
369 fn drop(&mut self) {
370 for path in &self.paths {
371 remove_temp_file(path);
372 }
373 }
374}
375
376fn file_for_role(files: &[(ModelFileRole, PathBuf)], role: ModelFileRole) -> Option<&Path> {
377 files
378 .iter()
379 .find(|(r, _)| *r == role)
380 .map(|(_, p)| p.as_path())
381}
382
383fn resolve_image_args(params: &ImageParams, source: &ModelSource) -> ResolvedImageArgs {
388 let width = if params.width > 0 {
389 params.width
390 } else if source.cli_defaults.width > 0 {
391 source.cli_defaults.width
392 } else {
393 1024
394 };
395 let height = if params.height > 0 {
396 params.height
397 } else if source.cli_defaults.height > 0 {
398 source.cli_defaults.height
399 } else {
400 1024
401 };
402 let steps = if params.steps > 0 && params.steps != 20 {
406 params.steps
407 } else if source.cli_defaults.steps > 0 {
408 source.cli_defaults.steps
409 } else {
410 STEPS_FALLBACK
411 };
412 let source_cfg = if source.cli_defaults.cfg_scale > 0.0 {
413 source.cli_defaults.cfg_scale
414 } else {
415 1.0
416 };
417 let cfg_scale = params.cfg_scale.filter(|v| *v > 0.0).unwrap_or(source_cfg);
418 let sampling_method = params
419 .sampling_method
420 .clone()
421 .or_else(|| source.cli_defaults.sampling_method.clone());
422 ResolvedImageArgs {
423 width,
424 height,
425 steps,
426 cfg_scale,
427 sampling_method,
428 }
429}
430
431#[derive(Debug, Clone, PartialEq)]
433struct ResolvedImageArgs {
434 width: u32,
435 height: u32,
436 steps: u32,
437 cfg_scale: f32,
438 sampling_method: Option<String>,
439}
440
441#[allow(clippy::too_many_arguments)]
448fn build_sdcli_args(
449 params: &ImageParams,
450 source: &ModelSource,
451 diffusion_model: &Path,
452 vae: Option<&Path>,
453 text_encoder: Option<&Path>,
454 text_encoder_vision: Option<&Path>,
455 out_path: &Path,
456 init_img_path: Option<&Path>,
457 mask_path: Option<&Path>,
458 ref_img_path: Option<&Path>,
459 full_checkpoint: bool,
460) -> Vec<OsString> {
461 let resolved = resolve_image_args(params, source);
462 let mut args: Vec<OsString> = Vec::with_capacity(32);
463
464 args.push(
467 if full_checkpoint {
468 "--model"
469 } else {
470 "--diffusion-model"
471 }
472 .into(),
473 );
474 args.push(diffusion_model.into());
475 if let Some(p) = vae {
476 args.push("--vae".into());
477 args.push(p.into());
478 }
479 if let Some(p) = text_encoder {
480 args.push("--llm".into());
481 args.push(p.into());
482 }
483 if let Some(p) = text_encoder_vision {
484 args.push("--llm_vision".into());
485 args.push(p.into());
486 }
487 args.push("-p".into());
488 args.push((¶ms.prompt as &str).into());
489 if let Some(neg) = params.negative_prompt.as_deref() {
490 if !neg.is_empty() {
491 args.push("--negative-prompt".into());
492 args.push(neg.into());
493 }
494 }
495 if let Some(reference) = ref_img_path {
496 args.push("-r".into());
502 args.push(reference.into());
503 if let Some(mask) = mask_path {
504 args.push("--mask".into());
505 args.push(mask.into());
506 }
507 } else if let Some(init) = init_img_path {
508 args.push("--init-img".into());
509 args.push(init.into());
510 let strength = params.denoise.unwrap_or(0.75);
514 args.push("--strength".into());
515 args.push(strength.to_string().into());
516 if let Some(mask) = mask_path {
518 args.push("--mask".into());
519 args.push(mask.into());
520 }
521 }
522 args.push("--cfg-scale".into());
523 args.push(resolved.cfg_scale.to_string().into());
524 args.push("--steps".into());
525 args.push(resolved.steps.to_string().into());
526 args.push("-W".into());
527 args.push(resolved.width.to_string().into());
528 args.push("-H".into());
529 args.push(resolved.height.to_string().into());
530 args.push("-o".into());
531 args.push(out_path.into());
532 if let Some(seed) = params.seed {
533 args.push("--seed".into());
534 args.push(seed.to_string().into());
535 }
536 if let Some(method) = resolved.sampling_method.as_deref() {
537 args.push("--sampling-method".into());
538 args.push(method.into());
539 }
540 if let Some(shift) = source.cli_defaults.flow_shift {
543 args.push("--flow-shift".into());
544 args.push(shift.to_string().into());
545 }
546 if source.cli_defaults.zero_cond_t == Some(true) {
547 args.push("--qwen-image-zero-cond-t".into());
548 }
549 if source.cli_defaults.offload_to_cpu == Some(true) {
550 args.push("--offload-to-cpu".into());
551 }
552 args.push("--diffusion-fa".into());
554 args
555}
556
557fn sd_cli_binary_name() -> &'static str {
559 if cfg!(target_os = "windows") {
560 "sd-cli.exe"
561 } else {
562 "sd-cli"
563 }
564}
565
566#[cfg_attr(coverage_nightly, coverage(off))]
574fn resolve_sd_cli(models_root: &Path) -> Option<PathBuf> {
575 let bin = sd_cli_binary_name();
576 if let Ok(p) = std::env::var("STUDIO_WORKER_SD_CLI") {
577 let path = PathBuf::from(p);
578 if path.is_file() {
579 return Some(path);
580 }
581 }
582 let in_models = models_root.join("bin").join(bin);
583 if in_models.is_file() {
584 return Some(in_models);
585 }
586 if let Some(home) = std::env::var_os("HOME") {
587 let candidate = PathBuf::from(home).join(".local/bin").join(bin);
588 if candidate.is_file() {
589 return Some(candidate);
590 }
591 }
592 which(bin)
593}
594
595#[cfg_attr(coverage_nightly, coverage(off))]
598fn which(bin: &str) -> Option<PathBuf> {
599 let path = std::env::var_os("PATH")?;
600 for entry in std::env::split_paths(&path) {
601 let candidate = entry.join(bin);
602 if candidate.is_file() {
603 return Some(candidate);
604 }
605 }
606 None
607}
608
609fn init_image_extension(url: &str) -> &'static str {
614 let path = url.split(['?', '#']).next().unwrap_or(url);
615 let lower_tail = path
616 .rsplit('.')
617 .next()
618 .map(|t| t.to_ascii_lowercase())
619 .unwrap_or_default();
620 match lower_tail.as_str() {
621 "png" => "png",
622 "jpg" | "jpeg" => "jpg",
623 "webp" => "webp",
624 "bmp" => "bmp",
625 "gif" => "gif",
626 "tif" | "tiff" => "tif",
627 _ => "webp",
628 }
629}
630
631#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::types::{ModelCliDefaults, ModelEngine, ModelFile, ModelFileRole};
639 use tempfile::tempdir;
640
641 fn fake_source(files: Vec<ModelFile>) -> ModelSource {
642 ModelSource {
643 engine: ModelEngine::SdCpp,
644 files,
645 cli_defaults: ModelCliDefaults {
646 cfg_scale: 1.0,
647 steps: 8,
648 width: 1024,
649 height: 1024,
650 sampling_method: Some("euler".to_string()),
651 ..Default::default()
652 },
653 }
654 }
655
656 #[test]
657 fn temp_file_guard_removes_every_registered_file_on_drop() {
658 let dir = tempdir().unwrap();
659 let out = dir.path().join("out.webp");
660 let init = dir.path().join("out-init.png");
661 std::fs::write(&out, b"image").unwrap();
662 std::fs::write(&init, b"init").unwrap();
663 {
664 let mut guard = TempFileGuard::new();
665 guard.push(out.clone());
666 guard.push(init.clone());
667 assert!(out.exists() && init.exists(), "files present before drop");
668 }
669 assert!(!out.exists(), "sd-cli output temp must be removed on drop");
670 assert!(!init.exists(), "init-image temp must be removed on drop");
671 }
672
673 #[test]
674 fn temp_file_guard_tolerates_a_file_that_never_materialised() {
675 let dir = tempdir().unwrap();
680 let missing = dir.path().join("never-written.webp");
681 let out = crate::test_support::capture(move || {
682 let mut guard = TempFileGuard::new();
683 guard.push(missing);
684 drop(guard);
685 });
686 assert!(
687 !out.contains("failed to remove temp file"),
688 "a never-created temp file must not warn on cleanup: {out:?}"
689 );
690 }
691
692 #[test]
693 fn remove_temp_file_deletes_an_existing_file_quietly() {
694 let dir = tempdir().unwrap();
695 let f = dir.path().join("artefact.webp");
696 std::fs::write(&f, b"bytes").unwrap();
697 let out = crate::test_support::capture({
698 let f = f.clone();
699 move || remove_temp_file(&f)
700 });
701 assert!(!f.exists(), "file should be gone after cleanup");
702 assert!(
703 !out.contains("failed to remove temp file"),
704 "the success path must not warn: {out:?}"
705 );
706 }
707
708 #[test]
709 fn remove_temp_file_ignores_an_already_missing_file() {
710 let dir = tempdir().unwrap();
711 let missing = dir.path().join("never-existed.webp");
712 let out = crate::test_support::capture(move || remove_temp_file(&missing));
713 assert!(
714 !out.contains("failed to remove temp file"),
715 "a not-found file is the desired end state, not a warning: {out:?}"
716 );
717 }
718
719 #[test]
720 fn remove_temp_file_surfaces_a_failed_removal() {
721 let dir = tempdir().unwrap();
725 let stubborn = dir.path().join("subdir");
726 std::fs::create_dir(&stubborn).unwrap();
727 let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
728 assert!(
729 out.contains("failed to remove temp file"),
730 "a failed removal must surface in the logs: {out:?}"
731 );
732 assert!(
733 out.contains("subdir"),
734 "the warning must name the offending path: {out:?}"
735 );
736 assert!(
737 out.contains("cleanup"),
738 "the warning should tag the cleanup op: {out:?}"
739 );
740 }
741
742 #[test]
743 fn file_for_role_picks_matching_file() {
744 let files = vec![
745 (ModelFileRole::DiffusionModel, PathBuf::from("/d.gguf")),
746 (ModelFileRole::Vae, PathBuf::from("/v.safetensors")),
747 ];
748 assert_eq!(
749 file_for_role(&files, ModelFileRole::DiffusionModel),
750 Some(Path::new("/d.gguf"))
751 );
752 assert_eq!(
753 file_for_role(&files, ModelFileRole::Vae),
754 Some(Path::new("/v.safetensors"))
755 );
756 assert!(file_for_role(&files, ModelFileRole::TextEncoder).is_none());
757 }
758
759 #[test]
760 fn ensure_files_skips_already_present() {
761 let dir = tempdir().unwrap();
762 let cached = dir.path().join("cached.gguf");
763 std::fs::write(&cached, b"already here").unwrap();
764 let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
765 let source = fake_source(vec![ModelFile {
766 role: ModelFileRole::DiffusionModel,
767 url: "https://example.invalid/cached.gguf".into(),
768 filename: "cached.gguf".into(),
769 approx_bytes: None,
770 }]);
771 let resolved = engine.ensure_files(&source).expect("cached file used");
772 assert_eq!(resolved.len(), 1);
773 assert_eq!(resolved[0].0, ModelFileRole::DiffusionModel);
774 assert_eq!(resolved[0].1, cached);
775 assert_eq!(std::fs::read(&cached).unwrap(), b"already here");
777 }
778
779 #[test]
780 fn dispatch_rejects_non_image_tasks() {
781 use crate::types::AudioTtsParams;
782 let dir = tempdir().unwrap();
783 let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
784 let task = Task::AudioTts(AudioTtsParams {
785 text: "hi".into(),
786 voice: "v".into(),
787 ext: "wav".into(),
788 ..Default::default()
789 });
790 let source = fake_source(vec![]);
791 let err = engine
792 .dispatch_with_source("anything", task, &source)
793 .unwrap_err();
794 assert!(err.to_string().contains("cannot serve audio_tts"));
795 }
796
797 fn args_to_strings(args: &[OsString]) -> Vec<String> {
807 args.iter()
808 .map(|s| s.to_string_lossy().into_owned())
809 .collect()
810 }
811
812 fn idx_after(args: &[String], flag: &str) -> Option<usize> {
813 args.iter().position(|a| a == flag).map(|i| i + 1)
814 }
815
816 #[test]
817 fn build_sdcli_args_includes_required_flags() {
818 let params = ImageParams {
819 prompt: "hello".into(),
820 width: 768,
821 height: 512,
822 steps: 20, ..Default::default()
824 };
825 let source = fake_source(vec![]);
826 let args = build_sdcli_args(
827 ¶ms,
828 &source,
829 Path::new("/d.gguf"),
830 Some(Path::new("/v.safetensors")),
831 Some(Path::new("/llm.gguf")),
832 None,
833 Path::new("/tmp/out.webp"),
834 None,
835 None,
836 None,
837 false,
838 );
839 let s = args_to_strings(&args);
840 assert_eq!(s[idx_after(&s, "--diffusion-model").unwrap()], "/d.gguf");
841 assert_eq!(s[idx_after(&s, "--vae").unwrap()], "/v.safetensors");
842 assert_eq!(s[idx_after(&s, "--llm").unwrap()], "/llm.gguf");
843 assert_eq!(s[idx_after(&s, "-p").unwrap()], "hello");
844 assert_eq!(s[idx_after(&s, "-W").unwrap()], "768");
845 assert_eq!(s[idx_after(&s, "-H").unwrap()], "512");
846 assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "1");
848 assert_eq!(s[idx_after(&s, "--steps").unwrap()], "8");
850 assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "euler");
851 assert_eq!(s[idx_after(&s, "-o").unwrap()], "/tmp/out.webp");
852 assert!(s.contains(&"--diffusion-fa".to_string()));
853 assert!(!s.contains(&"--init-img".to_string()));
855 assert!(!s.contains(&"--strength".to_string()));
856 }
857
858 #[test]
859 fn build_sdcli_args_includes_negative_prompt_when_set() {
860 let params = ImageParams {
861 prompt: "hi".into(),
862 negative_prompt: Some("text, watermark, low quality".into()),
863 ..Default::default()
864 };
865 let source = fake_source(vec![]);
866 let args = build_sdcli_args(
867 ¶ms,
868 &source,
869 Path::new("/d.gguf"),
870 None,
871 None,
872 None,
873 Path::new("/tmp/out.webp"),
874 None,
875 None,
876 None,
877 false,
878 );
879 let s = args_to_strings(&args);
880 assert_eq!(
881 s[idx_after(&s, "--negative-prompt").unwrap()],
882 "text, watermark, low quality"
883 );
884 }
885
886 #[test]
887 fn build_sdcli_args_omits_negative_prompt_when_empty_string() {
888 let params = ImageParams {
889 prompt: "hi".into(),
890 negative_prompt: Some(String::new()),
891 ..Default::default()
892 };
893 let source = fake_source(vec![]);
894 let args = build_sdcli_args(
895 ¶ms,
896 &source,
897 Path::new("/d.gguf"),
898 None,
899 None,
900 None,
901 Path::new("/tmp/out.webp"),
902 None,
903 None,
904 None,
905 false,
906 );
907 let s = args_to_strings(&args);
908 assert!(!s.contains(&"--negative-prompt".to_string()));
909 }
910
911 #[test]
912 fn build_sdcli_args_includes_init_image_and_strength() {
913 let params = ImageParams {
914 prompt: "hi".into(),
915 denoise: Some(0.55),
916 ..Default::default()
917 };
918 let source = fake_source(vec![]);
919 let args = build_sdcli_args(
920 ¶ms,
921 &source,
922 Path::new("/d.gguf"),
923 None,
924 None,
925 None,
926 Path::new("/tmp/out.webp"),
927 Some(Path::new("/tmp/init.webp")),
928 None,
929 None,
930 false,
931 );
932 let s = args_to_strings(&args);
933 assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
934 assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.55");
935 assert!(!s.contains(&"--mask".to_string()));
937 }
938
939 #[test]
940 fn build_sdcli_args_includes_mask_for_inpaint() {
941 let params = ImageParams {
942 prompt: "remove the tree".into(),
943 denoise: Some(0.8),
944 ..Default::default()
945 };
946 let source = fake_source(vec![]);
947 let args = build_sdcli_args(
948 ¶ms,
949 &source,
950 Path::new("/d.gguf"),
951 None,
952 None,
953 None,
954 Path::new("/tmp/out.webp"),
955 Some(Path::new("/tmp/init.webp")),
956 Some(Path::new("/tmp/mask.png")),
957 None,
958 false,
959 );
960 let s = args_to_strings(&args);
961 assert_eq!(s[idx_after(&s, "--init-img").unwrap()], "/tmp/init.webp");
962 assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
963 assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.8");
964 }
965
966 #[test]
967 fn build_sdcli_args_uses_model_flag_for_full_checkpoint() {
968 let params = ImageParams {
969 prompt: "hi".into(),
970 ..Default::default()
971 };
972 let source = fake_source(vec![]);
973 let args = build_sdcli_args(
974 ¶ms,
975 &source,
976 Path::new("/checkpoint.safetensors"),
977 Some(Path::new("/v.safetensors")),
978 None,
979 None,
980 Path::new("/tmp/out.webp"),
981 None,
982 None,
983 None,
984 true,
985 );
986 let s = args_to_strings(&args);
987 assert_eq!(
989 s[idx_after(&s, "--model").unwrap()],
990 "/checkpoint.safetensors"
991 );
992 assert!(!s.contains(&"--diffusion-model".to_string()));
993 }
994
995 #[test]
996 fn build_sdcli_args_defaults_denoise_when_init_image_present_but_denoise_none() {
997 let params = ImageParams {
998 prompt: "hi".into(),
999 denoise: None,
1000 ..Default::default()
1001 };
1002 let source = fake_source(vec![]);
1003 let args = build_sdcli_args(
1004 ¶ms,
1005 &source,
1006 Path::new("/d.gguf"),
1007 None,
1008 None,
1009 None,
1010 Path::new("/tmp/out.webp"),
1011 Some(Path::new("/tmp/init.webp")),
1012 None,
1013 None,
1014 false,
1015 );
1016 let s = args_to_strings(&args);
1017 assert_eq!(s[idx_after(&s, "--strength").unwrap()], "0.75");
1018 }
1019
1020 #[test]
1021 fn build_sdcli_args_per_job_cfg_scale_overrides_model_default() {
1022 let params = ImageParams {
1023 prompt: "hi".into(),
1024 cfg_scale: Some(7.5),
1025 ..Default::default()
1026 };
1027 let source = fake_source(vec![]);
1028 let args = build_sdcli_args(
1029 ¶ms,
1030 &source,
1031 Path::new("/d.gguf"),
1032 None,
1033 None,
1034 None,
1035 Path::new("/tmp/out.webp"),
1036 None,
1037 None,
1038 None,
1039 false,
1040 );
1041 let s = args_to_strings(&args);
1042 assert_eq!(s[idx_after(&s, "--cfg-scale").unwrap()], "7.5");
1043 }
1044
1045 #[test]
1046 fn build_sdcli_args_per_job_sampling_method_overrides_model_default() {
1047 let params = ImageParams {
1048 prompt: "hi".into(),
1049 sampling_method: Some("dpm++2m".into()),
1050 ..Default::default()
1051 };
1052 let source = fake_source(vec![]);
1053 let args = build_sdcli_args(
1054 ¶ms,
1055 &source,
1056 Path::new("/d.gguf"),
1057 None,
1058 None,
1059 None,
1060 Path::new("/tmp/out.webp"),
1061 None,
1062 None,
1063 None,
1064 false,
1065 );
1066 let s = args_to_strings(&args);
1067 assert_eq!(s[idx_after(&s, "--sampling-method").unwrap()], "dpm++2m");
1068 }
1069
1070 #[test]
1071 fn build_sdcli_args_per_job_steps_overrides_when_non_default() {
1072 let params = ImageParams {
1073 prompt: "hi".into(),
1074 steps: 30, ..Default::default()
1076 };
1077 let source = fake_source(vec![]);
1078 let args = build_sdcli_args(
1079 ¶ms,
1080 &source,
1081 Path::new("/d.gguf"),
1082 None,
1083 None,
1084 None,
1085 Path::new("/tmp/out.webp"),
1086 None,
1087 None,
1088 None,
1089 false,
1090 );
1091 let s = args_to_strings(&args);
1092 assert_eq!(s[idx_after(&s, "--steps").unwrap()], "30");
1093 }
1094
1095 #[test]
1096 fn build_sdcli_args_seed_included_when_set() {
1097 let params = ImageParams {
1098 prompt: "hi".into(),
1099 seed: Some(42),
1100 ..Default::default()
1101 };
1102 let source = fake_source(vec![]);
1103 let args = build_sdcli_args(
1104 ¶ms,
1105 &source,
1106 Path::new("/d.gguf"),
1107 None,
1108 None,
1109 None,
1110 Path::new("/tmp/out.webp"),
1111 None,
1112 None,
1113 None,
1114 false,
1115 );
1116 let s = args_to_strings(&args);
1117 assert_eq!(s[idx_after(&s, "--seed").unwrap()], "42");
1118 }
1119
1120 fn qwen_edit_source() -> ModelSource {
1122 ModelSource {
1123 engine: ModelEngine::SdCpp,
1124 files: vec![],
1125 cli_defaults: ModelCliDefaults {
1126 cfg_scale: 4.0,
1127 steps: 20,
1128 width: 1024,
1129 height: 1024,
1130 sampling_method: Some("euler".to_string()),
1131 flow_shift: Some(3.0),
1132 zero_cond_t: Some(true),
1133 offload_to_cpu: Some(true),
1134 },
1135 }
1136 }
1137
1138 #[test]
1139 fn build_sdcli_args_reference_mode_for_instruction_edit() {
1140 let params = ImageParams {
1141 prompt: "add a red beach ball".into(),
1142 denoise: Some(0.9),
1143 ..Default::default()
1144 };
1145 let source = qwen_edit_source();
1146 let args = build_sdcli_args(
1147 ¶ms,
1148 &source,
1149 Path::new("/qwen.gguf"),
1150 Some(Path::new("/vae.safetensors")),
1151 Some(Path::new("/llm.gguf")),
1152 Some(Path::new("/mmproj.gguf")),
1153 Path::new("/tmp/out.webp"),
1154 None,
1155 Some(Path::new("/tmp/mask.png")),
1156 Some(Path::new("/tmp/ref.webp")),
1157 false,
1158 );
1159 let s = args_to_strings(&args);
1160 assert_eq!(s[idx_after(&s, "-r").unwrap()], "/tmp/ref.webp");
1163 assert_eq!(s[idx_after(&s, "--mask").unwrap()], "/tmp/mask.png");
1164 assert!(!s.contains(&"--init-img".to_string()));
1165 assert!(!s.contains(&"--strength".to_string()));
1166 assert_eq!(s[idx_after(&s, "--llm_vision").unwrap()], "/mmproj.gguf");
1168 assert_eq!(s[idx_after(&s, "--flow-shift").unwrap()], "3");
1169 assert!(s.contains(&"--qwen-image-zero-cond-t".to_string()));
1170 assert!(s.contains(&"--offload-to-cpu".to_string()));
1171 }
1172
1173 #[test]
1174 fn build_sdcli_args_omits_qwen_flags_for_plain_model() {
1175 let params = ImageParams {
1176 prompt: "hi".into(),
1177 ..Default::default()
1178 };
1179 let source = fake_source(vec![]);
1181 let args = build_sdcli_args(
1182 ¶ms,
1183 &source,
1184 Path::new("/d.gguf"),
1185 None,
1186 None,
1187 None,
1188 Path::new("/tmp/out.webp"),
1189 None,
1190 None,
1191 None,
1192 false,
1193 );
1194 let s = args_to_strings(&args);
1195 assert!(!s.contains(&"--flow-shift".to_string()));
1196 assert!(!s.contains(&"--qwen-image-zero-cond-t".to_string()));
1197 assert!(!s.contains(&"--offload-to-cpu".to_string()));
1198 assert!(!s.contains(&"--llm_vision".to_string()));
1199 assert!(!s.contains(&"-r".to_string()));
1200 }
1201
1202 #[test]
1203 fn capabilities_advertises_only_image_kind() {
1204 let dir = tempdir().unwrap();
1205 let engine = SdCppEngine::with_paths(PathBuf::from("/usr/bin/true"), dir.path().into());
1206 let caps = engine.capabilities();
1207 assert!(caps
1208 .supported_models_per_kind
1209 .contains_key(&TaskKind::Image));
1210 assert_eq!(caps.supported_models_per_kind.len(), 1);
1211 }
1212
1213 #[test]
1214 fn init_image_extension_reads_url_tail() {
1215 assert_eq!(init_image_extension("https://x/y/latest.webp"), "webp");
1216 assert_eq!(init_image_extension("https://x/y/latest.PNG"), "png");
1217 assert_eq!(init_image_extension("https://x/y/latest.jpg"), "jpg");
1218 assert_eq!(init_image_extension("https://x/y/latest.jpeg"), "jpg");
1219 assert_eq!(
1221 init_image_extension("https://x/y/latest.webp?v=42&t=now"),
1222 "webp"
1223 );
1224 assert_eq!(init_image_extension("https://x/y/latest.webp#frag"), "webp");
1225 assert_eq!(
1227 init_image_extension("https://x/y/latest.unknownext"),
1228 "webp"
1229 );
1230 assert_eq!(init_image_extension("https://x/y/no-ext"), "webp");
1231 }
1232}