Skip to main content

studio_worker/
types.rs

1//! Shared wire-format mirrors of the studio API.
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5// ---------------------------------------------------------------------------
6// Task kinds — every job claimed by the worker is one of these.
7// ---------------------------------------------------------------------------
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
10#[serde(rename_all = "snake_case")]
11pub enum TaskKind {
12    Image,
13    Llm,
14    AudioStt,
15    AudioTts,
16    Video,
17}
18
19impl TaskKind {
20    pub const ALL: [TaskKind; 5] = [
21        TaskKind::Image,
22        TaskKind::Llm,
23        TaskKind::AudioStt,
24        TaskKind::AudioTts,
25        TaskKind::Video,
26    ];
27
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            TaskKind::Image => "image",
31            TaskKind::Llm => "llm",
32            TaskKind::AudioStt => "audio_stt",
33            TaskKind::AudioTts => "audio_tts",
34            TaskKind::Video => "video",
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Per-kind task parameters
41// ---------------------------------------------------------------------------
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct ImageParams {
46    pub prompt: String,
47    /// Negative prompt — what the model should steer AWAY from.
48    /// Optional: when `None`, `sd-cli` is invoked without
49    /// `--negative-prompt`.  Wire-format key: `negativePrompt`.
50    #[serde(default)]
51    pub negative_prompt: Option<String>,
52    /// HTTPS URL to a base image for image-to-image generation.
53    /// When set, the worker downloads the bytes to a local tempfile
54    /// and invokes `sd-cli --init-img <path>`.  Required for any
55    /// task profile that declares the `initImage` capability.
56    #[serde(default)]
57    pub init_image_url: Option<String>,
58    /// HTTPS URL to a black/white inpaint mask (white = the region the
59    /// model may repaint). When set alongside `init_image_url`, the
60    /// worker downloads it and invokes `sd-cli --mask <path>` so only
61    /// the masked region changes. Wire-format key: `maskUrl`.
62    #[serde(default)]
63    pub mask_url: Option<String>,
64    /// HTTPS URL to a reference image for instruction-edit models
65    /// (e.g. Qwen-Image-Edit / Flux Kontext).  When set, the worker
66    /// downloads it and invokes `sd-cli -r <path>` (reference mode)
67    /// instead of the `--init-img`/`--strength`/`--mask` img2img path:
68    /// the model regenerates the whole image from the reference per the
69    /// instruction prompt; per-region clipping happens in the studio
70    /// composite. Wire-format key: `refImageUrl`.
71    #[serde(default)]
72    pub ref_image_url: Option<String>,
73    /// Denoise / noise-strength for i2i (0.0 = keep init image
74    /// unchanged, 1.0 = full re-noise).  Maps to `sd-cli --strength`.
75    #[serde(default)]
76    pub denoise: Option<f32>,
77    /// Classifier-free guidance scale.  Per-job override; falls back
78    /// to `ModelSource.cliDefaults.cfgScale` when `None`.
79    #[serde(default)]
80    pub cfg_scale: Option<f32>,
81    /// Sampler choice (`euler`, `euler_a`, `dpm++2m`, ...).  Per-job
82    /// override; falls back to `ModelSource.cliDefaults.samplingMethod`.
83    #[serde(default)]
84    pub sampling_method: Option<String>,
85    #[serde(default = "default_image_dim")]
86    pub width: u32,
87    #[serde(default = "default_image_dim")]
88    pub height: u32,
89    #[serde(default = "default_steps")]
90    pub steps: u32,
91    #[serde(default)]
92    pub seed: Option<u64>,
93    #[serde(default = "default_image_ext")]
94    pub ext: String,
95}
96
97fn default_image_dim() -> u32 {
98    512
99}
100fn default_steps() -> u32 {
101    20
102}
103fn default_image_ext() -> String {
104    "webp".into()
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChatMessage {
109    pub role: String,
110    pub content: String,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct LlmParams {
116    pub messages: Vec<ChatMessage>,
117    /// System prompt prepended to the conversation.  Engines that
118    /// don't accept a separate system role inline it as the first
119    /// chat turn.
120    #[serde(default)]
121    pub system: Option<String>,
122    #[serde(default = "default_max_tokens")]
123    pub max_tokens: u32,
124    #[serde(default = "default_temperature")]
125    pub temperature: f32,
126    #[serde(default)]
127    pub top_p: Option<f32>,
128    #[serde(default)]
129    pub stop: Option<Vec<String>>,
130    /// Strict-JSON output schema.  Engine-specific; passed through
131    /// verbatim to the backend (e.g. llama.cpp's `--grammar` JSON).
132    #[serde(default)]
133    pub json_schema: Option<serde_json::Value>,
134    /// Reasoning effort hint.  Currently honoured by Gemini-style
135    /// backends only; ignored elsewhere.
136    #[serde(default)]
137    pub reasoning: Option<String>,
138}
139
140fn default_max_tokens() -> u32 {
141    512
142}
143fn default_temperature() -> f32 {
144    0.7
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct AudioSttParams {
150    /// HTTPS URL to fetch the audio bytes from (e.g. R2 signed URL).
151    pub input_url: String,
152    #[serde(default)]
153    pub language: Option<String>,
154    /// Translate the audio to English (whisper `--translate`).
155    #[serde(default)]
156    pub translate: Option<bool>,
157    /// Initial prompt to bias the transcription.
158    #[serde(default)]
159    pub prompt: Option<String>,
160    /// Voice-activity detection.
161    #[serde(default)]
162    pub vad: Option<bool>,
163    /// Timestamp granularity: `"segment"` or `"word"`.
164    #[serde(default)]
165    pub timestamps: Option<String>,
166}
167
168#[derive(Debug, Clone, Default, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct AudioTtsParams {
171    pub text: String,
172    #[serde(default = "default_voice")]
173    pub voice: String,
174    /// Playback speed multiplier (1.0 = natural pace).
175    #[serde(default)]
176    pub speed: Option<f32>,
177    /// Spoken-language hint (e.g. `"en"`, `"nl"`).
178    #[serde(default)]
179    pub language: Option<String>,
180    #[serde(default = "default_audio_ext")]
181    pub ext: String,
182}
183
184fn default_voice() -> String {
185    "default".into()
186}
187fn default_audio_ext() -> String {
188    "wav".into()
189}
190
191#[derive(Debug, Clone, Default, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct VideoParams {
194    pub prompt: String,
195    #[serde(default)]
196    pub negative_prompt: Option<String>,
197    /// HTTPS URL to a base frame for image-to-video models.
198    #[serde(default)]
199    pub init_image_url: Option<String>,
200    #[serde(default = "default_video_seconds")]
201    pub seconds: f32,
202    /// Frame rate; defaults to backend-specific value when `None`.
203    #[serde(default)]
204    pub fps: Option<u32>,
205    #[serde(default = "default_image_dim")]
206    pub width: u32,
207    #[serde(default = "default_image_dim")]
208    pub height: u32,
209    #[serde(default = "default_video_ext")]
210    pub ext: String,
211}
212
213fn default_video_seconds() -> f32 {
214    2.0
215}
216fn default_video_ext() -> String {
217    "mp4".into()
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
221#[serde(tag = "kind", rename_all = "snake_case")]
222pub enum Task {
223    Image(ImageParams),
224    Llm(LlmParams),
225    AudioStt(AudioSttParams),
226    AudioTts(AudioTtsParams),
227    Video(VideoParams),
228}
229
230impl Task {
231    pub fn kind(&self) -> TaskKind {
232        match self {
233            Task::Image(_) => TaskKind::Image,
234            Task::Llm(_) => TaskKind::Llm,
235            Task::AudioStt(_) => TaskKind::AudioStt,
236            Task::AudioTts(_) => TaskKind::AudioTts,
237            Task::Video(_) => TaskKind::Video,
238        }
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Per-kind task results
244// ---------------------------------------------------------------------------
245
246#[derive(Debug, Clone)]
247pub enum TaskResult {
248    /// Binary image (webp/png/...) with the chosen extension.
249    Image { bytes: Vec<u8>, ext: String },
250    /// JSON response from an LLM call (free-form to mirror common APIs).
251    Llm { json: serde_json::Value },
252    /// JSON transcript.
253    AudioStt { json: serde_json::Value },
254    /// Binary audio (wav/mp3/...) with the chosen extension.
255    AudioTts { bytes: Vec<u8>, ext: String },
256    /// Binary video (mp4/webm/...) with the chosen extension.
257    Video { bytes: Vec<u8>, ext: String },
258}
259
260impl TaskResult {
261    pub fn kind(&self) -> TaskKind {
262        match self {
263            TaskResult::Image { .. } => TaskKind::Image,
264            TaskResult::Llm { .. } => TaskKind::Llm,
265            TaskResult::AudioStt { .. } => TaskKind::AudioStt,
266            TaskResult::AudioTts { .. } => TaskKind::AudioTts,
267            TaskResult::Video { .. } => TaskKind::Video,
268        }
269    }
270}
271
272// ---------------------------------------------------------------------------
273// Worker capabilities + registration
274// ---------------------------------------------------------------------------
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct WorkerCapabilities {
278    #[serde(rename = "machineName")]
279    pub machine_name: String,
280    pub username: String,
281    #[serde(rename = "agentVersion")]
282    pub agent_version: String,
283    pub engine: String,
284    #[serde(rename = "vramTotalGb")]
285    pub vram_total_gb: f32,
286    #[serde(rename = "vramThresholdGb")]
287    pub vram_threshold_gb: f32,
288    #[serde(rename = "autoEnabled")]
289    pub auto_enabled: bool,
290    #[serde(rename = "autoStart")]
291    pub auto_start: bool,
292    /// Flat list of models, kept for backward compat with the existing studio
293    /// API that doesn't know about kinds yet.  Equivalent to
294    /// `supported_models_per_kind[Image]`.
295    #[serde(rename = "supportedModels")]
296    pub supported_models: Vec<String>,
297    /// New: task kinds this worker can serve.
298    #[serde(rename = "taskKinds", default)]
299    pub task_kinds: Vec<TaskKind>,
300    /// New: per-kind supported model ids.
301    #[serde(rename = "supportedModelsPerKind", default)]
302    pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
303}
304
305// ---------------------------------------------------------------------------
306// Auto-register wire format — the only registration path.
307//
308// Worker POSTs `/workers/register-request` with hostname / username /
309// VRAM / supported models, gets a `requestId` back, and polls
310// `/workers/register-requests/:id` until the operator approves or
311// rejects from the studio dashboard.
312// ---------------------------------------------------------------------------
313
314#[derive(Debug, Clone, Serialize)]
315pub struct AutoRegisterRequest {
316    /// Per-install UUID stable across worker restarts on the same
317    /// machine.  The studio uses it to dedup re-submissions.
318    #[serde(rename = "installId")]
319    pub install_id: String,
320    /// SHA-256 hex of the worker-side `registration_secret`.  The
321    /// worker keeps the secret locally and presents it as a Bearer
322    /// token when polling for status; only the hash leaves the box.
323    #[serde(rename = "registrationSecretHash")]
324    pub registration_secret_hash: String,
325    /// Full capability snapshot — hostname, username, engine, VRAM,
326    /// supported models so the operator can decide.
327    pub capabilities: WorkerCapabilities,
328    /// `studio-worker/<version>` so the operator sees stale clients.
329    #[serde(rename = "userAgent")]
330    pub user_agent: String,
331}
332
333#[derive(Debug, Clone, Deserialize)]
334pub struct AutoRegisterRequestResponse {
335    #[serde(rename = "requestId")]
336    pub request_id: String,
337    /// Always `"pending"` on first response.  Idempotent dedup may
338    /// return the existing requestId for the same
339    /// `(installId, sourceIp)` tuple — status is still "pending".
340    pub status: String,
341}
342
343/// Tagged union returned by `GET /workers/register-requests/:id`.
344#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
345#[serde(rename_all = "snake_case", tag = "status")]
346pub enum RegisterStatus {
347    Pending,
348    Approved {
349        #[serde(rename = "workerId")]
350        worker_id: String,
351        #[serde(rename = "authToken")]
352        auth_token: String,
353    },
354    Rejected {
355        #[serde(default)]
356        reason: String,
357    },
358}
359
360#[derive(Debug, Clone, Serialize)]
361pub struct HeartbeatRequest {
362    pub capabilities: WorkerCapabilities,
363    #[serde(rename = "currentJobId", skip_serializing_if = "Option::is_none")]
364    pub current_job_id: Option<String>,
365}
366
367// ---------------------------------------------------------------------------
368// ModelSource — download spec the studio attaches to every offer so
369// the worker can fetch + run any model without per-model knowledge.
370// Mirrors `WorkerModelSource` on the TS side.
371// ---------------------------------------------------------------------------
372
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
374#[serde(rename_all = "kebab-case")]
375pub enum ModelFileRole {
376    DiffusionModel,
377    TextEncoder,
378    /// Vision tower / mmproj for a multimodal text encoder (e.g.
379    /// Qwen2.5-VL ViT). Maps to `sd-cli --llm_vision`; required by
380    /// instruction-edit models that condition on the reference image.
381    TextEncoderVision,
382    Vae,
383    Lora,
384    Model,
385}
386
387#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
388#[serde(rename_all = "kebab-case")]
389pub enum ModelEngine {
390    SdCpp,
391    LlamaCpp,
392    /// ONNX Runtime image engine (pykeio/ort).  Serves the LaMa
393    /// object-removal model used by Find-the-Differences removals.
394    Onnx,
395    Synthetic,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
399#[serde(rename_all = "camelCase")]
400pub struct ModelFile {
401    pub role: ModelFileRole,
402    pub url: String,
403    pub filename: String,
404    #[serde(default, skip_serializing_if = "Option::is_none")]
405    pub approx_bytes: Option<u64>,
406}
407
408#[derive(Debug, Clone, Default, Serialize, Deserialize)]
409#[serde(rename_all = "camelCase")]
410pub struct ModelCliDefaults {
411    pub cfg_scale: f32,
412    pub steps: u32,
413    pub width: u32,
414    pub height: u32,
415    #[serde(default, skip_serializing_if = "Option::is_none")]
416    pub sampling_method: Option<String>,
417    /// `--flow-shift` for Flow models (SD3.x / WAN / Qwen-Image). Model-level constant.
418    #[serde(default, skip_serializing_if = "Option::is_none")]
419    pub flow_shift: Option<f32>,
420    /// `--qwen-image-zero-cond-t` — mandatory for Qwen-Image edit quality.
421    #[serde(default, skip_serializing_if = "Option::is_none")]
422    pub zero_cond_t: Option<bool>,
423    /// `--offload-to-cpu` — keep weights in RAM, stream to VRAM, so a large model fits a small card.
424    #[serde(default, skip_serializing_if = "Option::is_none")]
425    pub offload_to_cpu: Option<bool>,
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
429#[serde(rename_all = "camelCase")]
430pub struct ModelSource {
431    pub engine: ModelEngine,
432    pub files: Vec<ModelFile>,
433    pub cli_defaults: ModelCliDefaults,
434}
435
436// ---------------------------------------------------------------------------
437// JobClaim — every offer carries `task` + `model_source` populated by the
438// studio's registry resolver.  No legacy `prompt + ext` shadow fields.
439// ---------------------------------------------------------------------------
440
441#[derive(Debug, Clone, Deserialize)]
442#[serde(rename_all = "camelCase")]
443pub struct JobClaim {
444    pub job_id: String,
445    #[allow(dead_code)]
446    pub game_id: String,
447    pub asset_name: String,
448    pub model: String,
449    pub vram_gb_estimate: f32,
450    /// Structured task payload.  Required — the studio refuses to
451    /// promote a job without one.  Worker treats a missing `task`
452    /// as a protocol_violation.
453    pub task: Task,
454    /// Download + engine + CLI defaults the studio resolved from its
455    /// model registry.  Required — `synthetic` is just another engine
456    /// option, not a fallback for missing rows.
457    pub model_source: ModelSource,
458}
459
460#[derive(Debug, Clone, Serialize)]
461pub struct FailRequest {
462    pub error: String,
463    pub retryable: bool,
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
467#[cfg_attr(feature = "ui", derive(PartialEq, Eq))]
468pub struct LogEntry {
469    pub ts: String,
470    pub level: String,
471    pub category: String,
472    pub message: String,
473    #[serde(rename = "jobId", default, skip_serializing_if = "Option::is_none")]
474    pub job_id: Option<String>,
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct LogBatch {
479    pub entries: Vec<LogEntry>,
480}
481
482// ---------------------------------------------------------------------------
483// Release feed (auto-update)
484// ---------------------------------------------------------------------------
485
486/// Subset of the GitHub Releases API we care about.
487#[derive(Debug, Clone, Deserialize)]
488pub struct GithubRelease {
489    pub tag_name: String,
490    #[serde(default)]
491    pub prerelease: bool,
492    #[serde(default)]
493    pub draft: bool,
494    #[serde(default)]
495    pub assets: Vec<GithubReleaseAsset>,
496}
497
498#[derive(Debug, Clone, Deserialize)]
499pub struct GithubReleaseAsset {
500    pub name: String,
501    pub browser_download_url: String,
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    fn synthetic_model_source_json() -> serde_json::Value {
509        serde_json::json!({
510            "engine": "synthetic",
511            "files": [],
512            "cliDefaults": {
513                "cfgScale": 1.0,
514                "steps": 8,
515                "width": 1024,
516                "height": 1024,
517            },
518        })
519    }
520
521    #[test]
522    fn job_claim_requires_task_and_model_source() {
523        // Without `task` or `modelSource` the offer must fail to
524        // deserialise — no silent fallback to an Option/Image default.
525        let bare = serde_json::json!({
526            "jobId": "j-1",
527            "gameId": "g-1",
528            "assetName": "g-1/creatures/x",
529            "model": "synthetic-image",
530            "vramGbEstimate": 1.0,
531        });
532        assert!(
533            serde_json::from_value::<JobClaim>(bare).is_err(),
534            "JobClaim must reject missing task + modelSource"
535        );
536    }
537
538    #[test]
539    fn job_claim_with_explicit_llm_task() {
540        let json = serde_json::json!({
541            "jobId": "j-2",
542            "gameId": "g-1",
543            "assetName": "g-1/conversations/x",
544            "model": "llama-3.1-8b",
545            "vramGbEstimate": 8.0,
546            "task": {
547                "kind": "llm",
548                "messages": [{"role": "user", "content": "hi"}],
549                "maxTokens": 32,
550                "temperature": 0.5,
551            },
552            "modelSource": synthetic_model_source_json(),
553        });
554        let claim: JobClaim = serde_json::from_value(json).unwrap();
555        match claim.task {
556            Task::Llm(p) => {
557                assert_eq!(p.messages.len(), 1);
558                assert_eq!(p.max_tokens, 32);
559            }
560            other => panic!("expected llm, got {:?}", other),
561        }
562    }
563
564    #[test]
565    fn job_claim_with_explicit_image_task() {
566        let json = serde_json::json!({
567            "jobId": "j-3",
568            "gameId": "g-1",
569            "assetName": "g-1/creatures/y",
570            "model": "synthetic-image",
571            "vramGbEstimate": 8.0,
572            "task": {
573                "kind": "image",
574                "prompt": "a koi",
575                "width": 1024,
576                "height": 1024,
577                "steps": 30,
578                "ext": "png",
579            },
580            "modelSource": synthetic_model_source_json(),
581        });
582        let claim: JobClaim = serde_json::from_value(json).unwrap();
583        match claim.task {
584            Task::Image(p) => {
585                assert_eq!(p.prompt, "a koi");
586                assert_eq!(p.width, 1024);
587                assert_eq!(p.ext, "png");
588            }
589            other => panic!("expected image, got {:?}", other),
590        }
591    }
592
593    #[test]
594    fn image_params_round_trips_with_new_fields() {
595        let json = serde_json::json!({
596            "kind": "image",
597            "prompt": "a stone golem",
598            "negativePrompt": "text, watermark, low quality",
599            "initImageUrl": "https://example.invalid/t2-golem-stone/latest.webp",
600            "denoise": 0.55,
601            "cfgScale": 7.5,
602            "samplingMethod": "dpm++2m",
603            "width": 768,
604            "height": 512,
605            "steps": 30,
606            "seed": 1234,
607            "ext": "webp",
608        });
609        let task: Task = serde_json::from_value(json).unwrap();
610        match task {
611            Task::Image(p) => {
612                assert_eq!(p.prompt, "a stone golem");
613                assert_eq!(
614                    p.negative_prompt.as_deref(),
615                    Some("text, watermark, low quality")
616                );
617                assert_eq!(
618                    p.init_image_url.as_deref(),
619                    Some("https://example.invalid/t2-golem-stone/latest.webp")
620                );
621                assert!((p.denoise.unwrap() - 0.55).abs() < 1e-6);
622                assert!((p.cfg_scale.unwrap() - 7.5).abs() < 1e-6);
623                assert_eq!(p.sampling_method.as_deref(), Some("dpm++2m"));
624                assert_eq!(p.width, 768);
625                assert_eq!(p.height, 512);
626                assert_eq!(p.steps, 30);
627                assert_eq!(p.seed, Some(1234));
628            }
629            other => panic!("expected image, got {:?}", other),
630        }
631    }
632
633    #[test]
634    fn image_params_defaults_when_optional_fields_absent() {
635        let json = serde_json::json!({
636            "kind": "image",
637            "prompt": "a fox"
638        });
639        let task: Task = serde_json::from_value(json).unwrap();
640        match task {
641            Task::Image(p) => {
642                assert_eq!(p.prompt, "a fox");
643                assert!(p.negative_prompt.is_none());
644                assert!(p.init_image_url.is_none());
645                assert!(p.denoise.is_none());
646                assert!(p.cfg_scale.is_none());
647                assert!(p.sampling_method.is_none());
648                assert_eq!(p.width, 512);
649                assert_eq!(p.height, 512);
650                assert_eq!(p.steps, 20);
651                assert_eq!(p.ext, "webp");
652            }
653            other => panic!("expected image, got {:?}", other),
654        }
655    }
656
657    #[test]
658    fn task_kinds_round_trip_via_json() {
659        for kind in TaskKind::ALL {
660            let s = serde_json::to_string(&kind).unwrap();
661            let back: TaskKind = serde_json::from_str(&s).unwrap();
662            assert_eq!(kind, back);
663        }
664    }
665}