Skip to main content

studio_worker/
types.rs

1//! Shared wire-format mirrors of the studio API.
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5// ---------------------------------------------------------------------------
6// Task kinds — every job claimed by the worker is one of these.
7// ---------------------------------------------------------------------------
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
10#[serde(rename_all = "snake_case")]
11pub enum TaskKind {
12    Image,
13    Llm,
14    AudioStt,
15    AudioTts,
16    Video,
17}
18
19impl TaskKind {
20    pub const ALL: [TaskKind; 5] = [
21        TaskKind::Image,
22        TaskKind::Llm,
23        TaskKind::AudioStt,
24        TaskKind::AudioTts,
25        TaskKind::Video,
26    ];
27
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            TaskKind::Image => "image",
31            TaskKind::Llm => "llm",
32            TaskKind::AudioStt => "audio_stt",
33            TaskKind::AudioTts => "audio_tts",
34            TaskKind::Video => "video",
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Per-kind task parameters
41// ---------------------------------------------------------------------------
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct ImageParams {
46    pub prompt: String,
47    /// Negative prompt — what the model should steer AWAY from.
48    /// Optional: when `None`, `sd-cli` is invoked without
49    /// `--negative-prompt`.  Wire-format key: `negativePrompt`.
50    #[serde(default)]
51    pub negative_prompt: Option<String>,
52    /// HTTPS URL to a base image for image-to-image generation.
53    /// When set, the worker downloads the bytes to a local tempfile
54    /// and invokes `sd-cli --init-img <path>`.  Required for any
55    /// task profile that declares the `initImage` capability.
56    #[serde(default)]
57    pub init_image_url: Option<String>,
58    /// HTTPS URL to a black/white inpaint mask (white = the region the
59    /// model may repaint). When set alongside `init_image_url`, the
60    /// worker downloads it and invokes `sd-cli --mask <path>` so only
61    /// the masked region changes. Wire-format key: `maskUrl`.
62    #[serde(default)]
63    pub mask_url: Option<String>,
64    /// HTTPS URL to a reference image for instruction-edit models
65    /// (e.g. Qwen-Image-Edit / Flux Kontext).  When set, the worker
66    /// downloads it and invokes `sd-cli -r <path>` (reference mode)
67    /// instead of the `--init-img`/`--strength`/`--mask` img2img path:
68    /// the model regenerates the whole image from the reference per the
69    /// instruction prompt; per-region clipping happens in the studio
70    /// composite. Wire-format key: `refImageUrl`.
71    #[serde(default)]
72    pub ref_image_url: Option<String>,
73    /// Denoise / noise-strength for i2i (0.0 = keep init image
74    /// unchanged, 1.0 = full re-noise).  Maps to `sd-cli --strength`.
75    #[serde(default)]
76    pub denoise: Option<f32>,
77    /// Classifier-free guidance scale.  Per-job override; falls back
78    /// to `ModelSource.cliDefaults.cfgScale` when `None`.
79    #[serde(default)]
80    pub cfg_scale: Option<f32>,
81    /// Sampler choice (`euler`, `euler_a`, `dpm++2m`, ...).  Per-job
82    /// override; falls back to `ModelSource.cliDefaults.samplingMethod`.
83    #[serde(default)]
84    pub sampling_method: Option<String>,
85    #[serde(default = "default_image_dim")]
86    pub width: u32,
87    #[serde(default = "default_image_dim")]
88    pub height: u32,
89    #[serde(default = "default_steps")]
90    pub steps: u32,
91    #[serde(default)]
92    pub seed: Option<u64>,
93    #[serde(default = "default_image_ext")]
94    pub ext: String,
95}
96
97fn default_image_dim() -> u32 {
98    512
99}
100fn default_steps() -> u32 {
101    20
102}
103fn default_image_ext() -> String {
104    "webp".into()
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChatMessage {
109    pub role: String,
110    pub content: String,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct LlmParams {
116    pub messages: Vec<ChatMessage>,
117    /// System prompt prepended to the conversation.  Engines that
118    /// don't accept a separate system role inline it as the first
119    /// chat turn.
120    #[serde(default)]
121    pub system: Option<String>,
122    #[serde(default = "default_max_tokens")]
123    pub max_tokens: u32,
124    #[serde(default = "default_temperature")]
125    pub temperature: f32,
126    #[serde(default)]
127    pub top_p: Option<f32>,
128    #[serde(default)]
129    pub stop: Option<Vec<String>>,
130    /// Strict-JSON output schema.  Engine-specific; passed through
131    /// verbatim to the backend (e.g. llama.cpp's `--grammar` JSON).
132    #[serde(default)]
133    pub json_schema: Option<serde_json::Value>,
134    /// Reasoning effort hint.  Currently honoured by Gemini-style
135    /// backends only; ignored elsewhere.
136    #[serde(default)]
137    pub reasoning: Option<String>,
138}
139
140fn default_max_tokens() -> u32 {
141    512
142}
143fn default_temperature() -> f32 {
144    0.7
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct AudioSttParams {
150    /// HTTPS URL to fetch the audio bytes from (e.g. R2 signed URL).
151    pub input_url: String,
152    #[serde(default)]
153    pub language: Option<String>,
154    /// Translate the audio to English (whisper `--translate`).
155    #[serde(default)]
156    pub translate: Option<bool>,
157    /// Initial prompt to bias the transcription.
158    #[serde(default)]
159    pub prompt: Option<String>,
160    /// Voice-activity detection.
161    #[serde(default)]
162    pub vad: Option<bool>,
163    /// Timestamp granularity: `"segment"` or `"word"`.
164    #[serde(default)]
165    pub timestamps: Option<String>,
166}
167
168#[derive(Debug, Clone, Default, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct AudioTtsParams {
171    pub text: String,
172    #[serde(default = "default_voice")]
173    pub voice: String,
174    /// Playback speed multiplier (1.0 = natural pace).
175    #[serde(default)]
176    pub speed: Option<f32>,
177    /// Spoken-language hint (e.g. `"en"`, `"nl"`).
178    #[serde(default)]
179    pub language: Option<String>,
180    #[serde(default = "default_audio_ext")]
181    pub ext: String,
182}
183
184fn default_voice() -> String {
185    "default".into()
186}
187fn default_audio_ext() -> String {
188    "wav".into()
189}
190
191#[derive(Debug, Clone, Default, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct VideoParams {
194    pub prompt: String,
195    #[serde(default)]
196    pub negative_prompt: Option<String>,
197    /// HTTPS URL to a base frame for image-to-video models.
198    #[serde(default)]
199    pub init_image_url: Option<String>,
200    #[serde(default = "default_video_seconds")]
201    pub seconds: f32,
202    /// Frame rate; defaults to backend-specific value when `None`.
203    #[serde(default)]
204    pub fps: Option<u32>,
205    #[serde(default = "default_image_dim")]
206    pub width: u32,
207    #[serde(default = "default_image_dim")]
208    pub height: u32,
209    #[serde(default = "default_video_ext")]
210    pub ext: String,
211}
212
213fn default_video_seconds() -> f32 {
214    2.0
215}
216fn default_video_ext() -> String {
217    "mp4".into()
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
221#[serde(tag = "kind", rename_all = "snake_case")]
222pub enum Task {
223    Image(ImageParams),
224    Llm(LlmParams),
225    AudioStt(AudioSttParams),
226    AudioTts(AudioTtsParams),
227    Video(VideoParams),
228}
229
230impl Task {
231    pub fn kind(&self) -> TaskKind {
232        match self {
233            Task::Image(_) => TaskKind::Image,
234            Task::Llm(_) => TaskKind::Llm,
235            Task::AudioStt(_) => TaskKind::AudioStt,
236            Task::AudioTts(_) => TaskKind::AudioTts,
237            Task::Video(_) => TaskKind::Video,
238        }
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Per-kind task results
244// ---------------------------------------------------------------------------
245
246#[derive(Debug, Clone)]
247pub enum TaskResult {
248    /// Binary image (webp/png/...) with the chosen extension.
249    Image { bytes: Vec<u8>, ext: String },
250    /// JSON response from an LLM call (free-form to mirror common APIs).
251    Llm { json: serde_json::Value },
252    /// JSON transcript.
253    AudioStt { json: serde_json::Value },
254    /// Binary audio (wav/mp3/...) with the chosen extension.
255    AudioTts { bytes: Vec<u8>, ext: String },
256    /// Binary video (mp4/webm/...) with the chosen extension.
257    Video { bytes: Vec<u8>, ext: String },
258}
259
260impl TaskResult {
261    pub fn kind(&self) -> TaskKind {
262        match self {
263            TaskResult::Image { .. } => TaskKind::Image,
264            TaskResult::Llm { .. } => TaskKind::Llm,
265            TaskResult::AudioStt { .. } => TaskKind::AudioStt,
266            TaskResult::AudioTts { .. } => TaskKind::AudioTts,
267            TaskResult::Video { .. } => TaskKind::Video,
268        }
269    }
270}
271
272// ---------------------------------------------------------------------------
273// Worker capabilities + registration
274// ---------------------------------------------------------------------------
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct WorkerCapabilities {
278    #[serde(rename = "machineName")]
279    pub machine_name: String,
280    pub username: String,
281    #[serde(rename = "agentVersion")]
282    pub agent_version: String,
283    pub engine: String,
284    #[serde(rename = "vramTotalGb")]
285    pub vram_total_gb: f32,
286    #[serde(rename = "vramThresholdGb")]
287    pub vram_threshold_gb: f32,
288    #[serde(rename = "autoEnabled")]
289    pub auto_enabled: bool,
290    #[serde(rename = "autoStart")]
291    pub auto_start: bool,
292    /// Flat list of models, kept for backward compat with the existing studio
293    /// API that doesn't know about kinds yet.  Equivalent to
294    /// `supported_models_per_kind[Image]`.
295    #[serde(rename = "supportedModels")]
296    pub supported_models: Vec<String>,
297    /// New: task kinds this worker can serve.
298    #[serde(rename = "taskKinds", default)]
299    pub task_kinds: Vec<TaskKind>,
300    /// New: per-kind supported model ids.
301    #[serde(rename = "supportedModelsPerKind", default)]
302    pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
303}
304
305// ---------------------------------------------------------------------------
306// Auto-register wire format — the only registration path.
307//
308// Worker POSTs `/workers/register-request` with hostname / username /
309// VRAM / supported models, gets a `requestId` back, and polls
310// `/workers/register-requests/:id` until the operator approves or
311// rejects from the studio dashboard.
312// ---------------------------------------------------------------------------
313
314#[derive(Debug, Clone, Serialize)]
315pub struct AutoRegisterRequest {
316    /// Per-install UUID stable across worker restarts on the same
317    /// machine.  The studio uses it to dedup re-submissions.
318    #[serde(rename = "installId")]
319    pub install_id: String,
320    /// SHA-256 hex of the worker-side `registration_secret`.  The
321    /// worker keeps the secret locally and presents it as a Bearer
322    /// token when polling for status; only the hash leaves the box.
323    #[serde(rename = "registrationSecretHash")]
324    pub registration_secret_hash: String,
325    /// Full capability snapshot — hostname, username, engine, VRAM,
326    /// supported models so the operator can decide.
327    pub capabilities: WorkerCapabilities,
328    /// `studio-worker/<version>` so the operator sees stale clients.
329    #[serde(rename = "userAgent")]
330    pub user_agent: String,
331}
332
333#[derive(Debug, Clone, Deserialize)]
334pub struct AutoRegisterRequestResponse {
335    #[serde(rename = "requestId")]
336    pub request_id: String,
337    /// Always `"pending"` on first response.  Idempotent dedup may
338    /// return the existing requestId for the same
339    /// `(installId, sourceIp)` tuple — status is still "pending".
340    pub status: String,
341}
342
343/// Tagged union returned by `GET /workers/register-requests/:id`.
344#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
345#[serde(rename_all = "snake_case", tag = "status")]
346pub enum RegisterStatus {
347    Pending,
348    Approved {
349        #[serde(rename = "workerId")]
350        worker_id: String,
351        #[serde(rename = "authToken")]
352        auth_token: String,
353    },
354    Rejected {
355        #[serde(default)]
356        reason: String,
357    },
358}
359
360#[derive(Debug, Clone, Serialize)]
361pub struct HeartbeatRequest {
362    pub capabilities: WorkerCapabilities,
363    #[serde(rename = "currentJobId", skip_serializing_if = "Option::is_none")]
364    pub current_job_id: Option<String>,
365}
366
367// ---------------------------------------------------------------------------
368// ModelSource — download spec the studio attaches to every offer so
369// the worker can fetch + run any model without per-model knowledge.
370// Mirrors `WorkerModelSource` on the TS side.
371// ---------------------------------------------------------------------------
372
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
374#[serde(rename_all = "kebab-case")]
375pub enum ModelFileRole {
376    DiffusionModel,
377    TextEncoder,
378    /// Vision tower / mmproj for a multimodal text encoder (e.g.
379    /// Qwen2.5-VL ViT). Maps to `sd-cli --llm_vision`; required by
380    /// instruction-edit models that condition on the reference image.
381    TextEncoderVision,
382    Vae,
383    Lora,
384    Model,
385}
386
387#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
388#[serde(rename_all = "kebab-case")]
389pub enum ModelEngine {
390    SdCpp,
391    LlamaCpp,
392    /// ONNX Runtime image engine (pykeio/ort).  Serves the LaMa
393    /// object-removal model used by Find-the-Differences removals.
394    Onnx,
395    Synthetic,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
399#[serde(rename_all = "camelCase")]
400pub struct ModelFile {
401    pub role: ModelFileRole,
402    pub url: String,
403    pub filename: String,
404    #[serde(default, skip_serializing_if = "Option::is_none")]
405    pub approx_bytes: Option<u64>,
406    /// Hex sha256 of the file's bytes.  When present the worker
407    /// verifies the downloaded body against it before committing the
408    /// file to the cache; absent (legacy registry rows) means
409    /// Content-Length is the only integrity check.
410    #[serde(default, skip_serializing_if = "Option::is_none")]
411    pub sha256: Option<String>,
412}
413
414#[derive(Debug, Clone, Default, Serialize, Deserialize)]
415#[serde(rename_all = "camelCase")]
416pub struct ModelCliDefaults {
417    pub cfg_scale: f32,
418    pub steps: u32,
419    pub width: u32,
420    pub height: u32,
421    #[serde(default, skip_serializing_if = "Option::is_none")]
422    pub sampling_method: Option<String>,
423    /// `--flow-shift` for Flow models (SD3.x / WAN / Qwen-Image). Model-level constant.
424    #[serde(default, skip_serializing_if = "Option::is_none")]
425    pub flow_shift: Option<f32>,
426    /// `--qwen-image-zero-cond-t` — mandatory for Qwen-Image edit quality.
427    #[serde(default, skip_serializing_if = "Option::is_none")]
428    pub zero_cond_t: Option<bool>,
429    /// `--offload-to-cpu` — keep weights in RAM, stream to VRAM, so a large model fits a small card.
430    #[serde(default, skip_serializing_if = "Option::is_none")]
431    pub offload_to_cpu: Option<bool>,
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
435#[serde(rename_all = "camelCase")]
436pub struct ModelSource {
437    pub engine: ModelEngine,
438    pub files: Vec<ModelFile>,
439    pub cli_defaults: ModelCliDefaults,
440}
441
442// ---------------------------------------------------------------------------
443// JobClaim — every offer carries `task` + `model_source` populated by the
444// studio's registry resolver.  No legacy `prompt + ext` shadow fields.
445// ---------------------------------------------------------------------------
446
447#[derive(Debug, Clone, Deserialize)]
448#[serde(rename_all = "camelCase")]
449pub struct JobClaim {
450    pub job_id: String,
451    #[allow(dead_code)]
452    pub game_id: String,
453    pub asset_name: String,
454    pub model: String,
455    pub vram_gb_estimate: f32,
456    /// Structured task payload.  Required — the studio refuses to
457    /// promote a job without one.  Worker treats a missing `task`
458    /// as a protocol_violation.
459    pub task: Task,
460    /// Download + engine + CLI defaults the studio resolved from its
461    /// model registry.  Required — `synthetic` is just another engine
462    /// option, not a fallback for missing rows.
463    pub model_source: ModelSource,
464}
465
466#[derive(Debug, Clone, Serialize)]
467pub struct FailRequest {
468    pub error: String,
469    pub retryable: bool,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
473#[cfg_attr(feature = "ui", derive(PartialEq, Eq))]
474pub struct LogEntry {
475    pub ts: String,
476    pub level: String,
477    pub category: String,
478    pub message: String,
479    #[serde(rename = "jobId", default, skip_serializing_if = "Option::is_none")]
480    pub job_id: Option<String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub struct LogBatch {
485    pub entries: Vec<LogEntry>,
486}
487
488// ---------------------------------------------------------------------------
489// Release feed (auto-update)
490// ---------------------------------------------------------------------------
491
492/// Subset of the GitHub Releases API we care about.
493#[derive(Debug, Clone, Deserialize)]
494pub struct GithubRelease {
495    pub tag_name: String,
496    #[serde(default)]
497    pub prerelease: bool,
498    #[serde(default)]
499    pub draft: bool,
500    #[serde(default)]
501    pub assets: Vec<GithubReleaseAsset>,
502}
503
504#[derive(Debug, Clone, Deserialize)]
505pub struct GithubReleaseAsset {
506    pub name: String,
507    pub browser_download_url: String,
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    fn synthetic_model_source_json() -> serde_json::Value {
515        serde_json::json!({
516            "engine": "synthetic",
517            "files": [],
518            "cliDefaults": {
519                "cfgScale": 1.0,
520                "steps": 8,
521                "width": 1024,
522                "height": 1024,
523            },
524        })
525    }
526
527    #[test]
528    fn job_claim_requires_task_and_model_source() {
529        // Without `task` or `modelSource` the offer must fail to
530        // deserialise — no silent fallback to an Option/Image default.
531        let bare = serde_json::json!({
532            "jobId": "j-1",
533            "gameId": "g-1",
534            "assetName": "g-1/creatures/x",
535            "model": "synthetic-image",
536            "vramGbEstimate": 1.0,
537        });
538        assert!(
539            serde_json::from_value::<JobClaim>(bare).is_err(),
540            "JobClaim must reject missing task + modelSource"
541        );
542    }
543
544    #[test]
545    fn job_claim_with_explicit_llm_task() {
546        let json = serde_json::json!({
547            "jobId": "j-2",
548            "gameId": "g-1",
549            "assetName": "g-1/conversations/x",
550            "model": "llama-3.1-8b",
551            "vramGbEstimate": 8.0,
552            "task": {
553                "kind": "llm",
554                "messages": [{"role": "user", "content": "hi"}],
555                "maxTokens": 32,
556                "temperature": 0.5,
557            },
558            "modelSource": synthetic_model_source_json(),
559        });
560        let claim: JobClaim = serde_json::from_value(json).unwrap();
561        match claim.task {
562            Task::Llm(p) => {
563                assert_eq!(p.messages.len(), 1);
564                assert_eq!(p.max_tokens, 32);
565            }
566            other => panic!("expected llm, got {:?}", other),
567        }
568    }
569
570    #[test]
571    fn job_claim_with_explicit_image_task() {
572        let json = serde_json::json!({
573            "jobId": "j-3",
574            "gameId": "g-1",
575            "assetName": "g-1/creatures/y",
576            "model": "synthetic-image",
577            "vramGbEstimate": 8.0,
578            "task": {
579                "kind": "image",
580                "prompt": "a koi",
581                "width": 1024,
582                "height": 1024,
583                "steps": 30,
584                "ext": "png",
585            },
586            "modelSource": synthetic_model_source_json(),
587        });
588        let claim: JobClaim = serde_json::from_value(json).unwrap();
589        match claim.task {
590            Task::Image(p) => {
591                assert_eq!(p.prompt, "a koi");
592                assert_eq!(p.width, 1024);
593                assert_eq!(p.ext, "png");
594            }
595            other => panic!("expected image, got {:?}", other),
596        }
597    }
598
599    #[test]
600    fn image_params_round_trips_with_new_fields() {
601        let json = serde_json::json!({
602            "kind": "image",
603            "prompt": "a stone golem",
604            "negativePrompt": "text, watermark, low quality",
605            "initImageUrl": "https://example.invalid/t2-golem-stone/latest.webp",
606            "denoise": 0.55,
607            "cfgScale": 7.5,
608            "samplingMethod": "dpm++2m",
609            "width": 768,
610            "height": 512,
611            "steps": 30,
612            "seed": 1234,
613            "ext": "webp",
614        });
615        let task: Task = serde_json::from_value(json).unwrap();
616        match task {
617            Task::Image(p) => {
618                assert_eq!(p.prompt, "a stone golem");
619                assert_eq!(
620                    p.negative_prompt.as_deref(),
621                    Some("text, watermark, low quality")
622                );
623                assert_eq!(
624                    p.init_image_url.as_deref(),
625                    Some("https://example.invalid/t2-golem-stone/latest.webp")
626                );
627                assert!((p.denoise.unwrap() - 0.55).abs() < 1e-6);
628                assert!((p.cfg_scale.unwrap() - 7.5).abs() < 1e-6);
629                assert_eq!(p.sampling_method.as_deref(), Some("dpm++2m"));
630                assert_eq!(p.width, 768);
631                assert_eq!(p.height, 512);
632                assert_eq!(p.steps, 30);
633                assert_eq!(p.seed, Some(1234));
634            }
635            other => panic!("expected image, got {:?}", other),
636        }
637    }
638
639    #[test]
640    fn image_params_defaults_when_optional_fields_absent() {
641        let json = serde_json::json!({
642            "kind": "image",
643            "prompt": "a fox"
644        });
645        let task: Task = serde_json::from_value(json).unwrap();
646        match task {
647            Task::Image(p) => {
648                assert_eq!(p.prompt, "a fox");
649                assert!(p.negative_prompt.is_none());
650                assert!(p.init_image_url.is_none());
651                assert!(p.denoise.is_none());
652                assert!(p.cfg_scale.is_none());
653                assert!(p.sampling_method.is_none());
654                assert_eq!(p.width, 512);
655                assert_eq!(p.height, 512);
656                assert_eq!(p.steps, 20);
657                assert_eq!(p.ext, "webp");
658            }
659            other => panic!("expected image, got {:?}", other),
660        }
661    }
662
663    #[test]
664    fn task_kinds_round_trip_via_json() {
665        for kind in TaskKind::ALL {
666            let s = serde_json::to_string(&kind).unwrap();
667            let back: TaskKind = serde_json::from_str(&s).unwrap();
668            assert_eq!(kind, back);
669        }
670    }
671}