1use reqwest::{multipart, Client};
5use serde::{Deserialize, Serialize};
6use std::path::Path;
7use thiserror::Error;
8use tokio::fs::File;
9use tokio_util::codec::{BytesCodec, FramedRead};
10
11use crate::{
12 utils::{api_key, remove_trailing_slash, OpenAiApiKeyError},
13 OpenAiError,
14};
15
16pub struct FilesClient {
18 pub api_key: String,
20 pub base_url: url::Url,
22 pub files_path: String,
24 pub http_client: Client,
26}
27
28impl From<&crate::chat_completions::ChatClient> for FilesClient {
29 fn from(client: &crate::chat_completions::ChatClient) -> Self {
30 Self {
31 api_key: client.api_key.clone(),
32 base_url: client.base_url.clone(),
33 files_path: "files/".to_string(),
34 http_client: client.http_client.clone(),
35 }
36 }
37}
38#[derive(Serialize, Clone, Copy)]
40#[serde(rename_all = "snake_case")]
41pub enum FilePurpose {
42 #[serde(rename = "fine-tune")]
44 FineTune,
45 #[serde(rename = "assistants")]
47 Assistants,
48 #[serde(rename = "batch")]
50 Batch,
51 #[serde(rename = "user_data")]
53 UserData,
54 #[serde(rename = "vision")]
56 Vision,
57 #[serde(rename = "evals")]
59 Evals,
60}
61
62impl std::fmt::Debug for FilePurpose {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 FilePurpose::FineTune => write!(f, "fine-tune"),
66 FilePurpose::Assistants => write!(f, "assistants"),
67 FilePurpose::Batch => write!(f, "batch"),
68 FilePurpose::UserData => write!(f, "user_data"),
69 FilePurpose::Vision => write!(f, "vision"),
70 FilePurpose::Evals => write!(f, "evals"),
71 }
72 }
73}
74
75impl std::fmt::Display for FilePurpose {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{:?}", self)
78 }
79}
80
81#[derive(Debug, Deserialize)]
83pub struct FileObject {
84 pub id: String,
86 pub object: String,
88 pub bytes: u64,
90 pub created_at: u64,
92 pub filename: String,
94 pub purpose: String,
96}
97
98#[derive(Debug, Deserialize)]
99enum UploadFileResponse {
100 #[serde(rename = "error")]
101 Error(OpenAiError),
102 #[serde(untagged)]
103 File(FileObject),
104}
105
106#[derive(Debug, Deserialize)]
108pub struct FileList {
109 pub data: Vec<FileObject>,
111 pub object: String,
113}
114
115#[derive(Error, Debug)]
117pub enum FilesError {
118 #[error("Request error: {0}")]
120 RequestError(#[from] reqwest::Error),
121
122 #[error("API {url} returned an unknown response: {response}")]
124 ApiParseError {
125 url: String,
127 response: String,
129 #[source]
131 error: serde_json::Error,
132 },
133
134 #[error("API returned an error response")]
136 ApiError(#[from] OpenAiError),
137
138 #[error("File error: {0}")]
140 IoError(#[from] std::io::Error),
141
142 #[error("Invalid file path")]
144 InvalidFilePath,
145}
146
147impl FilesClient {
148 pub fn new(api_key: impl Into<String>) -> Self {
157 Self {
158 api_key: api_key.into(),
159 base_url: url::Url::parse("https://api.openai.com/v1/").unwrap(),
160 files_path: "files/".to_string(),
161 http_client: crate::utils::pooled_client(),
162 }
163 }
164
165 fn files_url(&self) -> url::Url {
166 self.base_url.join(&self.files_path).unwrap()
167 }
168
169 pub fn from_env() -> Result<Self, OpenAiApiKeyError> {
178 Ok(Self::new(api_key()?))
179 }
180
181 pub async fn upload_file(
193 &self,
194 file_path: impl AsRef<Path>,
195 purpose: FilePurpose,
196 ) -> Result<FileObject, FilesError> {
197 let file_path = file_path.as_ref();
198 let file_name = file_path
199 .file_name()
200 .and_then(|name| name.to_str())
201 .ok_or(FilesError::InvalidFilePath)?;
202
203 let file = File::open(file_path).await?;
204 let stream = FramedRead::new(file, BytesCodec::new());
205 let file_part = multipart::Part::stream(reqwest::Body::wrap_stream(stream))
206 .file_name(file_name.to_string());
207
208 let form = multipart::Form::new()
209 .text("purpose", format!("{:?}", purpose).to_lowercase())
210 .part("file", file_part);
211
212 let url = remove_trailing_slash(self.files_url());
213 let response = self
214 .http_client
215 .post(url.clone())
216 .header("Authorization", format!("Bearer {}", self.api_key))
217 .multipart(form)
218 .send()
219 .await?;
220
221 let response_text = response.text().await?;
222
223 let file_object: UploadFileResponse =
224 serde_json::from_str(&response_text).map_err(|e| FilesError::ApiParseError {
225 url: url.to_string(),
226 response: response_text.clone(),
227 error: e,
228 })?;
229
230 match file_object {
231 UploadFileResponse::File(file) => Ok(file),
232 UploadFileResponse::Error(error) => Err(FilesError::ApiError(error)),
233 }
234 }
235
236 pub async fn upload_bytes(
249 &self,
250 filename: &str,
251 bytes: Vec<u8>,
252 purpose: FilePurpose,
253 ) -> Result<FileObject, FilesError> {
254 let file_part = multipart::Part::bytes(bytes).file_name(filename.to_string());
255
256 let form = multipart::Form::new()
257 .text("purpose", format!("{:?}", purpose).to_lowercase())
258 .part("file", file_part);
259
260 let url = remove_trailing_slash(self.files_url());
261 let response = self
262 .http_client
263 .post(url.clone())
264 .header("Authorization", format!("Bearer {}", self.api_key))
265 .multipart(form)
266 .send()
267 .await?;
268
269 let response_text = response.text().await?;
270
271 let file_object: UploadFileResponse =
272 serde_json::from_str(&response_text).map_err(|e| FilesError::ApiParseError {
273 url: url.to_string(),
274 response: response_text.clone(),
275 error: e,
276 })?;
277
278 match file_object {
279 UploadFileResponse::File(file) => Ok(file),
280 UploadFileResponse::Error(error) => Err(FilesError::ApiError(error)),
281 }
282 }
283
284 pub async fn list_files(&self) -> Result<FileList, FilesError> {
298 let response = self
299 .http_client
300 .get(self.files_url())
301 .header("Authorization", format!("Bearer {}", self.api_key))
302 .send()
303 .await?;
304
305 let file_list = response.json::<FileList>().await?;
306 Ok(file_list)
307 }
308
309 pub async fn retrieve_file(&self, file_id: &str) -> Result<FileObject, FilesError> {
321 let response = self
322 .http_client
323 .get(self.files_url().join(file_id).unwrap())
324 .header("Authorization", format!("Bearer {}", self.api_key))
325 .send()
326 .await?;
327
328 let file_object = response.json::<FileObject>().await?;
329 Ok(file_object)
330 }
331
332 pub async fn delete_file(&self, file_id: &str) -> Result<DeletedFile, FilesError> {
344 let response = self
345 .http_client
346 .delete(self.files_url().join(file_id).unwrap())
347 .header("Authorization", format!("Bearer {}", self.api_key))
348 .send()
349 .await?;
350
351 let deleted_file = response.json::<DeletedFile>().await?;
352 Ok(deleted_file)
353 }
354
355 pub async fn download_file(&self, file_id: &str) -> Result<String, FilesError> {
367 let url = self
368 .files_url()
369 .join(&format!("{file_id}/content"))
370 .unwrap();
371 let response = self
372 .http_client
373 .get(url)
374 .header("Authorization", format!("Bearer {}", self.api_key))
375 .send()
376 .await?;
377
378 let content = response.text().await?;
379 Ok(content)
380 }
381}
382
383#[derive(Debug, Deserialize)]
385pub struct DeletedFile {
386 pub id: String,
388 pub object: String,
390 pub deleted: bool,
392}