Skip to main content

studio_worker/
types.rs

1//! Shared wire-format mirrors of the studio API.
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5// ---------------------------------------------------------------------------
6// Task kinds — every job claimed by the worker is one of these.
7// ---------------------------------------------------------------------------
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
10#[serde(rename_all = "snake_case")]
11pub enum TaskKind {
12    Image,
13    Llm,
14    AudioStt,
15    AudioTts,
16    Video,
17}
18
19impl TaskKind {
20    pub const ALL: [TaskKind; 5] = [
21        TaskKind::Image,
22        TaskKind::Llm,
23        TaskKind::AudioStt,
24        TaskKind::AudioTts,
25        TaskKind::Video,
26    ];
27
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            TaskKind::Image => "image",
31            TaskKind::Llm => "llm",
32            TaskKind::AudioStt => "audio_stt",
33            TaskKind::AudioTts => "audio_tts",
34            TaskKind::Video => "video",
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Per-kind task parameters
41// ---------------------------------------------------------------------------
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ImageParams {
45    pub prompt: String,
46    #[serde(default = "default_image_dim")]
47    pub width: u32,
48    #[serde(default = "default_image_dim")]
49    pub height: u32,
50    #[serde(default = "default_steps")]
51    pub steps: u32,
52    #[serde(default)]
53    pub seed: Option<u64>,
54    #[serde(default = "default_image_ext")]
55    pub ext: String,
56}
57
58fn default_image_dim() -> u32 {
59    512
60}
61fn default_steps() -> u32 {
62    20
63}
64fn default_image_ext() -> String {
65    "webp".into()
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ChatMessage {
70    pub role: String,
71    pub content: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct LlmParams {
76    pub messages: Vec<ChatMessage>,
77    #[serde(default = "default_max_tokens")]
78    pub max_tokens: u32,
79    #[serde(default = "default_temperature")]
80    pub temperature: f32,
81}
82
83fn default_max_tokens() -> u32 {
84    512
85}
86fn default_temperature() -> f32 {
87    0.7
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct AudioSttParams {
92    /// HTTPS URL to fetch the audio bytes from (e.g. R2 signed URL).
93    pub input_url: String,
94    #[serde(default)]
95    pub language: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct AudioTtsParams {
100    pub text: String,
101    #[serde(default = "default_voice")]
102    pub voice: String,
103    #[serde(default = "default_audio_ext")]
104    pub ext: String,
105}
106
107fn default_voice() -> String {
108    "default".into()
109}
110fn default_audio_ext() -> String {
111    "wav".into()
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct VideoParams {
116    pub prompt: String,
117    #[serde(default = "default_video_seconds")]
118    pub seconds: f32,
119    #[serde(default = "default_image_dim")]
120    pub width: u32,
121    #[serde(default = "default_image_dim")]
122    pub height: u32,
123    #[serde(default = "default_video_ext")]
124    pub ext: String,
125}
126
127fn default_video_seconds() -> f32 {
128    2.0
129}
130fn default_video_ext() -> String {
131    "mp4".into()
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(tag = "kind", rename_all = "snake_case")]
136pub enum Task {
137    Image(ImageParams),
138    Llm(LlmParams),
139    AudioStt(AudioSttParams),
140    AudioTts(AudioTtsParams),
141    Video(VideoParams),
142}
143
144impl Task {
145    pub fn kind(&self) -> TaskKind {
146        match self {
147            Task::Image(_) => TaskKind::Image,
148            Task::Llm(_) => TaskKind::Llm,
149            Task::AudioStt(_) => TaskKind::AudioStt,
150            Task::AudioTts(_) => TaskKind::AudioTts,
151            Task::Video(_) => TaskKind::Video,
152        }
153    }
154}
155
156// ---------------------------------------------------------------------------
157// Per-kind task results
158// ---------------------------------------------------------------------------
159
160#[derive(Debug, Clone)]
161pub enum TaskResult {
162    /// Binary image (webp/png/...) with the chosen extension.
163    Image { bytes: Vec<u8>, ext: String },
164    /// JSON response from an LLM call (free-form to mirror common APIs).
165    Llm { json: serde_json::Value },
166    /// JSON transcript.
167    AudioStt { json: serde_json::Value },
168    /// Binary audio (wav/mp3/...) with the chosen extension.
169    AudioTts { bytes: Vec<u8>, ext: String },
170    /// Binary video (mp4/webm/...) with the chosen extension.
171    Video { bytes: Vec<u8>, ext: String },
172}
173
174impl TaskResult {
175    pub fn kind(&self) -> TaskKind {
176        match self {
177            TaskResult::Image { .. } => TaskKind::Image,
178            TaskResult::Llm { .. } => TaskKind::Llm,
179            TaskResult::AudioStt { .. } => TaskKind::AudioStt,
180            TaskResult::AudioTts { .. } => TaskKind::AudioTts,
181            TaskResult::Video { .. } => TaskKind::Video,
182        }
183    }
184}
185
186// ---------------------------------------------------------------------------
187// Worker capabilities + registration
188// ---------------------------------------------------------------------------
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct WorkerCapabilities {
192    #[serde(rename = "machineName")]
193    pub machine_name: String,
194    pub username: String,
195    #[serde(rename = "agentVersion")]
196    pub agent_version: String,
197    pub engine: String,
198    #[serde(rename = "vramTotalGb")]
199    pub vram_total_gb: f32,
200    #[serde(rename = "vramThresholdGb")]
201    pub vram_threshold_gb: f32,
202    #[serde(rename = "autoEnabled")]
203    pub auto_enabled: bool,
204    #[serde(rename = "autoStart")]
205    pub auto_start: bool,
206    /// Flat list of models, kept for backward compat with the existing studio
207    /// API that doesn't know about kinds yet.  Equivalent to
208    /// `supported_models_per_kind[Image]`.
209    #[serde(rename = "supportedModels")]
210    pub supported_models: Vec<String>,
211    /// New: task kinds this worker can serve.
212    #[serde(rename = "taskKinds", default)]
213    pub task_kinds: Vec<TaskKind>,
214    /// New: per-kind supported model ids.
215    #[serde(rename = "supportedModelsPerKind", default)]
216    pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
217}
218
219// ---------------------------------------------------------------------------
220// Auto-register wire format — the only registration path.
221//
222// Worker POSTs `/workers/register-request` with hostname / username /
223// VRAM / supported models, gets a `requestId` back, and polls
224// `/workers/register-requests/:id` until the operator approves or
225// rejects from the studio dashboard.
226// ---------------------------------------------------------------------------
227
228#[derive(Debug, Clone, Serialize)]
229pub struct AutoRegisterRequest {
230    /// Per-install UUID stable across worker restarts on the same
231    /// machine.  The studio uses it to dedup re-submissions.
232    #[serde(rename = "installId")]
233    pub install_id: String,
234    /// SHA-256 hex of the worker-side `registration_secret`.  The
235    /// worker keeps the secret locally and presents it as a Bearer
236    /// token when polling for status; only the hash leaves the box.
237    #[serde(rename = "registrationSecretHash")]
238    pub registration_secret_hash: String,
239    /// Optional human label the operator sees in the Pending Workers
240    /// panel (e.g. "alice's gaming rig").
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub label: Option<String>,
243    /// Full capability snapshot — hostname, username, engine, VRAM,
244    /// supported models so the operator can decide.
245    pub capabilities: WorkerCapabilities,
246    /// `studio-worker/<version>` so the operator sees stale clients.
247    #[serde(rename = "userAgent")]
248    pub user_agent: String,
249}
250
251#[derive(Debug, Clone, Deserialize)]
252pub struct AutoRegisterRequestResponse {
253    #[serde(rename = "requestId")]
254    pub request_id: String,
255    /// Always `"pending"` on first response.  Idempotent dedup may
256    /// return the existing requestId for the same
257    /// `(installId, sourceIp)` tuple — status is still "pending".
258    pub status: String,
259}
260
261/// Tagged union returned by `GET /workers/register-requests/:id`.
262#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
263#[serde(rename_all = "snake_case", tag = "status")]
264pub enum RegisterStatus {
265    Pending,
266    Approved {
267        #[serde(rename = "workerId")]
268        worker_id: String,
269        #[serde(rename = "authToken")]
270        auth_token: String,
271    },
272    Rejected {
273        #[serde(default)]
274        reason: String,
275    },
276}
277
278#[derive(Debug, Clone, Serialize)]
279pub struct HeartbeatRequest {
280    pub capabilities: WorkerCapabilities,
281    #[serde(rename = "currentJobId", skip_serializing_if = "Option::is_none")]
282    pub current_job_id: Option<String>,
283}
284
285// ---------------------------------------------------------------------------
286// JobClaim — backward-compatible with the existing image-only studio API.
287// ---------------------------------------------------------------------------
288
289#[derive(Debug, Clone, Deserialize)]
290pub struct JobClaim {
291    #[serde(rename = "jobId")]
292    pub job_id: String,
293    #[serde(rename = "gameId")]
294    #[allow(dead_code)]
295    pub game_id: String,
296    #[serde(rename = "assetName")]
297    pub asset_name: String,
298    pub model: String,
299    #[serde(rename = "vramGbEstimate")]
300    pub vram_gb_estimate: f32,
301    /// Legacy field (image prompt).  Ignored when `task` is present.
302    #[serde(default)]
303    pub prompt: String,
304    /// Legacy field (image extension).  Ignored when `task` is present.
305    #[serde(default = "default_image_ext")]
306    pub ext: String,
307    /// New: structured task spec.  When absent the worker reconstructs
308    /// an `Task::Image` from `prompt` + `ext` above for backward compat.
309    #[serde(default)]
310    pub task: Option<Task>,
311}
312
313impl JobClaim {
314    /// Resolve the structured task, applying the legacy fallback if
315    /// `task` is missing.
316    pub fn resolved_task(&self) -> Task {
317        if let Some(t) = self.task.clone() {
318            return t;
319        }
320        Task::Image(ImageParams {
321            prompt: self.prompt.clone(),
322            width: 512,
323            height: 512,
324            steps: 20,
325            seed: None,
326            ext: self.ext.clone(),
327        })
328    }
329}
330
331#[derive(Debug, Clone, Serialize)]
332pub struct FailRequest {
333    pub error: String,
334    pub retryable: bool,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338#[cfg_attr(feature = "ui", derive(PartialEq, Eq))]
339pub struct LogEntry {
340    pub ts: String,
341    pub level: String,
342    pub category: String,
343    pub message: String,
344    #[serde(rename = "jobId", default, skip_serializing_if = "Option::is_none")]
345    pub job_id: Option<String>,
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct LogBatch {
350    pub entries: Vec<LogEntry>,
351}
352
353// ---------------------------------------------------------------------------
354// Release feed (auto-update)
355// ---------------------------------------------------------------------------
356
357/// Subset of the GitHub Releases API we care about.
358#[derive(Debug, Clone, Deserialize)]
359pub struct GithubRelease {
360    pub tag_name: String,
361    #[serde(default)]
362    pub prerelease: bool,
363    #[serde(default)]
364    pub draft: bool,
365    #[serde(default)]
366    pub assets: Vec<GithubReleaseAsset>,
367}
368
369#[derive(Debug, Clone, Deserialize)]
370pub struct GithubReleaseAsset {
371    pub name: String,
372    pub browser_download_url: String,
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn job_claim_with_no_task_falls_back_to_image() {
381        let json = serde_json::json!({
382            "jobId": "j-1",
383            "gameId": "g-1",
384            "assetName": "g-1/creatures/x",
385            "model": "synthetic",
386            "vramGbEstimate": 1.0,
387            "prompt": "a stone golem",
388            "ext": "webp",
389        });
390        let claim: JobClaim = serde_json::from_value(json).unwrap();
391        match claim.resolved_task() {
392            Task::Image(p) => {
393                assert_eq!(p.prompt, "a stone golem");
394                assert_eq!(p.ext, "webp");
395            }
396            other => panic!("expected image, got {:?}", other),
397        }
398    }
399
400    #[test]
401    fn job_claim_with_explicit_llm_task() {
402        let json = serde_json::json!({
403            "jobId": "j-2",
404            "gameId": "g-1",
405            "assetName": "g-1/conversations/x",
406            "model": "llama-3.1-8b",
407            "vramGbEstimate": 8.0,
408            "task": {
409                "kind": "llm",
410                "messages": [{"role": "user", "content": "hi"}],
411                "max_tokens": 32,
412                "temperature": 0.5,
413            },
414        });
415        let claim: JobClaim = serde_json::from_value(json).unwrap();
416        match claim.resolved_task() {
417            Task::Llm(p) => {
418                assert_eq!(p.messages.len(), 1);
419                assert_eq!(p.max_tokens, 32);
420            }
421            other => panic!("expected llm, got {:?}", other),
422        }
423    }
424
425    #[test]
426    fn job_claim_with_explicit_image_task() {
427        let json = serde_json::json!({
428            "jobId": "j-3",
429            "gameId": "g-1",
430            "assetName": "g-1/creatures/y",
431            "model": "synthetic",
432            "vramGbEstimate": 8.0,
433            "task": {
434                "kind": "image",
435                "prompt": "a koi",
436                "width": 1024,
437                "height": 1024,
438                "steps": 30,
439                "ext": "png",
440            },
441        });
442        let claim: JobClaim = serde_json::from_value(json).unwrap();
443        match claim.resolved_task() {
444            Task::Image(p) => {
445                assert_eq!(p.prompt, "a koi");
446                assert_eq!(p.width, 1024);
447                assert_eq!(p.ext, "png");
448            }
449            other => panic!("expected image, got {:?}", other),
450        }
451    }
452
453    #[test]
454    fn task_kinds_round_trip_via_json() {
455        for kind in TaskKind::ALL {
456            let s = serde_json::to_string(&kind).unwrap();
457            let back: TaskKind = serde_json::from_str(&s).unwrap();
458            assert_eq!(kind, back);
459        }
460    }
461}