Skip to main content

studio_worker/
types.rs

1//! Shared wire-format mirrors of the studio API.
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5// ---------------------------------------------------------------------------
6// Task kinds — every job claimed by the worker is one of these.
7// ---------------------------------------------------------------------------
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
10#[serde(rename_all = "snake_case")]
11pub enum TaskKind {
12    Image,
13    Llm,
14    AudioStt,
15    AudioTts,
16    Video,
17}
18
19impl TaskKind {
20    pub const ALL: [TaskKind; 5] = [
21        TaskKind::Image,
22        TaskKind::Llm,
23        TaskKind::AudioStt,
24        TaskKind::AudioTts,
25        TaskKind::Video,
26    ];
27
28    pub fn as_str(&self) -> &'static str {
29        match self {
30            TaskKind::Image => "image",
31            TaskKind::Llm => "llm",
32            TaskKind::AudioStt => "audio_stt",
33            TaskKind::AudioTts => "audio_tts",
34            TaskKind::Video => "video",
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Per-kind task parameters
41// ---------------------------------------------------------------------------
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ImageParams {
45    pub prompt: String,
46    #[serde(default = "default_image_dim")]
47    pub width: u32,
48    #[serde(default = "default_image_dim")]
49    pub height: u32,
50    #[serde(default = "default_steps")]
51    pub steps: u32,
52    #[serde(default)]
53    pub seed: Option<u64>,
54    #[serde(default = "default_image_ext")]
55    pub ext: String,
56}
57
58fn default_image_dim() -> u32 {
59    512
60}
61fn default_steps() -> u32 {
62    20
63}
64fn default_image_ext() -> String {
65    "webp".into()
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ChatMessage {
70    pub role: String,
71    pub content: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct LlmParams {
76    pub messages: Vec<ChatMessage>,
77    #[serde(default = "default_max_tokens")]
78    pub max_tokens: u32,
79    #[serde(default = "default_temperature")]
80    pub temperature: f32,
81}
82
83fn default_max_tokens() -> u32 {
84    512
85}
86fn default_temperature() -> f32 {
87    0.7
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct AudioSttParams {
92    /// HTTPS URL to fetch the audio bytes from (e.g. R2 signed URL).
93    pub input_url: String,
94    #[serde(default)]
95    pub language: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct AudioTtsParams {
100    pub text: String,
101    #[serde(default = "default_voice")]
102    pub voice: String,
103    #[serde(default = "default_audio_ext")]
104    pub ext: String,
105}
106
107fn default_voice() -> String {
108    "default".into()
109}
110fn default_audio_ext() -> String {
111    "wav".into()
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct VideoParams {
116    pub prompt: String,
117    #[serde(default = "default_video_seconds")]
118    pub seconds: f32,
119    #[serde(default = "default_image_dim")]
120    pub width: u32,
121    #[serde(default = "default_image_dim")]
122    pub height: u32,
123    #[serde(default = "default_video_ext")]
124    pub ext: String,
125}
126
127fn default_video_seconds() -> f32 {
128    2.0
129}
130fn default_video_ext() -> String {
131    "mp4".into()
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(tag = "kind", rename_all = "snake_case")]
136pub enum Task {
137    Image(ImageParams),
138    Llm(LlmParams),
139    AudioStt(AudioSttParams),
140    AudioTts(AudioTtsParams),
141    Video(VideoParams),
142}
143
144impl Task {
145    pub fn kind(&self) -> TaskKind {
146        match self {
147            Task::Image(_) => TaskKind::Image,
148            Task::Llm(_) => TaskKind::Llm,
149            Task::AudioStt(_) => TaskKind::AudioStt,
150            Task::AudioTts(_) => TaskKind::AudioTts,
151            Task::Video(_) => TaskKind::Video,
152        }
153    }
154}
155
156// ---------------------------------------------------------------------------
157// Per-kind task results
158// ---------------------------------------------------------------------------
159
160#[derive(Debug, Clone)]
161pub enum TaskResult {
162    /// Binary image (webp/png/...) with the chosen extension.
163    Image { bytes: Vec<u8>, ext: String },
164    /// JSON response from an LLM call (free-form to mirror common APIs).
165    Llm { json: serde_json::Value },
166    /// JSON transcript.
167    AudioStt { json: serde_json::Value },
168    /// Binary audio (wav/mp3/...) with the chosen extension.
169    AudioTts { bytes: Vec<u8>, ext: String },
170    /// Binary video (mp4/webm/...) with the chosen extension.
171    Video { bytes: Vec<u8>, ext: String },
172}
173
174impl TaskResult {
175    pub fn kind(&self) -> TaskKind {
176        match self {
177            TaskResult::Image { .. } => TaskKind::Image,
178            TaskResult::Llm { .. } => TaskKind::Llm,
179            TaskResult::AudioStt { .. } => TaskKind::AudioStt,
180            TaskResult::AudioTts { .. } => TaskKind::AudioTts,
181            TaskResult::Video { .. } => TaskKind::Video,
182        }
183    }
184}
185
186// ---------------------------------------------------------------------------
187// Worker capabilities + registration
188// ---------------------------------------------------------------------------
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct WorkerCapabilities {
192    #[serde(rename = "machineName")]
193    pub machine_name: String,
194    pub username: String,
195    #[serde(rename = "agentVersion")]
196    pub agent_version: String,
197    pub engine: String,
198    #[serde(rename = "vramTotalGb")]
199    pub vram_total_gb: f32,
200    #[serde(rename = "vramThresholdGb")]
201    pub vram_threshold_gb: f32,
202    #[serde(rename = "autoEnabled")]
203    pub auto_enabled: bool,
204    #[serde(rename = "autoStart")]
205    pub auto_start: bool,
206    /// Flat list of models, kept for backward compat with the existing studio
207    /// API that doesn't know about kinds yet.  Equivalent to
208    /// `supported_models_per_kind[Image]`.
209    #[serde(rename = "supportedModels")]
210    pub supported_models: Vec<String>,
211    /// New: task kinds this worker can serve.
212    #[serde(rename = "taskKinds", default)]
213    pub task_kinds: Vec<TaskKind>,
214    /// New: per-kind supported model ids.
215    #[serde(rename = "supportedModelsPerKind", default)]
216    pub supported_models_per_kind: BTreeMap<TaskKind, Vec<String>>,
217}
218
219#[derive(Debug, Clone, Serialize)]
220pub struct RegisterRequest {
221    #[serde(rename = "bootstrapToken")]
222    pub bootstrap_token: String,
223    pub capabilities: WorkerCapabilities,
224    #[serde(rename = "workerId", skip_serializing_if = "Option::is_none")]
225    pub worker_id: Option<String>,
226}
227
228#[derive(Debug, Clone, Deserialize)]
229pub struct RegisterResponse {
230    #[serde(rename = "workerId")]
231    pub worker_id: String,
232    #[serde(rename = "authToken")]
233    pub auth_token: String,
234}
235
236#[derive(Debug, Clone, Serialize)]
237pub struct HeartbeatRequest {
238    pub capabilities: WorkerCapabilities,
239    #[serde(rename = "currentJobId", skip_serializing_if = "Option::is_none")]
240    pub current_job_id: Option<String>,
241}
242
243// ---------------------------------------------------------------------------
244// JobClaim — backward-compatible with the existing image-only studio API.
245// ---------------------------------------------------------------------------
246
247#[derive(Debug, Clone, Deserialize)]
248pub struct JobClaim {
249    #[serde(rename = "jobId")]
250    pub job_id: String,
251    #[serde(rename = "gameId")]
252    #[allow(dead_code)]
253    pub game_id: String,
254    #[serde(rename = "assetName")]
255    pub asset_name: String,
256    pub model: String,
257    #[serde(rename = "vramGbEstimate")]
258    pub vram_gb_estimate: f32,
259    /// Legacy field (image prompt).  Ignored when `task` is present.
260    #[serde(default)]
261    pub prompt: String,
262    /// Legacy field (image extension).  Ignored when `task` is present.
263    #[serde(default = "default_image_ext")]
264    pub ext: String,
265    /// New: structured task spec.  When absent the worker reconstructs
266    /// an `Task::Image` from `prompt` + `ext` above for backward compat.
267    #[serde(default)]
268    pub task: Option<Task>,
269}
270
271impl JobClaim {
272    /// Resolve the structured task, applying the legacy fallback if
273    /// `task` is missing.
274    pub fn resolved_task(&self) -> Task {
275        if let Some(t) = self.task.clone() {
276            return t;
277        }
278        Task::Image(ImageParams {
279            prompt: self.prompt.clone(),
280            width: 512,
281            height: 512,
282            steps: 20,
283            seed: None,
284            ext: self.ext.clone(),
285        })
286    }
287}
288
289#[derive(Debug, Clone, Serialize)]
290pub struct FailRequest {
291    pub error: String,
292    pub retryable: bool,
293}
294
295#[derive(Debug, Clone, Serialize)]
296#[cfg_attr(feature = "ui", derive(PartialEq, Eq))]
297pub struct LogEntry {
298    pub ts: String,
299    pub level: String,
300    pub category: String,
301    pub message: String,
302    #[serde(rename = "jobId", skip_serializing_if = "Option::is_none")]
303    pub job_id: Option<String>,
304}
305
306#[derive(Debug, Clone, Serialize)]
307pub struct LogBatch {
308    pub entries: Vec<LogEntry>,
309}
310
311// ---------------------------------------------------------------------------
312// Release feed (auto-update)
313// ---------------------------------------------------------------------------
314
315/// Subset of the GitHub Releases API we care about.
316#[derive(Debug, Clone, Deserialize)]
317pub struct GithubRelease {
318    pub tag_name: String,
319    #[serde(default)]
320    pub prerelease: bool,
321    #[serde(default)]
322    pub draft: bool,
323    #[serde(default)]
324    pub assets: Vec<GithubReleaseAsset>,
325}
326
327#[derive(Debug, Clone, Deserialize)]
328pub struct GithubReleaseAsset {
329    pub name: String,
330    pub browser_download_url: String,
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn job_claim_with_no_task_falls_back_to_image() {
339        let json = serde_json::json!({
340            "jobId": "j-1",
341            "gameId": "g-1",
342            "assetName": "g-1/creatures/x",
343            "model": "synthetic",
344            "vramGbEstimate": 1.0,
345            "prompt": "a stone golem",
346            "ext": "webp",
347        });
348        let claim: JobClaim = serde_json::from_value(json).unwrap();
349        match claim.resolved_task() {
350            Task::Image(p) => {
351                assert_eq!(p.prompt, "a stone golem");
352                assert_eq!(p.ext, "webp");
353            }
354            other => panic!("expected image, got {:?}", other),
355        }
356    }
357
358    #[test]
359    fn job_claim_with_explicit_llm_task() {
360        let json = serde_json::json!({
361            "jobId": "j-2",
362            "gameId": "g-1",
363            "assetName": "g-1/conversations/x",
364            "model": "llama-3.1-8b",
365            "vramGbEstimate": 8.0,
366            "task": {
367                "kind": "llm",
368                "messages": [{"role": "user", "content": "hi"}],
369                "max_tokens": 32,
370                "temperature": 0.5,
371            },
372        });
373        let claim: JobClaim = serde_json::from_value(json).unwrap();
374        match claim.resolved_task() {
375            Task::Llm(p) => {
376                assert_eq!(p.messages.len(), 1);
377                assert_eq!(p.max_tokens, 32);
378            }
379            other => panic!("expected llm, got {:?}", other),
380        }
381    }
382
383    #[test]
384    fn job_claim_with_explicit_image_task() {
385        let json = serde_json::json!({
386            "jobId": "j-3",
387            "gameId": "g-1",
388            "assetName": "g-1/creatures/y",
389            "model": "synthetic",
390            "vramGbEstimate": 8.0,
391            "task": {
392                "kind": "image",
393                "prompt": "a koi",
394                "width": 1024,
395                "height": 1024,
396                "steps": 30,
397                "ext": "png",
398            },
399        });
400        let claim: JobClaim = serde_json::from_value(json).unwrap();
401        match claim.resolved_task() {
402            Task::Image(p) => {
403                assert_eq!(p.prompt, "a koi");
404                assert_eq!(p.width, 1024);
405                assert_eq!(p.ext, "png");
406            }
407            other => panic!("expected image, got {:?}", other),
408        }
409    }
410
411    #[test]
412    fn task_kinds_round_trip_via_json() {
413        for kind in TaskKind::ALL {
414            let s = serde_json::to_string(&kind).unwrap();
415            let back: TaskKind = serde_json::from_str(&s).unwrap();
416            assert_eq!(kind, back);
417        }
418    }
419}