Skip to main content

studio_worker/
types.rs

1//! Shared wire-format mirrors of the studio API.
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5// ---------------------------------------------------------------------------
6// Task kinds — every job claimed by the worker is one of these.
7// ---------------------------------------------------------------------------
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
10#[serde(rename_all = "snake_case")]
11pub enum TaskKind {
12    Image,
13    Llm,
14    AudioStt,
15    AudioTts,
16    Video,
17}
18
19impl TaskKind {
20    pub const ALL: [TaskKind; 5] = [
21        TaskKind::Image,
22        TaskKind::Llm,
23        TaskKind::AudioStt,
24        TaskKind::AudioTts,
25        TaskKind::Video,
26    ];
27
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            TaskKind::Image => "image",
31            TaskKind::Llm => "llm",
32            TaskKind::AudioStt => "audio_stt",
33            TaskKind::AudioTts => "audio_tts",
34            TaskKind::Video => "video",
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Per-kind task parameters
41// ---------------------------------------------------------------------------
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44#[serde(rename_all = "camelCase")]
45pub struct ImageParams {
46    pub prompt: String,
47    /// Negative prompt — what the model should steer AWAY from.
48    /// Optional: when `None`, `sd-cli` is invoked without
49    /// `--negative-prompt`.  Wire-format key: `negativePrompt`.
50    #[serde(default)]
51    pub negative_prompt: Option<String>,
52    /// HTTPS URL to a base image for image-to-image generation.
53    /// When set, the worker downloads the bytes to a local tempfile
54    /// and invokes `sd-cli --init-img <path>`.  Required for any
55    /// task profile that declares the `initImage` capability.
56    #[serde(default)]
57    pub init_image_url: Option<String>,
58    /// HTTPS URL to a black/white inpaint mask (white = the region the
59    /// model may repaint). When set alongside `init_image_url`, the
60    /// worker downloads it and invokes `sd-cli --mask <path>` so only
61    /// the masked region changes. Wire-format key: `maskUrl`.
62    #[serde(default)]
63    pub mask_url: Option<String>,
64    /// HTTPS URL to a reference image for instruction-edit models
65    /// (e.g. Qwen-Image-Edit / Flux Kontext).  When set, the worker
66    /// downloads it and invokes `sd-cli -r <path>` (reference mode)
67    /// instead of the `--init-img`/`--strength`/`--mask` img2img path:
68    /// the model regenerates the whole image from the reference per the
69    /// instruction prompt; per-region clipping happens in the studio
70    /// composite. Wire-format key: `refImageUrl`.
71    #[serde(default)]
72    pub ref_image_url: Option<String>,
73    /// Denoise / noise-strength for i2i (0.0 = keep init image
74    /// unchanged, 1.0 = full re-noise).  Maps to `sd-cli --strength`.
75    #[serde(default)]
76    pub denoise: Option<f32>,
77    /// Classifier-free guidance scale.  Per-job override; falls back
78    /// to `ModelSource.cliDefaults.cfgScale` when `None`.
79    #[serde(default)]
80    pub cfg_scale: Option<f32>,
81    /// Sampler choice (`euler`, `euler_a`, `dpm++2m`, ...).  Per-job
82    /// override; falls back to `ModelSource.cliDefaults.samplingMethod`.
83    #[serde(default)]
84    pub sampling_method: Option<String>,
85    #[serde(default = "default_image_dim")]
86    pub width: u32,
87    #[serde(default = "default_image_dim")]
88    pub height: u32,
89    #[serde(default = "default_steps")]
90    pub steps: u32,
91    #[serde(default)]
92    pub seed: Option<u64>,
93    #[serde(default = "default_image_ext")]
94    pub ext: String,
95}
96
97fn default_image_dim() -> u32 {
98    512
99}
100fn default_steps() -> u32 {
101    20
102}
103fn default_image_ext() -> String {
104    "webp".into()
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChatMessage {
109    pub role: String,
110    pub content: String,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct LlmParams {
116    pub messages: Vec<ChatMessage>,
117    /// System prompt prepended to the conversation.  Engines that
118    /// don't accept a separate system role inline it as the first
119    /// chat turn.
120    #[serde(default)]
121    pub system: Option<String>,
122    #[serde(default = "default_max_tokens")]
123    pub max_tokens: u32,
124    #[serde(default = "default_temperature")]
125    pub temperature: f32,
126    #[serde(default)]
127    pub top_p: Option<f32>,
128    #[serde(default)]
129    pub stop: Option<Vec<String>>,
130    /// Strict-JSON output schema.  Engine-specific; passed through
131    /// verbatim to the backend (e.g. llama.cpp's `--grammar` JSON).
132    #[serde(default)]
133    pub json_schema: Option<serde_json::Value>,
134    /// Reasoning effort hint.  Currently honoured by Gemini-style
135    /// backends only; ignored elsewhere.
136    #[serde(default)]
137    pub reasoning: Option<String>,
138}
139
140fn default_max_tokens() -> u32 {
141    512
142}
143fn default_temperature() -> f32 {
144    0.7
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct AudioSttParams {
150    /// HTTPS URL to fetch the audio bytes from (e.g. R2 signed URL).
151    pub input_url: String,
152    #[serde(default)]
153    pub language: Option<String>,
154    /// Translate the audio to English (whisper `--translate`).
155    #[serde(default)]
156    pub translate: Option<bool>,
157    /// Initial prompt to bias the transcription.
158    #[serde(default)]
159    pub prompt: Option<String>,
160    /// Voice-activity detection.
161    #[serde(default)]
162    pub vad: Option<bool>,
163    /// Timestamp granularity: `"segment"` or `"word"`.
164    #[serde(default)]
165    pub timestamps: Option<String>,
166}
167
168#[derive(Debug, Clone, Default, Serialize, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct AudioTtsParams {
171    pub text: String,
172    #[serde(default = "default_voice")]
173    pub voice: String,
174    /// Playback speed multiplier (1.0 = natural pace).
175    #[serde(default)]
176    pub speed: Option<f32>,
177    /// Spoken-language hint (e.g. `"en"`, `"nl"`).
178    #[serde(default)]
179    pub language: Option<String>,
180    #[serde(default = "default_audio_ext")]
181    pub ext: String,
182}
183
184fn default_voice() -> String {
185    "default".into()
186}
187fn default_audio_ext() -> String {
188    "wav".into()
189}
190
191#[derive(Debug, Clone, Default, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct VideoParams {
194    pub prompt: String,
195    #[serde(default)]
196    pub negative_prompt: Option<String>,
197    /// HTTPS URL to a base frame for image-to-video models.
198    #[serde(default)]
199    pub init_image_url: Option<String>,
200    #[serde(default = "default_video_seconds")]
201    pub seconds: f32,
202    /// Frame rate; defaults to backend-specific value when `None`.
203    #[serde(default)]
204    pub fps: Option<u32>,
205    #[serde(default = "default_image_dim")]
206    pub width: u32,
207    #[serde(default = "default_image_dim")]
208    pub height: u32,
209    #[serde(default = "default_video_ext")]
210    pub ext: String,
211}
212
213fn default_video_seconds() -> f32 {
214    2.0
215}
216fn default_video_ext() -> String {
217    "mp4".into()
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
221#[serde(tag = "kind", rename_all = "snake_case")]
222pub enum Task {
223    Image(ImageParams),
224    Llm(LlmParams),
225    AudioStt(AudioSttParams),
226    AudioTts(AudioTtsParams),
227    Video(VideoParams),
228}
229
230impl Task {
231    pub fn kind(&self) -> TaskKind {
232        match self {
233            Task::Image(_) => TaskKind::Image,
234            Task::Llm(_) => TaskKind::Llm,
235            Task::AudioStt(_) => TaskKind::AudioStt,
236            Task::AudioTts(_) => TaskKind::AudioTts,
237            Task::Video(_) => TaskKind::Video,
238        }
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Per-kind task results
244// ---------------------------------------------------------------------------
245
246#[derive(Debug, Clone)]
247pub enum TaskResult {
248    /// Binary image (webp/png/...) with the chosen extension.
249    Image { bytes: Vec<u8>, ext: String },
250    /// JSON response from an LLM call (free-form to mirror common APIs).
251    Llm { json: serde_json::Value },
252    /// JSON transcript.
253    AudioStt { json: serde_json::Value },
254    /// Binary audio (wav/mp3/...) with the chosen extension.
255    AudioTts { bytes: Vec<u8>, ext: String },
256    /// Binary video (mp4/webm/...) with the chosen extension.
257    Video { bytes: Vec<u8>, ext: String },
258}
259
260impl TaskResult {
261    pub fn kind(&self) -> TaskKind {
262        match self {
263            TaskResult::Image { .. } => TaskKind::Image,
264            TaskResult::Llm { .. } => TaskKind::Llm,
265            TaskResult::AudioStt { .. } => TaskKind::AudioStt,
266            TaskResult::AudioTts { .. } => TaskKind::AudioTts,
267            TaskResult::Video { .. } => TaskKind::Video,
268        }
269    }
270}
271
272// ---------------------------------------------------------------------------
273// Worker capabilities + registration
274// ---------------------------------------------------------------------------
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct WorkerCapabilities {
278    #[serde(rename = "machineName")]
279    pub machine_name: String,
280    pub username: String,
281    #[serde(rename = "agentVersion")]
282    pub agent_version: String,
283    pub engine: String,
284    #[serde(rename = "vramTotalGb")]
285    pub vram_total_gb: f32,
286    #[serde(rename = "vramThresholdGb")]
287    pub vram_threshold_gb: f32,
288    #[serde(rename = "autoEnabled")]
289    pub auto_enabled: bool,
290    #[serde(rename = "autoStart")]
291    pub auto_start: bool,
292    /// Flat list of models, kept for backward compat with the existing studio
293    /// API that doesn't know about kinds yet.  Equivalent to
294    /// `supported_models_per_kind[Image]`.
295    #[serde(rename = "supportedModels")]
296    pub supported_models: Vec<String>,
297    /// New: task kinds this worker can serve.
298    #[serde(rename = "taskKinds", default)]
299    pub task_kinds: Vec<TaskKind>,
300    /// New: per-kind supported model ids.
301    #[serde(rename = "supportedModelsPerKind", default)]
302    pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
303}
304
305// ---------------------------------------------------------------------------
306// Auto-register wire format — the only registration path.
307//
308// Worker POSTs `/workers/register-request` with hostname / username /
309// VRAM / supported models, gets a `requestId` back, and polls
310// `/workers/register-requests/:id` until the operator approves or
311// rejects from the studio dashboard.
312// ---------------------------------------------------------------------------
313
314#[derive(Debug, Clone, Serialize)]
315pub struct AutoRegisterRequest {
316    /// Per-install UUID stable across worker restarts on the same
317    /// machine.  The studio uses it to dedup re-submissions.
318    #[serde(rename = "installId")]
319    pub install_id: String,
320    /// SHA-256 hex of the worker-side `registration_secret`.  The
321    /// worker keeps the secret locally and presents it as a Bearer
322    /// token when polling for status; only the hash leaves the box.
323    #[serde(rename = "registrationSecretHash")]
324    pub registration_secret_hash: String,
325    /// Full capability snapshot — hostname, username, engine, VRAM,
326    /// supported models so the operator can decide.
327    pub capabilities: WorkerCapabilities,
328    /// `studio-worker/<version>` so the operator sees stale clients.
329    #[serde(rename = "userAgent")]
330    pub user_agent: String,
331}
332
333#[derive(Debug, Clone, Deserialize)]
334pub struct AutoRegisterRequestResponse {
335    #[serde(rename = "requestId")]
336    pub request_id: String,
337    /// Always `"pending"` on first response.  Idempotent dedup may
338    /// return the existing requestId for the same
339    /// `(installId, sourceIp)` tuple — status is still "pending".
340    pub status: String,
341}
342
343/// Tagged union returned by `GET /workers/register-requests/:id`.
344#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
345#[serde(rename_all = "snake_case", tag = "status")]
346pub enum RegisterStatus {
347    Pending,
348    Approved {
349        #[serde(rename = "workerId")]
350        worker_id: String,
351        #[serde(rename = "authToken")]
352        auth_token: String,
353    },
354    Rejected {
355        #[serde(default)]
356        reason: String,
357    },
358}
359
360#[derive(Debug, Clone, Serialize)]
361pub struct HeartbeatRequest {
362    pub capabilities: WorkerCapabilities,
363    #[serde(rename = "currentJobId", skip_serializing_if = "Option::is_none")]
364    pub current_job_id: Option<String>,
365}
366
367// ---------------------------------------------------------------------------
368// ModelSource — download spec the studio attaches to every offer so
369// the worker can fetch + run any model without per-model knowledge.
370// Mirrors `WorkerModelSource` on the TS side.
371// ---------------------------------------------------------------------------
372
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
374#[serde(rename_all = "kebab-case")]
375pub enum ModelFileRole {
376    DiffusionModel,
377    TextEncoder,
378    /// Vision tower / mmproj for a multimodal text encoder (e.g.
379    /// Qwen2.5-VL ViT). Maps to `sd-cli --llm_vision`; required by
380    /// instruction-edit models that condition on the reference image.
381    TextEncoderVision,
382    Vae,
383    Lora,
384    Model,
385}
386
387#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
388#[serde(rename_all = "kebab-case")]
389pub enum ModelEngine {
390    SdCpp,
391    LlamaCpp,
392    Synthetic,
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
396#[serde(rename_all = "camelCase")]
397pub struct ModelFile {
398    pub role: ModelFileRole,
399    pub url: String,
400    pub filename: String,
401    #[serde(default, skip_serializing_if = "Option::is_none")]
402    pub approx_bytes: Option<u64>,
403}
404
405#[derive(Debug, Clone, Default, Serialize, Deserialize)]
406#[serde(rename_all = "camelCase")]
407pub struct ModelCliDefaults {
408    pub cfg_scale: f32,
409    pub steps: u32,
410    pub width: u32,
411    pub height: u32,
412    #[serde(default, skip_serializing_if = "Option::is_none")]
413    pub sampling_method: Option<String>,
414    /// `--flow-shift` for Flow models (SD3.x / WAN / Qwen-Image). Model-level constant.
415    #[serde(default, skip_serializing_if = "Option::is_none")]
416    pub flow_shift: Option<f32>,
417    /// `--qwen-image-zero-cond-t` — mandatory for Qwen-Image edit quality.
418    #[serde(default, skip_serializing_if = "Option::is_none")]
419    pub zero_cond_t: Option<bool>,
420    /// `--offload-to-cpu` — keep weights in RAM, stream to VRAM, so a large model fits a small card.
421    #[serde(default, skip_serializing_if = "Option::is_none")]
422    pub offload_to_cpu: Option<bool>,
423}
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
426#[serde(rename_all = "camelCase")]
427pub struct ModelSource {
428    pub engine: ModelEngine,
429    pub files: Vec<ModelFile>,
430    pub cli_defaults: ModelCliDefaults,
431}
432
433// ---------------------------------------------------------------------------
434// JobClaim — every offer carries `task` + `model_source` populated by the
435// studio's registry resolver.  No legacy `prompt + ext` shadow fields.
436// ---------------------------------------------------------------------------
437
438#[derive(Debug, Clone, Deserialize)]
439#[serde(rename_all = "camelCase")]
440pub struct JobClaim {
441    pub job_id: String,
442    #[allow(dead_code)]
443    pub game_id: String,
444    pub asset_name: String,
445    pub model: String,
446    pub vram_gb_estimate: f32,
447    /// Structured task payload.  Required — the studio refuses to
448    /// promote a job without one.  Worker treats a missing `task`
449    /// as a protocol_violation.
450    pub task: Task,
451    /// Download + engine + CLI defaults the studio resolved from its
452    /// model registry.  Required — `synthetic` is just another engine
453    /// option, not a fallback for missing rows.
454    pub model_source: ModelSource,
455}
456
457#[derive(Debug, Clone, Serialize)]
458pub struct FailRequest {
459    pub error: String,
460    pub retryable: bool,
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
464#[cfg_attr(feature = "ui", derive(PartialEq, Eq))]
465pub struct LogEntry {
466    pub ts: String,
467    pub level: String,
468    pub category: String,
469    pub message: String,
470    #[serde(rename = "jobId", default, skip_serializing_if = "Option::is_none")]
471    pub job_id: Option<String>,
472}
473
474#[derive(Debug, Clone, Serialize, Deserialize)]
475pub struct LogBatch {
476    pub entries: Vec<LogEntry>,
477}
478
479// ---------------------------------------------------------------------------
480// Release feed (auto-update)
481// ---------------------------------------------------------------------------
482
483/// Subset of the GitHub Releases API we care about.
484#[derive(Debug, Clone, Deserialize)]
485pub struct GithubRelease {
486    pub tag_name: String,
487    #[serde(default)]
488    pub prerelease: bool,
489    #[serde(default)]
490    pub draft: bool,
491    #[serde(default)]
492    pub assets: Vec<GithubReleaseAsset>,
493}
494
495#[derive(Debug, Clone, Deserialize)]
496pub struct GithubReleaseAsset {
497    pub name: String,
498    pub browser_download_url: String,
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    fn synthetic_model_source_json() -> serde_json::Value {
506        serde_json::json!({
507            "engine": "synthetic",
508            "files": [],
509            "cliDefaults": {
510                "cfgScale": 1.0,
511                "steps": 8,
512                "width": 1024,
513                "height": 1024,
514            },
515        })
516    }
517
518    #[test]
519    fn job_claim_requires_task_and_model_source() {
520        // Without `task` or `modelSource` the offer must fail to
521        // deserialise — no silent fallback to an Option/Image default.
522        let bare = serde_json::json!({
523            "jobId": "j-1",
524            "gameId": "g-1",
525            "assetName": "g-1/creatures/x",
526            "model": "synthetic-image",
527            "vramGbEstimate": 1.0,
528        });
529        assert!(
530            serde_json::from_value::<JobClaim>(bare).is_err(),
531            "JobClaim must reject missing task + modelSource"
532        );
533    }
534
535    #[test]
536    fn job_claim_with_explicit_llm_task() {
537        let json = serde_json::json!({
538            "jobId": "j-2",
539            "gameId": "g-1",
540            "assetName": "g-1/conversations/x",
541            "model": "llama-3.1-8b",
542            "vramGbEstimate": 8.0,
543            "task": {
544                "kind": "llm",
545                "messages": [{"role": "user", "content": "hi"}],
546                "maxTokens": 32,
547                "temperature": 0.5,
548            },
549            "modelSource": synthetic_model_source_json(),
550        });
551        let claim: JobClaim = serde_json::from_value(json).unwrap();
552        match claim.task {
553            Task::Llm(p) => {
554                assert_eq!(p.messages.len(), 1);
555                assert_eq!(p.max_tokens, 32);
556            }
557            other => panic!("expected llm, got {:?}", other),
558        }
559    }
560
561    #[test]
562    fn job_claim_with_explicit_image_task() {
563        let json = serde_json::json!({
564            "jobId": "j-3",
565            "gameId": "g-1",
566            "assetName": "g-1/creatures/y",
567            "model": "synthetic-image",
568            "vramGbEstimate": 8.0,
569            "task": {
570                "kind": "image",
571                "prompt": "a koi",
572                "width": 1024,
573                "height": 1024,
574                "steps": 30,
575                "ext": "png",
576            },
577            "modelSource": synthetic_model_source_json(),
578        });
579        let claim: JobClaim = serde_json::from_value(json).unwrap();
580        match claim.task {
581            Task::Image(p) => {
582                assert_eq!(p.prompt, "a koi");
583                assert_eq!(p.width, 1024);
584                assert_eq!(p.ext, "png");
585            }
586            other => panic!("expected image, got {:?}", other),
587        }
588    }
589
590    #[test]
591    fn image_params_round_trips_with_new_fields() {
592        let json = serde_json::json!({
593            "kind": "image",
594            "prompt": "a stone golem",
595            "negativePrompt": "text, watermark, low quality",
596            "initImageUrl": "https://example.invalid/t2-golem-stone/latest.webp",
597            "denoise": 0.55,
598            "cfgScale": 7.5,
599            "samplingMethod": "dpm++2m",
600            "width": 768,
601            "height": 512,
602            "steps": 30,
603            "seed": 1234,
604            "ext": "webp",
605        });
606        let task: Task = serde_json::from_value(json).unwrap();
607        match task {
608            Task::Image(p) => {
609                assert_eq!(p.prompt, "a stone golem");
610                assert_eq!(
611                    p.negative_prompt.as_deref(),
612                    Some("text, watermark, low quality")
613                );
614                assert_eq!(
615                    p.init_image_url.as_deref(),
616                    Some("https://example.invalid/t2-golem-stone/latest.webp")
617                );
618                assert!((p.denoise.unwrap() - 0.55).abs() < 1e-6);
619                assert!((p.cfg_scale.unwrap() - 7.5).abs() < 1e-6);
620                assert_eq!(p.sampling_method.as_deref(), Some("dpm++2m"));
621                assert_eq!(p.width, 768);
622                assert_eq!(p.height, 512);
623                assert_eq!(p.steps, 30);
624                assert_eq!(p.seed, Some(1234));
625            }
626            other => panic!("expected image, got {:?}", other),
627        }
628    }
629
630    #[test]
631    fn image_params_defaults_when_optional_fields_absent() {
632        let json = serde_json::json!({
633            "kind": "image",
634            "prompt": "a fox"
635        });
636        let task: Task = serde_json::from_value(json).unwrap();
637        match task {
638            Task::Image(p) => {
639                assert_eq!(p.prompt, "a fox");
640                assert!(p.negative_prompt.is_none());
641                assert!(p.init_image_url.is_none());
642                assert!(p.denoise.is_none());
643                assert!(p.cfg_scale.is_none());
644                assert!(p.sampling_method.is_none());
645                assert_eq!(p.width, 512);
646                assert_eq!(p.height, 512);
647                assert_eq!(p.steps, 20);
648                assert_eq!(p.ext, "webp");
649            }
650            other => panic!("expected image, got {:?}", other),
651        }
652    }
653
654    #[test]
655    fn task_kinds_round_trip_via_json() {
656        for kind in TaskKind::ALL {
657            let s = serde_json::to_string(&kind).unwrap();
658            let back: TaskKind = serde_json::from_str(&s).unwrap();
659            assert_eq!(kind, back);
660        }
661    }
662}