1use std::path::{Path, PathBuf};
4
5use futures::stream::{FuturesUnordered, StreamExt};
6use serde::{Deserialize, Serialize};
7use tokio::io::AsyncWriteExt;
8
9use crate::client::Client;
10use crate::error::{Error, Result};
11use crate::types::{Task, TaskId};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum OutputKind {
16 Model,
18 BaseModel,
20 PbrModel,
22 RenderedImage,
24}
25
26impl OutputKind {
27 fn filename(self, id: &TaskId, ext: &str) -> String {
28 match self {
29 Self::Model => format!("{id}.{ext}"),
30 Self::BaseModel => format!("{id}_base.{ext}"),
31 Self::PbrModel => format!("{id}_pbr.{ext}"),
32 Self::RenderedImage => format!("{id}_rendered.{ext}"),
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct DownloadOptions {
40 pub max_concurrency: usize,
42 pub overwrite: bool,
44 pub kinds: Vec<OutputKind>,
46}
47
48impl Default for DownloadOptions {
49 fn default() -> Self {
50 Self {
51 max_concurrency: 4,
52 overwrite: false,
53 kinds: vec![
54 OutputKind::Model,
55 OutputKind::BaseModel,
56 OutputKind::PbrModel,
57 OutputKind::RenderedImage,
58 ],
59 }
60 }
61}
62
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
65#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
66pub struct DownloadedFiles {
67 pub model: Option<PathBuf>,
69 pub base_model: Option<PathBuf>,
71 pub pbr_model: Option<PathBuf>,
73 pub rendered_image: Option<PathBuf>,
75}
76
77fn extension_of(url: &str, default_ext: &str) -> String {
78 let path = url.split('?').next().unwrap_or(url);
79 Path::new(path)
80 .extension()
81 .and_then(|e| e.to_str())
82 .unwrap_or(default_ext)
83 .to_string()
84}
85
86impl Client {
87 #[tracing::instrument(skip(self, task, opts), fields(task_id = %task.task_id))]
90 pub async fn download_task_models(
91 &self,
92 task: &Task,
93 dir: &Path,
94 opts: DownloadOptions,
95 ) -> Result<DownloadedFiles> {
96 tokio::fs::create_dir_all(dir).await?;
97
98 let mut jobs: Vec<(OutputKind, String, PathBuf)> = Vec::new();
99 for kind in &opts.kinds {
100 let (url, default_ext) = match kind {
101 OutputKind::Model => (&task.output.model, "glb"),
102 OutputKind::BaseModel => (&task.output.base_model, "glb"),
103 OutputKind::PbrModel => (&task.output.pbr_model, "glb"),
104 OutputKind::RenderedImage => (&task.output.rendered_image, "jpg"),
105 };
106 let Some(url) = url.clone() else { continue };
107 let ext = extension_of(&url, default_ext);
108 let target = dir.join(kind.filename(&task.task_id, &ext));
109 if !opts.overwrite && tokio::fs::try_exists(&target).await? {
110 return Err(Error::FileExists(target));
111 }
112 jobs.push((*kind, url, target));
113 }
114
115 let max = opts.max_concurrency.max(1);
116 let mut in_flight = FuturesUnordered::new();
117 let mut pending = jobs.into_iter();
118
119 let mut out = DownloadedFiles::default();
120 for _ in 0..max {
121 if let Some(job) = pending.next() {
122 in_flight.push(download_one(self, job));
123 }
124 }
125 while let Some(done) = in_flight.next().await {
126 let (kind, path) = done?;
127 match kind {
128 OutputKind::Model => out.model = Some(path),
129 OutputKind::BaseModel => out.base_model = Some(path),
130 OutputKind::PbrModel => out.pbr_model = Some(path),
131 OutputKind::RenderedImage => out.rendered_image = Some(path),
132 }
133 if let Some(job) = pending.next() {
134 in_flight.push(download_one(self, job));
135 }
136 }
137 Ok(out)
138 }
139}
140
141async fn download_one(
142 client: &Client,
143 (kind, url, target): (OutputKind, String, PathBuf),
144) -> Result<(OutputKind, PathBuf)> {
145 let mut partial = target.clone();
146 partial.as_mut_os_string().push(".partial");
147 let mut resp = client.http.get(&url).send().await?.error_for_status()?;
148 let mut f = tokio::fs::File::create(&partial).await?;
149 while let Some(chunk) = resp.chunk().await? {
150 f.write_all(&chunk).await?;
151 }
152 f.flush().await?;
153 drop(f);
154 tokio::fs::rename(&partial, &target).await?;
155 Ok((kind, target))
156}