rust_genai/
files.rs

1//! Files API surface.
2
3use std::path::Path;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use rust_genai_types::enums::FileState;
8use rust_genai_types::files::{
9    DownloadFileConfig, File, ListFilesConfig, ListFilesResponse, UploadFileConfig,
10};
11use serde_json::Value;
12use tokio::io::AsyncReadExt;
13
14use crate::client::{Backend, ClientInner};
15use crate::error::{Error, Result};
16
17const CHUNK_SIZE: usize = 8 * 1024 * 1024;
18
19#[derive(Clone)]
20pub struct Files {
21    pub(crate) inner: Arc<ClientInner>,
22}
23
24impl Files {
25    pub(crate) fn new(inner: Arc<ClientInner>) -> Self {
26        Self { inner }
27    }
28
29    /// 上传文件(直接上传字节数据)。
30    pub async fn upload(&self, data: Vec<u8>, mime_type: impl Into<String>) -> Result<File> {
31        let config = UploadFileConfig {
32            mime_type: Some(mime_type.into()),
33            ..UploadFileConfig::default()
34        };
35        self.upload_with_config(data, config).await
36    }
37
38    /// 上传文件(自定义配置)。
39    pub async fn upload_with_config(
40        &self,
41        data: Vec<u8>,
42        config: UploadFileConfig,
43    ) -> Result<File> {
44        ensure_gemini_backend(&self.inner)?;
45
46        let mime_type = config
47            .mime_type
48            .clone()
49            .ok_or_else(|| Error::InvalidConfig {
50                message: "mime_type is required when uploading raw bytes".into(),
51            })?;
52        let size_bytes = data.len() as u64;
53        let file = build_upload_file(config, size_bytes, &mime_type)?;
54        let upload_url = self
55            .start_resumable_upload(file, size_bytes, &mime_type, None)
56            .await?;
57        self.upload_bytes(&upload_url, &data).await
58    }
59
60    /// 从文件路径上传。
61    pub async fn upload_from_path(&self, path: impl AsRef<Path>) -> Result<File> {
62        self.upload_from_path_with_config(path, UploadFileConfig::default())
63            .await
64    }
65
66    /// 从文件路径上传(自定义配置)。
67    pub async fn upload_from_path_with_config(
68        &self,
69        path: impl AsRef<Path>,
70        mut config: UploadFileConfig,
71    ) -> Result<File> {
72        ensure_gemini_backend(&self.inner)?;
73
74        let path = path.as_ref();
75        let metadata = tokio::fs::metadata(path).await?;
76        if !metadata.is_file() {
77            return Err(Error::InvalidConfig {
78                message: format!("{} is not a valid file path", path.display()),
79            });
80        }
81
82        let size_bytes = metadata.len();
83        let mime_type = if let Some(value) = config.mime_type.take() {
84            value
85        } else {
86            mime_guess::from_path(path)
87                .first_or_octet_stream()
88                .essence_str()
89                .to_string()
90        };
91
92        let file_name = path.file_name().and_then(|name| name.to_str());
93        let file = build_upload_file(config, size_bytes, &mime_type)?;
94        let upload_url = self
95            .start_resumable_upload(file, size_bytes, &mime_type, file_name)
96            .await?;
97        let mut file_handle = tokio::fs::File::open(path).await?;
98        self.upload_reader(&upload_url, &mut file_handle, size_bytes)
99            .await
100    }
101
102    /// 下载文件(返回字节内容)。
103    pub async fn download(&self, name_or_uri: impl AsRef<str>) -> Result<Vec<u8>> {
104        ensure_gemini_backend(&self.inner)?;
105
106        let file_name = normalize_file_name(name_or_uri.as_ref())?;
107        let url = build_file_download_url(&self.inner, &file_name)?;
108        let request = self.inner.http.get(url);
109        let response = self.inner.send(request).await?;
110        if !response.status().is_success() {
111            return Err(Error::ApiError {
112                status: response.status().as_u16(),
113                message: response.text().await.unwrap_or_default(),
114            });
115        }
116        let bytes = response.bytes().await?;
117        Ok(bytes.to_vec())
118    }
119
120    #[allow(unused_variables)]
121    /// 下载文件(自定义配置)。
122    pub async fn download_with_config(
123        &self,
124        name_or_uri: impl AsRef<str>,
125        _config: DownloadFileConfig,
126    ) -> Result<Vec<u8>> {
127        self.download(name_or_uri).await
128    }
129
130    /// 列出文件。
131    pub async fn list(&self) -> Result<ListFilesResponse> {
132        self.list_with_config(ListFilesConfig::default()).await
133    }
134
135    /// 列出文件(自定义配置)。
136    pub async fn list_with_config(&self, config: ListFilesConfig) -> Result<ListFilesResponse> {
137        ensure_gemini_backend(&self.inner)?;
138        let url = build_files_list_url(&self.inner, &config)?;
139        let request = self.inner.http.get(url);
140        let response = self.inner.send(request).await?;
141        if !response.status().is_success() {
142            return Err(Error::ApiError {
143                status: response.status().as_u16(),
144                message: response.text().await.unwrap_or_default(),
145            });
146        }
147        Ok(response.json::<ListFilesResponse>().await?)
148    }
149
150    /// 列出所有文件(自动翻页)。
151    pub async fn all(&self) -> Result<Vec<File>> {
152        self.all_with_config(ListFilesConfig::default()).await
153    }
154
155    /// 列出所有文件(带配置,自动翻页)。
156    pub async fn all_with_config(&self, mut config: ListFilesConfig) -> Result<Vec<File>> {
157        let mut files = Vec::new();
158        loop {
159            let response = self.list_with_config(config.clone()).await?;
160            if let Some(items) = response.files {
161                files.extend(items);
162            }
163            match response.next_page_token {
164                Some(token) if !token.is_empty() => {
165                    config.page_token = Some(token);
166                }
167                _ => break,
168            }
169        }
170        Ok(files)
171    }
172
173    /// 获取文件元数据。
174    pub async fn get(&self, name_or_uri: impl AsRef<str>) -> Result<File> {
175        ensure_gemini_backend(&self.inner)?;
176
177        let file_name = normalize_file_name(name_or_uri.as_ref())?;
178        let url = build_file_url(&self.inner, &file_name)?;
179        let request = self.inner.http.get(url);
180        let response = self.inner.send(request).await?;
181        if !response.status().is_success() {
182            return Err(Error::ApiError {
183                status: response.status().as_u16(),
184                message: response.text().await.unwrap_or_default(),
185            });
186        }
187        Ok(response.json::<File>().await?)
188    }
189
190    /// 删除文件。
191    pub async fn delete(&self, name_or_uri: impl AsRef<str>) -> Result<()> {
192        ensure_gemini_backend(&self.inner)?;
193
194        let file_name = normalize_file_name(name_or_uri.as_ref())?;
195        let url = build_file_url(&self.inner, &file_name)?;
196        let request = self.inner.http.delete(url);
197        let response = self.inner.send(request).await?;
198        if !response.status().is_success() {
199            return Err(Error::ApiError {
200                status: response.status().as_u16(),
201                message: response.text().await.unwrap_or_default(),
202            });
203        }
204        Ok(())
205    }
206
207    /// 轮询直到文件状态变为 ACTIVE。
208    pub async fn wait_for_active(
209        &self,
210        name_or_uri: impl AsRef<str>,
211        config: WaitForFileConfig,
212    ) -> Result<File> {
213        ensure_gemini_backend(&self.inner)?;
214
215        let start = Instant::now();
216        loop {
217            let file = self.get(name_or_uri.as_ref()).await?;
218            match file.state {
219                Some(FileState::Active) => return Ok(file),
220                Some(FileState::Failed) => {
221                    return Err(Error::ApiError {
222                        status: 500,
223                        message: "File processing failed".into(),
224                    })
225                }
226                _ => {}
227            }
228
229            if let Some(timeout) = config.timeout {
230                if start.elapsed() >= timeout {
231                    return Err(Error::Timeout {
232                        message: "Timed out waiting for file to become ACTIVE".into(),
233                    });
234                }
235            }
236
237            tokio::time::sleep(config.poll_interval).await;
238        }
239    }
240
241    async fn start_resumable_upload(
242        &self,
243        file: File,
244        size_bytes: u64,
245        mime_type: &str,
246        file_name: Option<&str>,
247    ) -> Result<String> {
248        let url = build_files_upload_url(&self.inner)?;
249        let mut request = self
250            .inner
251            .http
252            .post(url)
253            .header("X-Goog-Upload-Protocol", "resumable")
254            .header("X-Goog-Upload-Command", "start")
255            .header(
256                "X-Goog-Upload-Header-Content-Length",
257                size_bytes.to_string(),
258            )
259            .header("X-Goog-Upload-Header-Content-Type", mime_type);
260
261        if let Some(file_name) = file_name {
262            request = request.header("X-Goog-Upload-File-Name", file_name);
263        }
264
265        let body = serde_json::json!({ "file": file });
266        let request = request.json(&body);
267        let response = self.inner.send(request).await?;
268        if !response.status().is_success() {
269            return Err(Error::ApiError {
270                status: response.status().as_u16(),
271                message: response.text().await.unwrap_or_default(),
272            });
273        }
274
275        let upload_url = response
276            .headers()
277            .get("x-goog-upload-url")
278            .and_then(|value| value.to_str().ok())
279            .ok_or_else(|| Error::Parse {
280                message: "Missing x-goog-upload-url header".into(),
281            })?;
282
283        Ok(upload_url.to_string())
284    }
285
286    async fn upload_bytes(&self, upload_url: &str, data: &[u8]) -> Result<File> {
287        if data.is_empty() {
288            let (status, file) = self.send_upload_chunk(upload_url, &[], 0, true).await?;
289            return finalize_upload(status, file);
290        }
291
292        let mut offset: usize = 0;
293        while offset < data.len() {
294            let end = (offset + CHUNK_SIZE).min(data.len());
295            let finalize = end == data.len();
296            let (status, file) = self
297                .send_upload_chunk(upload_url, &data[offset..end], offset as u64, finalize)
298                .await?;
299
300            if finalize {
301                return finalize_upload(status, file);
302            }
303
304            if status != "active" {
305                return Err(Error::Parse {
306                    message: format!("Unexpected upload status: {status}"),
307                });
308            }
309
310            offset = end;
311        }
312
313        Err(Error::Parse {
314            message: "Upload finished without final response".into(),
315        })
316    }
317
318    async fn upload_reader(
319        &self,
320        upload_url: &str,
321        reader: &mut tokio::fs::File,
322        total_size: u64,
323    ) -> Result<File> {
324        if total_size == 0 {
325            let (status, file) = self.send_upload_chunk(upload_url, &[], 0, true).await?;
326            return finalize_upload(status, file);
327        }
328
329        let mut offset: u64 = 0;
330        let mut buffer = vec![0u8; CHUNK_SIZE];
331        while offset < total_size {
332            let read_bytes = reader.read(&mut buffer).await?;
333            if read_bytes == 0 {
334                return Err(Error::Parse {
335                    message: "Unexpected EOF while uploading file".into(),
336                });
337            }
338
339            let finalize = offset + read_bytes as u64 >= total_size;
340            let (status, file) = self
341                .send_upload_chunk(upload_url, &buffer[..read_bytes], offset, finalize)
342                .await?;
343
344            if finalize {
345                return finalize_upload(status, file);
346            }
347
348            if status != "active" {
349                return Err(Error::Parse {
350                    message: format!("Unexpected upload status: {status}"),
351                });
352            }
353
354            offset += read_bytes as u64;
355        }
356
357        Err(Error::Parse {
358            message: "Upload finished without final response".into(),
359        })
360    }
361
362    async fn send_upload_chunk(
363        &self,
364        upload_url: &str,
365        chunk: &[u8],
366        offset: u64,
367        finalize: bool,
368    ) -> Result<(String, Option<File>)> {
369        let command = if finalize {
370            "upload, finalize"
371        } else {
372            "upload"
373        };
374        let response = self
375            .inner
376            .http
377            .post(upload_url)
378            .header("X-Goog-Upload-Command", command)
379            .header("X-Goog-Upload-Offset", offset.to_string())
380            .header("Content-Length", chunk.len().to_string())
381            .body(chunk.to_vec())
382            .send()
383            .await?;
384
385        if !response.status().is_success() {
386            return Err(Error::ApiError {
387                status: response.status().as_u16(),
388                message: response.text().await.unwrap_or_default(),
389            });
390        }
391
392        let upload_status = response
393            .headers()
394            .get("x-goog-upload-status")
395            .and_then(|value| value.to_str().ok())
396            .ok_or_else(|| Error::Parse {
397                message: "Missing x-goog-upload-status header".into(),
398            })?
399            .to_string();
400
401        let body = response.bytes().await?;
402        if body.is_empty() {
403            return Ok((upload_status, None));
404        }
405
406        let value: Value = serde_json::from_slice(&body)?;
407        let file_value = value.get("file").cloned().unwrap_or(value);
408        let file: File = serde_json::from_value(file_value)?;
409
410        Ok((upload_status, Some(file)))
411    }
412}
413
414#[derive(Debug, Clone)]
415pub struct WaitForFileConfig {
416    pub poll_interval: Duration,
417    pub timeout: Option<Duration>,
418}
419
420impl Default for WaitForFileConfig {
421    fn default() -> Self {
422        Self {
423            poll_interval: Duration::from_secs(2),
424            timeout: Some(Duration::from_secs(300)),
425        }
426    }
427}
428
429fn finalize_upload(status: String, file: Option<File>) -> Result<File> {
430    if status != "final" {
431        return Err(Error::Parse {
432            message: format!("Upload finalize failed: {status}"),
433        });
434    }
435    file.ok_or_else(|| Error::Parse {
436        message: "Upload completed but response body was empty".into(),
437    })
438}
439
440fn ensure_gemini_backend(inner: &ClientInner) -> Result<()> {
441    if inner.config.backend == Backend::VertexAi {
442        return Err(Error::InvalidConfig {
443            message: "Files API is only supported in Gemini API".into(),
444        });
445    }
446    Ok(())
447}
448
449fn build_upload_file(config: UploadFileConfig, size_bytes: u64, mime_type: &str) -> Result<File> {
450    let mut file = File::default();
451    if let Some(name) = config.name {
452        file.name = Some(normalize_upload_name(&name));
453    }
454    file.display_name = config.display_name;
455    file.mime_type = Some(mime_type.to_string());
456    file.size_bytes = Some(size_bytes.to_string());
457    Ok(file)
458}
459
460fn normalize_upload_name(name: &str) -> String {
461    if name.starts_with("files/") {
462        name.to_string()
463    } else {
464        format!("files/{name}")
465    }
466}
467
468fn normalize_file_name(value: &str) -> Result<String> {
469    if value.starts_with("http://") || value.starts_with("https://") {
470        let marker = "files/";
471        let start = value.find(marker).ok_or_else(|| Error::InvalidConfig {
472            message: format!("Could not find 'files/' in URI: {value}"),
473        })?;
474        let suffix = &value[start + marker.len()..];
475        let name: String = suffix
476            .chars()
477            .take_while(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || *c == '-')
478            .collect();
479        if name.is_empty() {
480            return Err(Error::InvalidConfig {
481                message: format!("Could not extract file name from URI: {value}"),
482            });
483        }
484        Ok(name)
485    } else if value.starts_with("files/") {
486        Ok(value.trim_start_matches("files/").to_string())
487    } else {
488        Ok(value.to_string())
489    }
490}
491
492fn build_files_upload_url(inner: &ClientInner) -> Result<String> {
493    let base = &inner.api_client.base_url;
494    let version = &inner.api_client.api_version;
495    Ok(format!("{base}upload/{version}/files"))
496}
497
498fn build_files_list_url(inner: &ClientInner, config: &ListFilesConfig) -> Result<String> {
499    let base = &inner.api_client.base_url;
500    let version = &inner.api_client.api_version;
501    let url = format!("{base}{version}/files");
502    add_list_query_params(url, config)
503}
504
505fn build_file_url(inner: &ClientInner, name: &str) -> Result<String> {
506    let base = &inner.api_client.base_url;
507    let version = &inner.api_client.api_version;
508    Ok(format!("{base}{version}/files/{name}"))
509}
510
511fn build_file_download_url(inner: &ClientInner, name: &str) -> Result<String> {
512    let base = &inner.api_client.base_url;
513    let version = &inner.api_client.api_version;
514    Ok(format!("{base}{version}/files/{name}:download?alt=media"))
515}
516
517fn add_list_query_params(url: String, config: &ListFilesConfig) -> Result<String> {
518    let mut url = reqwest::Url::parse(&url).map_err(|err| Error::InvalidConfig {
519        message: err.to_string(),
520    })?;
521    {
522        let mut pairs = url.query_pairs_mut();
523        if let Some(page_size) = config.page_size {
524            pairs.append_pair("pageSize", &page_size.to_string());
525        }
526        if let Some(page_token) = &config.page_token {
527            pairs.append_pair("pageToken", page_token);
528        }
529    }
530    Ok(url.to_string())
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use crate::client::Client;
537
538    #[test]
539    fn test_normalize_file_name() {
540        assert_eq!(normalize_file_name("files/abc-123").unwrap(), "abc-123");
541        assert_eq!(normalize_file_name("abc-123").unwrap(), "abc-123");
542        assert_eq!(
543            normalize_file_name("https://example.com/files/abc-123?foo=bar").unwrap(),
544            "abc-123"
545        );
546    }
547
548    #[test]
549    fn test_build_urls() {
550        let client = Client::new("test-key").unwrap();
551        let files = client.files();
552        let url = build_files_upload_url(&files.inner).unwrap();
553        assert_eq!(
554            url,
555            "https://generativelanguage.googleapis.com/upload/v1beta/files"
556        );
557    }
558}