1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use headers::{Authorization, HeaderMapExt};
5use percent_encoding::utf8_percent_encode;
6use reqwest::header::HeaderMap;
7use reqwest::Response;
8use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
9use serde::{Deserialize, Serialize};
10use shuttle_common::models::{
11 certificate::{
12 AddCertificateRequest, CertificateListResponse, CertificateResponse,
13 DeleteCertificateRequest,
14 },
15 deployment::{
16 DeploymentListResponse, DeploymentRequest, DeploymentResponse, UploadArchiveResponse,
17 },
18 log::LogsResponse,
19 project::{ProjectCreateRequest, ProjectListResponse, ProjectResponse, ProjectUpdateRequest},
20 resource::{ProvisionResourceRequest, ResourceListResponse, ResourceResponse, ResourceType},
21 team::TeamListResponse,
22 user::UserResponse,
23};
24use tokio::net::TcpStream;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
27
28#[cfg(feature = "tracing")]
29mod middleware;
30#[cfg(feature = "tracing")]
31use crate::middleware::LoggingMiddleware;
32#[cfg(feature = "tracing")]
33use tracing::{debug, error};
34
35mod util;
36use util::ToJson;
37
38#[derive(Clone)]
39pub struct ShuttleApiClient {
40 pub client: ClientWithMiddleware,
41 pub api_url: String,
42 pub api_key: Option<String>,
43}
44
45impl ShuttleApiClient {
46 pub fn new(
47 api_url: String,
48 api_key: Option<String>,
49 headers: Option<HeaderMap>,
50 timeout: Option<u64>,
51 ) -> Self {
52 let mut builder = reqwest::Client::builder();
53 if let Some(h) = headers {
54 builder = builder.default_headers(h);
55 }
56 let client = builder
57 .timeout(Duration::from_secs(timeout.unwrap_or(60)))
58 .build()
59 .unwrap();
60
61 let builder = reqwest_middleware::ClientBuilder::new(client);
62 #[cfg(feature = "tracing")]
63 let builder = builder.with(LoggingMiddleware);
64 let client = builder.build();
65
66 Self {
67 client,
68 api_url,
69 api_key,
70 }
71 }
72
73 pub fn set_auth_bearer(&self, builder: RequestBuilder) -> RequestBuilder {
74 if let Some(ref api_key) = self.api_key {
75 builder.bearer_auth(api_key)
76 } else {
77 builder
78 }
79 }
80
81 pub async fn get_device_auth_ws(&self) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
82 self.ws_get("/device-auth/ws")
83 .await
84 .with_context(|| "failed to connect to auth endpoint")
85 }
86
87 pub async fn check_project_name(&self, project_name: &str) -> Result<bool> {
88 let url = format!("{}/projects/{project_name}/name", self.api_url);
89
90 self.client
91 .get(url)
92 .send()
93 .await
94 .context("failed to check project name availability")?
95 .to_json()
96 .await
97 .context("parsing name check response")
98 }
99
100 pub async fn get_current_user(&self) -> Result<UserResponse> {
101 self.get_json("/users/me").await
102 }
103
104 pub async fn deploy(
105 &self,
106 project: &str,
107 deployment_req: DeploymentRequest,
108 ) -> Result<DeploymentResponse> {
109 let path = format!("/projects/{project}/deployments");
110 self.post_json(path, Some(deployment_req)).await
111 }
112
113 pub async fn upload_archive(
114 &self,
115 project: &str,
116 data: Vec<u8>,
117 ) -> Result<UploadArchiveResponse> {
118 let path = format!("/projects/{project}/archives");
119
120 let url = format!("{}{}", self.api_url, path);
121 let mut builder = self.client.post(url);
122 builder = self.set_auth_bearer(builder);
123
124 builder
125 .body(data)
126 .send()
127 .await
128 .context("failed to upload archive")?
129 .to_json()
130 .await
131 }
132
133 pub async fn redeploy(&self, project: &str, deployment_id: &str) -> Result<DeploymentResponse> {
134 let path = format!("/projects/{project}/deployments/{deployment_id}/redeploy");
135
136 self.post_json(path, Option::<()>::None).await
137 }
138
139 pub async fn stop_service(&self, project: &str) -> Result<String> {
140 let path = format!("/projects/{project}/deployments");
141
142 self.delete_json(path).await
143 }
144
145 pub async fn get_service_resources(&self, project: &str) -> Result<ResourceListResponse> {
146 self.get_json(format!("/projects/{project}/resources"))
147 .await
148 }
149
150 async fn _dump_service_resource(
151 &self,
152 project: &str,
153 resource_type: &ResourceType,
154 ) -> Result<Vec<u8>> {
155 let r#type = resource_type.to_string();
156 let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
157
158 let res = self
159 .get(
160 format!(
161 "/projects/{project}/services/{project}/resources/{}/dump",
162 r#type
163 ),
164 Option::<()>::None,
165 )
166 .await?;
167
168 let bytes = res.bytes().await?;
169
170 Ok(bytes.to_vec())
171 }
172
173 pub async fn delete_service_resource(
174 &self,
175 project: &str,
176 resource_type: &ResourceType,
177 ) -> Result<String> {
178 let r#type = resource_type.to_string();
179 let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
180
181 self.delete_json(format!("/projects/{project}/resources/{}", r#type))
182 .await
183 }
184 pub async fn provision_resource(
185 &self,
186 project: &str,
187 req: ProvisionResourceRequest,
188 ) -> Result<ResourceResponse> {
189 self.post_json(format!("/projects/{project}/resources"), Some(req))
190 .await
191 }
192 pub async fn get_secrets(&self, project: &str) -> Result<ResourceResponse> {
193 self.get_json(format!("/projects/{project}/resources/secrets"))
194 .await
195 }
196
197 pub async fn list_certificates(&self, project: &str) -> Result<CertificateListResponse> {
198 self.get_json(format!("/projects/{project}/certificates"))
199 .await
200 }
201 pub async fn add_certificate(
202 &self,
203 project: &str,
204 subject: String,
205 ) -> Result<CertificateResponse> {
206 self.post_json(
207 format!("/projects/{project}/certificates"),
208 Some(AddCertificateRequest { subject }),
209 )
210 .await
211 }
212 pub async fn delete_certificate(&self, project: &str, subject: String) -> Result<String> {
213 self.delete_json_with_body(
214 format!("/projects/{project}/certificates"),
215 DeleteCertificateRequest { subject },
216 )
217 .await
218 }
219
220 pub async fn create_project(&self, name: &str) -> Result<ProjectResponse> {
221 self.post_json(
222 "/projects",
223 Some(ProjectCreateRequest {
224 name: name.to_string(),
225 }),
226 )
227 .await
228 }
229
230 pub async fn get_project(&self, project: &str) -> Result<ProjectResponse> {
231 self.get_json(format!("/projects/{project}")).await
232 }
233
234 pub async fn get_projects_list(&self) -> Result<ProjectListResponse> {
235 self.get_json("/projects".to_owned()).await
236 }
237
238 pub async fn update_project(
239 &self,
240 project: &str,
241 req: ProjectUpdateRequest,
242 ) -> Result<ProjectResponse> {
243 self.put_json(format!("/projects/{project}"), Some(req))
244 .await
245 }
246
247 pub async fn delete_project(&self, project: &str) -> Result<String> {
248 self.delete_json(format!("/projects/{project}")).await
249 }
250
251 #[allow(unused)]
252 async fn get_teams_list(&self) -> Result<TeamListResponse> {
253 self.get_json("/teams").await
254 }
255
256 pub async fn get_deployment_logs(
257 &self,
258 project: &str,
259 deployment_id: &str,
260 ) -> Result<LogsResponse> {
261 let path = format!("/projects/{project}/deployments/{deployment_id}/logs");
262
263 self.get_json(path).await
264 }
265 pub async fn get_project_logs(&self, project: &str) -> Result<LogsResponse> {
266 let path = format!("/projects/{project}/logs");
267
268 self.get_json(path).await
269 }
270
271 pub async fn get_deployments(
272 &self,
273 project: &str,
274 page: i32,
275 per_page: i32,
276 ) -> Result<DeploymentListResponse> {
277 let path = format!(
278 "/projects/{project}/deployments?page={}&per_page={}",
279 page.saturating_sub(1).max(0),
280 per_page.max(1),
281 );
282
283 self.get_json(path).await
284 }
285 pub async fn get_current_deployment(
286 &self,
287 project: &str,
288 ) -> Result<Option<DeploymentResponse>> {
289 let path = format!("/projects/{project}/deployments/current");
290
291 self.get_json(path).await
292 }
293
294 pub async fn get_deployment(
295 &self,
296 project: &str,
297 deployment_id: &str,
298 ) -> Result<DeploymentResponse> {
299 let path = format!("/projects/{project}/deployments/{deployment_id}");
300
301 self.get_json(path).await
302 }
303
304 pub async fn reset_api_key(&self) -> Result<Response> {
305 self.put("/users/reset-api-key", Option::<()>::None).await
306 }
307
308 pub async fn ws_get(
309 &self,
310 path: impl AsRef<str>,
311 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
312 let ws_url = self.api_url.clone().replace("http", "ws");
313 let url = format!("{ws_url}{}", path.as_ref());
314 let mut req = url.into_client_request()?;
315
316 #[cfg(feature = "tracing")]
317 debug!("WS Request: {} {}", req.method(), req.uri());
318
319 if let Some(ref api_key) = self.api_key {
320 let auth_header = Authorization::bearer(api_key.as_ref())?;
321 req.headers_mut().typed_insert(auth_header);
322 }
323
324 let (stream, _) = connect_async(req).await.with_context(|| {
325 #[cfg(feature = "tracing")]
326 error!("failed to connect to websocket");
327 "could not connect to websocket"
328 })?;
329
330 Ok(stream)
331 }
332
333 pub async fn get<T: Serialize>(
334 &self,
335 path: impl AsRef<str>,
336 body: Option<T>,
337 ) -> Result<Response> {
338 let url = format!("{}{}", self.api_url, path.as_ref());
339
340 let mut builder = self.client.get(url);
341 builder = self.set_auth_bearer(builder);
342
343 if let Some(body) = body {
344 let body = serde_json::to_string(&body)?;
345 #[cfg(feature = "tracing")]
346 debug!("Outgoing body: {}", body);
347 builder = builder.body(body);
348 builder = builder.header("Content-Type", "application/json");
349 }
350
351 Ok(builder.send().await?)
352 }
353
354 pub async fn get_json<R>(&self, path: impl AsRef<str>) -> Result<R>
355 where
356 R: for<'de> Deserialize<'de>,
357 {
358 self.get(path, Option::<()>::None).await?.to_json().await
359 }
360
361 pub async fn get_json_with_body<R, T: Serialize>(
362 &self,
363 path: impl AsRef<str>,
364 body: T,
365 ) -> Result<R>
366 where
367 R: for<'de> Deserialize<'de>,
368 {
369 self.get(path, Some(body)).await?.to_json().await
370 }
371
372 pub async fn post<T: Serialize>(
373 &self,
374 path: impl AsRef<str>,
375 body: Option<T>,
376 ) -> Result<Response> {
377 let url = format!("{}{}", self.api_url, path.as_ref());
378
379 let mut builder = self.client.post(url);
380 builder = self.set_auth_bearer(builder);
381
382 if let Some(body) = body {
383 let body = serde_json::to_string(&body)?;
384 #[cfg(feature = "tracing")]
385 debug!("Outgoing body: {}", body);
386 builder = builder.body(body);
387 builder = builder.header("Content-Type", "application/json");
388 }
389
390 Ok(builder.send().await?)
391 }
392
393 pub async fn post_json<T: Serialize, R>(
394 &self,
395 path: impl AsRef<str>,
396 body: Option<T>,
397 ) -> Result<R>
398 where
399 R: for<'de> Deserialize<'de>,
400 {
401 self.post(path, body).await?.to_json().await
402 }
403
404 pub async fn put<T: Serialize>(
405 &self,
406 path: impl AsRef<str>,
407 body: Option<T>,
408 ) -> Result<Response> {
409 let url = format!("{}{}", self.api_url, path.as_ref());
410
411 let mut builder = self.client.put(url);
412 builder = self.set_auth_bearer(builder);
413
414 if let Some(body) = body {
415 let body = serde_json::to_string(&body)?;
416 #[cfg(feature = "tracing")]
417 debug!("Outgoing body: {}", body);
418 builder = builder.body(body);
419 builder = builder.header("Content-Type", "application/json");
420 }
421
422 Ok(builder.send().await?)
423 }
424
425 pub async fn put_json<T: Serialize, R>(
426 &self,
427 path: impl AsRef<str>,
428 body: Option<T>,
429 ) -> Result<R>
430 where
431 R: for<'de> Deserialize<'de>,
432 {
433 self.put(path, body).await?.to_json().await
434 }
435
436 pub async fn delete<T: Serialize>(
437 &self,
438 path: impl AsRef<str>,
439 body: Option<T>,
440 ) -> Result<Response> {
441 let url = format!("{}{}", self.api_url, path.as_ref());
442
443 let mut builder = self.client.delete(url);
444 builder = self.set_auth_bearer(builder);
445
446 if let Some(body) = body {
447 let body = serde_json::to_string(&body)?;
448 #[cfg(feature = "tracing")]
449 debug!("Outgoing body: {}", body);
450 builder = builder.body(body);
451 builder = builder.header("Content-Type", "application/json");
452 }
453
454 Ok(builder.send().await?)
455 }
456
457 pub async fn delete_json<R>(&self, path: impl AsRef<str>) -> Result<R>
458 where
459 R: for<'de> Deserialize<'de>,
460 {
461 self.delete(path, Option::<()>::None).await?.to_json().await
462 }
463
464 pub async fn delete_json_with_body<R, T: Serialize>(
465 &self,
466 path: impl AsRef<str>,
467 body: T,
468 ) -> Result<R>
469 where
470 R: for<'de> Deserialize<'de>,
471 {
472 self.delete(path, Some(body)).await?.to_json().await
473 }
474}