1use crate::config::Config;
11use crate::types::*;
12use anyhow::{anyhow, bail, Result};
13use image::{ImageBuffer, Rgb, RgbImage};
14use sha2::{Digest, Sha256};
15use std::collections::BTreeMap;
16use std::io::Cursor;
17use std::time::Instant;
18use tracing::{debug, warn};
19
20const TRACE_TARGET_SYNTHETIC: &str = "studio_worker::engine::synthetic";
23
24const TRACE_TARGET_GRADIO: &str = "studio_worker::engine::gradio";
26
27#[derive(Debug, Clone, Default)]
29pub struct EngineCapabilities {
30 pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
33}
34
35impl EngineCapabilities {
36 pub fn supports(&self, kind: TaskKind, model: &str) -> bool {
37 self.supported_models_per_kind
38 .get(&kind)
39 .map(|ms| ms.iter().any(|m| m == model))
40 .unwrap_or(false)
41 }
42
43 pub fn kinds(&self) -> Vec<TaskKind> {
44 self.supported_models_per_kind.keys().copied().collect()
45 }
46
47 pub fn flat_models(&self) -> Vec<String> {
48 self.supported_models_per_kind
49 .values()
50 .flat_map(|ms| ms.iter().cloned())
51 .collect()
52 }
53}
54
55#[cfg(feature = "image-candle")]
56pub mod candle_image;
57#[cfg(feature = "llama")]
58pub mod llama;
59pub mod multi;
60#[cfg(feature = "tts")]
61pub mod tts;
62#[cfg(feature = "video")]
63pub mod video;
64#[cfg(feature = "whisper")]
65pub mod whisper;
66
67pub trait Engine: Send + Sync {
68 fn name(&self) -> &'static str;
69 fn capabilities(&self) -> EngineCapabilities;
70 fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult>;
71}
72
73pub fn build(cfg: &Config) -> Result<Box<dyn Engine>> {
74 if cfg.engine == "multi" {
75 return build_multi(cfg);
76 }
77 build_single(cfg, cfg.engine.as_str())
78}
79
80fn build_multi(cfg: &Config) -> Result<Box<dyn Engine>> {
81 let names = &cfg.engines;
82 if names.is_empty() {
83 bail!("multi engine requires a non-empty `engines` list in the config");
84 }
85 let mut built: Vec<Box<dyn Engine>> = Vec::with_capacity(names.len());
86 for name in names {
87 built.push(build_single(cfg, name)?);
88 }
89 Ok(Box::new(multi::MultiEngine::new(built)))
90}
91
92fn build_single(cfg: &Config, name: &str) -> Result<Box<dyn Engine>> {
93 match name {
94 "synthetic" => Ok(Box::new(SyntheticEngine::new(
95 cfg.supported_models_override.clone(),
96 ))),
97 "gradio" => {
98 let url = cfg
99 .gradio_endpoint_url
100 .clone()
101 .ok_or_else(|| anyhow!("gradio engine requires gradio_endpoint_url"))?;
102 Ok(Box::new(GradioEngine::new(
103 url,
104 cfg.supported_models_override.clone(),
105 )))
106 }
107 #[cfg(feature = "llama")]
108 "llama" => {
109 let root = cfg.models_root.clone().unwrap_or_else(default_models_root);
110 Ok(Box::new(llama::LlamaEngine::new(root)?))
111 }
112 #[cfg(feature = "whisper")]
113 "whisper" => {
114 let root = cfg.models_root.clone().unwrap_or_else(default_models_root);
115 Ok(Box::new(whisper::WhisperEngine::new(root)))
116 }
117 #[cfg(feature = "image-candle")]
118 "image-candle" => Ok(Box::new(candle_image::CandleImageEngine::new())),
119 #[cfg(feature = "video")]
120 "video" => Ok(Box::new(video::VideoEngine::new())),
121 #[cfg(feature = "tts")]
122 "tts" => Ok(Box::new(tts::TtsEngine::new())),
123 "multi" => bail!("nested `multi` engines are not allowed"),
124 other => bail!("unknown engine: {other}"),
125 }
126}
127
128pub fn default_models_root() -> std::path::PathBuf {
130 if let Some(dir) = directories::ProjectDirs::from("gg", "minis", "minis-studio-worker") {
131 return dir.cache_dir().to_path_buf();
132 }
133 std::env::temp_dir().join("studio-worker-models")
134}
135
136pub struct SyntheticEngine {
142 overrides: Vec<String>,
143}
144
145impl SyntheticEngine {
146 pub fn new(overrides: Vec<String>) -> Self {
147 Self { overrides }
148 }
149}
150
151const DEFAULT_IMAGE_MODELS: &[&str] = &[
152 "synthetic",
153 "synthetic-image",
154 "flux1-dev",
155 "flux1-dev-i2i",
156 "sdxl-1.0",
157];
158const DEFAULT_LLM_MODELS: &[&str] = &["synthetic", "synthetic-llm", "llama-3.1-8b-instruct-q4"];
159const DEFAULT_STT_MODELS: &[&str] = &["synthetic", "synthetic-stt", "whisper-medium"];
160const DEFAULT_TTS_MODELS: &[&str] = &["synthetic", "synthetic-tts", "piper-en"];
161const DEFAULT_VIDEO_MODELS: &[&str] = &["synthetic", "synthetic-video"];
162
163impl Engine for SyntheticEngine {
164 fn name(&self) -> &'static str {
165 "synthetic"
166 }
167
168 fn capabilities(&self) -> EngineCapabilities {
169 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
170 let (image_models, llm_models, stt_models, tts_models, video_models) =
171 if self.overrides.is_empty() {
172 (
173 DEFAULT_IMAGE_MODELS
174 .iter()
175 .map(|s| (*s).to_string())
176 .collect::<Vec<_>>(),
177 DEFAULT_LLM_MODELS
178 .iter()
179 .map(|s| (*s).to_string())
180 .collect::<Vec<_>>(),
181 DEFAULT_STT_MODELS
182 .iter()
183 .map(|s| (*s).to_string())
184 .collect::<Vec<_>>(),
185 DEFAULT_TTS_MODELS
186 .iter()
187 .map(|s| (*s).to_string())
188 .collect::<Vec<_>>(),
189 DEFAULT_VIDEO_MODELS
190 .iter()
191 .map(|s| (*s).to_string())
192 .collect::<Vec<_>>(),
193 )
194 } else {
195 let same = self.overrides.clone();
198 (same.clone(), same.clone(), same.clone(), same.clone(), same)
199 };
200 map.insert(TaskKind::Image, image_models);
201 map.insert(TaskKind::Llm, llm_models);
202 map.insert(TaskKind::AudioStt, stt_models);
203 map.insert(TaskKind::AudioTts, tts_models);
204 map.insert(TaskKind::Video, video_models);
205 EngineCapabilities {
206 supported_models_per_kind: map,
207 }
208 }
209
210 fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
211 let kind = task.kind();
212 let started = Instant::now();
213 let result = match task {
214 Task::Image(p) => render_procedural(&p.prompt, &p.ext)
215 .map(|bytes| TaskResult::Image { bytes, ext: p.ext }),
216 Task::Llm(p) => {
217 let prompt = p
218 .messages
219 .iter()
220 .map(|m| format!("{}: {}", m.role, m.content))
221 .collect::<Vec<_>>()
222 .join("\n");
223 Ok(TaskResult::Llm {
224 json: synthetic_llm_response(&prompt),
225 })
226 }
227 Task::AudioStt(p) => Ok(TaskResult::AudioStt {
228 json: synthetic_stt_response(&p.input_url, p.language.as_deref()),
229 }),
230 Task::AudioTts(p) => render_wav(&p.text).map(|bytes| TaskResult::AudioTts {
231 bytes,
232 ext: "wav".into(),
233 }),
234 Task::Video(p) => {
235 render_animated_webp(&p.prompt, p.width, p.height, p.seconds).map(|bytes| {
239 TaskResult::Video {
240 bytes,
241 ext: "webp".into(),
242 }
243 })
244 }
245 };
246 let elapsed_ms = started.elapsed().as_millis() as u64;
247 match &result {
248 Ok(_) => debug!(
249 target: TRACE_TARGET_SYNTHETIC,
250 op = "dispatch",
251 kind = kind.as_str(),
252 model,
253 elapsed_ms,
254 "ok"
255 ),
256 Err(e) => warn!(
257 target: TRACE_TARGET_SYNTHETIC,
258 op = "dispatch",
259 kind = kind.as_str(),
260 model,
261 elapsed_ms,
262 error = %e,
263 "failed"
264 ),
265 }
266 result
267 }
268}
269
270pub fn render_procedural(prompt: &str, ext: &str) -> Result<Vec<u8>> {
276 let digest = sha256_bytes(prompt);
277 let palette = [
278 Rgb([digest[0], digest[1], digest[2]]),
279 Rgb([digest[3], digest[4], digest[5]]),
280 Rgb([digest[6], digest[7], digest[8]]),
281 Rgb([digest[9], digest[10], digest[11]]),
282 ];
283
284 let size: u32 = 512;
285 let mut img: RgbImage = ImageBuffer::new(size, size);
286 for (x, y, pixel) in img.enumerate_pixels_mut() {
287 let cx = size as f32 / 2.0;
288 let cy = size as f32 / 2.0;
289 let dx = (x as f32 - cx).abs();
290 let dy = (y as f32 - cy).abs();
291 let chebyshev = dx.max(dy) / cx;
292 let ring = (chebyshev * 6.0).floor() as usize;
293 let base = palette[ring.min(palette.len() - 1)];
294 let phase = ((x as f32 / 24.0).sin() + (y as f32 / 24.0).cos()) * 12.0;
295 *pixel = Rgb([
296 base.0[0].saturating_add(phase as i8 as u8),
297 base.0[1].saturating_add((phase * 0.7) as i8 as u8),
298 base.0[2].saturating_add((phase * 1.3) as i8 as u8),
299 ]);
300 }
301
302 let mut out = Cursor::new(Vec::<u8>::new());
303 let dyn_img = image::DynamicImage::ImageRgb8(img);
304 match ext {
305 "webp" => dyn_img.write_to(&mut out, image::ImageFormat::WebP)?,
306 _ => dyn_img.write_to(&mut out, image::ImageFormat::Png)?,
307 }
308 Ok(out.into_inner())
309}
310
311pub fn synthetic_llm_response(prompt: &str) -> serde_json::Value {
314 let hash = hex::encode(sha256_bytes(prompt));
315 serde_json::json!({
316 "object": "chat.completion",
317 "model": "synthetic-llm",
318 "choices": [{
319 "index": 0,
320 "message": {
321 "role": "assistant",
322 "content": format!("[synthetic] reply to prompt #{}", &hash[..16]),
323 },
324 "finish_reason": "stop",
325 }],
326 "usage": {
327 "prompt_tokens": prompt.split_whitespace().count(),
328 "completion_tokens": 8,
329 "total_tokens": prompt.split_whitespace().count() + 8,
330 },
331 })
332}
333
334pub fn synthetic_stt_response(input_url: &str, language: Option<&str>) -> serde_json::Value {
336 let hash = hex::encode(sha256_bytes(input_url));
337 serde_json::json!({
338 "text": format!("[synthetic] transcript of {}", &hash[..16]),
339 "language": language.unwrap_or("en"),
340 "duration": 1.0,
341 })
342}
343
344pub fn render_wav(text: &str) -> Result<Vec<u8>> {
347 use hound::{SampleFormat, WavSpec, WavWriter};
348 let digest = sha256_bytes(text);
349 let freq_hz = 220.0 + (digest[0] as f32) * (660.0 / 255.0); let sample_rate: u32 = 22_050;
351 let spec = WavSpec {
352 channels: 1,
353 sample_rate,
354 bits_per_sample: 16,
355 sample_format: SampleFormat::Int,
356 };
357
358 let mut buf = Cursor::new(Vec::<u8>::new());
359 {
360 let mut writer = WavWriter::new(&mut buf, spec)?;
361 let total_samples = sample_rate; for n in 0..total_samples {
363 let t = n as f32 / sample_rate as f32;
364 let amplitude = (t * 2.0 * std::f32::consts::PI * freq_hz).sin();
365 let s = (amplitude * 0.4 * i16::MAX as f32) as i16;
366 writer.write_sample(s)?;
367 }
368 writer.finalize()?;
369 }
370 Ok(buf.into_inner())
371}
372
373pub fn render_animated_webp(prompt: &str, _w: u32, _h: u32, seconds: f32) -> Result<Vec<u8>> {
377 let _ = seconds;
383 render_procedural(prompt, "webp")
384}
385
386fn sha256_bytes(input: &str) -> [u8; 32] {
387 let mut hasher = Sha256::new();
388 hasher.update(input.as_bytes());
389 let digest = hasher.finalize();
390 let mut out = [0u8; 32];
391 out.copy_from_slice(&digest);
392 out
393}
394
395pub struct GradioEngine {
403 pub endpoint_url: String,
404 overrides: Vec<String>,
405}
406
407impl GradioEngine {
408 pub fn new(endpoint_url: String, overrides: Vec<String>) -> Self {
409 Self {
410 endpoint_url,
411 overrides,
412 }
413 }
414}
415
416impl Engine for GradioEngine {
417 fn name(&self) -> &'static str {
418 "gradio"
419 }
420
421 fn capabilities(&self) -> EngineCapabilities {
422 let mut map: BTreeMap<TaskKind, Vec<String>> = BTreeMap::new();
423 map.insert(TaskKind::Image, self.overrides.clone());
424 EngineCapabilities {
425 supported_models_per_kind: map,
426 }
427 }
428
429 fn dispatch(&self, model: &str, task: Task) -> Result<TaskResult> {
430 let kind = task.kind();
431 let started = Instant::now();
432 let image_params = match task {
433 Task::Image(p) => p,
434 other => {
435 warn!(
436 target: TRACE_TARGET_GRADIO,
437 op = "dispatch",
438 kind = kind.as_str(),
439 model,
440 "unsupported task kind"
441 );
442 bail!("gradio engine cannot serve {} tasks", other.kind().as_str());
443 }
444 };
445 let result = call_gradio(&self.endpoint_url, &image_params.prompt, model);
446 let elapsed_ms = started.elapsed().as_millis() as u64;
447 match &result {
448 Ok(_) => debug!(
449 target: TRACE_TARGET_GRADIO,
450 op = "dispatch",
451 kind = kind.as_str(),
452 model,
453 elapsed_ms,
454 "ok"
455 ),
456 Err(e) => warn!(
457 target: TRACE_TARGET_GRADIO,
458 op = "dispatch",
459 kind = kind.as_str(),
460 model,
461 elapsed_ms,
462 error = %e,
463 "failed"
464 ),
465 }
466 let bytes = result?;
467 Ok(TaskResult::Image {
468 bytes,
469 ext: image_params.ext,
470 })
471 }
472}
473
474fn call_gradio(endpoint_url: &str, prompt: &str, model: &str) -> Result<Vec<u8>> {
475 use base64::Engine as _;
476 let client = reqwest::blocking::Client::builder()
477 .timeout(std::time::Duration::from_secs(300))
478 .build()?;
479
480 let url = format!("{}/run/predict", endpoint_url.trim_end_matches('/'));
481 let body = serde_json::json!({ "data": [prompt, model] });
482 let started = Instant::now();
483
484 let response = client.post(&url).json(&body).send().map_err(|e| {
485 warn!(
486 target: TRACE_TARGET_GRADIO,
487 op = "predict",
488 endpoint = %url,
489 error = %e,
490 "request failed"
491 );
492 anyhow!("gradio request failed: {e}")
493 })?;
494 let status = response.status();
495 let elapsed_ms = started.elapsed().as_millis() as u64;
496 if !status.is_success() {
497 warn!(
498 target: TRACE_TARGET_GRADIO,
499 op = "predict",
500 endpoint = %url,
501 status = status.as_u16(),
502 elapsed_ms,
503 "non-2xx response"
504 );
505 bail!("gradio returned {}", status);
506 }
507 debug!(
508 target: TRACE_TARGET_GRADIO,
509 op = "predict",
510 endpoint = %url,
511 status = status.as_u16(),
512 elapsed_ms,
513 "ok"
514 );
515 let parsed: serde_json::Value = response.json()?;
516 let image_field = parsed
517 .get("data")
518 .and_then(|v| v.as_array())
519 .and_then(|a| a.first())
520 .ok_or_else(|| anyhow!("gradio response missing data[0]"))?;
521
522 if let Some(s) = image_field.as_str() {
523 if let Some(rest) = s.strip_prefix("data:") {
524 if let Some(idx) = rest.find(',') {
525 let payload = &rest[idx + 1..];
526 let decoded = base64::engine::general_purpose::STANDARD
527 .decode(payload)
528 .map_err(|e| anyhow!("invalid base64 image: {e}"))?;
529 return Ok(decoded);
530 }
531 }
532 let abs_url = if s.starts_with("http") {
533 s.to_string()
534 } else {
535 format!(
536 "{}/{}",
537 endpoint_url.trim_end_matches('/'),
538 s.trim_start_matches('/')
539 )
540 };
541 let response = client.get(&abs_url).send()?;
542 if !response.status().is_success() {
543 bail!("gradio image fetch returned {}", response.status());
544 }
545 return Ok(response.bytes()?.to_vec());
546 }
547 if let Some(obj) = image_field.as_object() {
548 if let Some(url) = obj.get("url").and_then(|v| v.as_str()) {
549 let response = client.get(url).send()?;
550 return Ok(response.bytes()?.to_vec());
551 }
552 }
553 bail!("unsupported gradio image payload")
554}
555
556#[cfg(test)]
561mod tests {
562 use super::*;
563 use std::io::Cursor;
564
565 #[test]
566 fn synthetic_image_round_trips_as_webp() {
567 let engine = SyntheticEngine::new(vec![]);
568 let task = Task::Image(ImageParams {
569 prompt: "hello world".into(),
570 width: 512,
571 height: 512,
572 steps: 20,
573 seed: None,
574 ext: "webp".into(),
575 });
576 let result = engine.dispatch("synthetic", task).unwrap();
577 let (bytes, ext) = match result {
578 TaskResult::Image { bytes, ext } => (bytes, ext),
579 other => panic!("expected image, got {:?}", other.kind()),
580 };
581 assert_eq!(ext, "webp");
582 assert!(bytes.len() > 100);
583 let reader = image::ImageReader::new(Cursor::new(&bytes))
584 .with_guessed_format()
585 .unwrap();
586 assert_eq!(reader.format().unwrap(), image::ImageFormat::WebP);
587 }
588
589 #[test]
590 fn synthetic_llm_returns_chat_completion_shape() {
591 let engine = SyntheticEngine::new(vec![]);
592 let task = Task::Llm(LlmParams {
593 messages: vec![ChatMessage {
594 role: "user".into(),
595 content: "what is the capital of france?".into(),
596 }],
597 max_tokens: 64,
598 temperature: 0.5,
599 });
600 let result = engine.dispatch("synthetic", task).unwrap();
601 let json = match result {
602 TaskResult::Llm { json } => json,
603 other => panic!("expected llm, got {:?}", other.kind()),
604 };
605 assert_eq!(json["object"], "chat.completion");
606 assert!(json["choices"][0]["message"]["content"]
607 .as_str()
608 .unwrap()
609 .starts_with("[synthetic]"));
610 }
611
612 #[test]
613 fn synthetic_stt_returns_whisper_shape() {
614 let engine = SyntheticEngine::new(vec![]);
615 let task = Task::AudioStt(AudioSttParams {
616 input_url: "https://example.com/audio.wav".into(),
617 language: Some("nl".into()),
618 });
619 let result = engine.dispatch("synthetic", task).unwrap();
620 let json = match result {
621 TaskResult::AudioStt { json } => json,
622 other => panic!("expected stt, got {:?}", other.kind()),
623 };
624 assert_eq!(json["language"], "nl");
625 assert!(json["text"].as_str().unwrap().starts_with("[synthetic]"));
626 }
627
628 #[test]
629 fn synthetic_tts_produces_real_wav() {
630 let engine = SyntheticEngine::new(vec![]);
631 let task = Task::AudioTts(AudioTtsParams {
632 text: "hello world".into(),
633 voice: "default".into(),
634 ext: "wav".into(),
635 });
636 let result = engine.dispatch("synthetic", task).unwrap();
637 let (bytes, ext) = match result {
638 TaskResult::AudioTts { bytes, ext } => (bytes, ext),
639 other => panic!("expected tts, got {:?}", other.kind()),
640 };
641 assert_eq!(ext, "wav");
642 let mut reader = hound::WavReader::new(Cursor::new(bytes)).expect("real WAV should decode");
644 let spec = reader.spec();
645 assert_eq!(spec.sample_rate, 22_050);
646 assert_eq!(spec.channels, 1);
647 let samples = reader
648 .samples::<i16>()
649 .collect::<std::result::Result<Vec<_>, _>>()
650 .expect("samples should decode");
651 assert_eq!(samples.len(), 22_050); }
653
654 #[test]
655 fn synthetic_video_emits_decodable_bytes() {
656 let engine = SyntheticEngine::new(vec![]);
657 let task = Task::Video(VideoParams {
658 prompt: "a tiny dragon".into(),
659 seconds: 1.0,
660 width: 256,
661 height: 256,
662 ext: "mp4".into(), });
664 let result = engine.dispatch("synthetic", task).unwrap();
665 let (bytes, ext) = match result {
666 TaskResult::Video { bytes, ext } => (bytes, ext),
667 other => panic!("expected video, got {:?}", other.kind()),
668 };
669 assert_eq!(ext, "webp");
670 let reader = image::ImageReader::new(Cursor::new(&bytes))
671 .with_guessed_format()
672 .unwrap();
673 assert_eq!(reader.format().unwrap(), image::ImageFormat::WebP);
674 }
675
676 #[test]
677 fn synthetic_engine_advertises_all_kinds() {
678 let engine = SyntheticEngine::new(vec![]);
679 let caps = engine.capabilities();
680 for k in TaskKind::ALL {
681 assert!(
682 caps.supported_models_per_kind.contains_key(&k),
683 "{} should be advertised",
684 k.as_str()
685 );
686 }
687 assert!(caps.supports(TaskKind::Image, "synthetic"));
688 assert!(caps.supports(TaskKind::Llm, "llama-3.1-8b-instruct-q4"));
689 }
690
691 #[test]
692 fn synthetic_engine_overrides_apply_to_every_kind() {
693 let engine = SyntheticEngine::new(vec!["only-this".into()]);
694 let caps = engine.capabilities();
695 for k in TaskKind::ALL {
696 assert_eq!(caps.supported_models_per_kind[&k], vec!["only-this"]);
697 }
698 }
699
700 #[test]
701 fn synthetic_engine_is_deterministic_per_prompt() {
702 let engine = SyntheticEngine::new(vec![]);
703 let task = || {
704 Task::Image(ImageParams {
705 prompt: "deterministic".into(),
706 width: 512,
707 height: 512,
708 steps: 20,
709 seed: None,
710 ext: "webp".into(),
711 })
712 };
713 let a = engine.dispatch("synthetic", task()).unwrap();
714 let b = engine.dispatch("synthetic", task()).unwrap();
715 match (a, b) {
716 (TaskResult::Image { bytes: a, .. }, TaskResult::Image { bytes: b, .. }) => {
717 assert_eq!(a, b);
718 }
719 _ => panic!("expected images"),
720 }
721 }
722
723 #[test]
724 fn gradio_engine_refuses_non_image_tasks() {
725 let engine = GradioEngine::new("http://localhost".into(), vec!["foo".into()]);
726 let task = Task::Llm(LlmParams {
727 messages: vec![],
728 max_tokens: 1,
729 temperature: 0.0,
730 });
731 let err = engine.dispatch("foo", task).unwrap_err();
732 assert!(err.to_string().contains("cannot serve llm"));
733 }
734}