1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use headers::{Authorization, HeaderMapExt};
5use percent_encoding::utf8_percent_encode;
6use reqwest::header::HeaderMap;
7use reqwest::Response;
8use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
9use serde::{Deserialize, Serialize};
10use shuttle_common::models::{
11 certificate::{
12 AddCertificateRequest, CertificateListResponse, CertificateResponse,
13 DeleteCertificateRequest,
14 },
15 deployment::{
16 DeploymentListResponse, DeploymentRequest, DeploymentResponse, UploadArchiveResponse,
17 },
18 log::LogsResponse,
19 project::{ProjectCreateRequest, ProjectListResponse, ProjectResponse, ProjectUpdateRequest},
20 resource::{ProvisionResourceRequest, ResourceListResponse, ResourceResponse, ResourceType},
21 team::TeamListResponse,
22 user::UserResponse,
23};
24use tokio::net::TcpStream;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
27
28#[cfg(feature = "tracing")]
29mod middleware;
30#[cfg(feature = "tracing")]
31use crate::middleware::LoggingMiddleware;
32#[cfg(feature = "tracing")]
33use tracing::{debug, error};
34
35pub mod util;
36use util::{ParsedJson, ToBodyContent};
37
38#[derive(Clone)]
39pub struct ShuttleApiClient {
40 pub client: ClientWithMiddleware,
41 pub api_url: String,
42 pub api_key: Option<String>,
43}
44
45impl ShuttleApiClient {
46 pub fn new(
47 api_url: String,
48 api_key: Option<String>,
49 headers: Option<HeaderMap>,
50 timeout: Option<u64>,
51 ) -> Self {
52 let mut builder = reqwest::Client::builder();
53 if let Some(h) = headers {
54 builder = builder.default_headers(h);
55 }
56 let client = builder
57 .timeout(Duration::from_secs(timeout.unwrap_or(60)))
58 .build()
59 .unwrap();
60
61 let builder = reqwest_middleware::ClientBuilder::new(client);
62 #[cfg(feature = "tracing")]
63 let builder = builder.with(LoggingMiddleware);
64 let client = builder.build();
65
66 Self {
67 client,
68 api_url,
69 api_key,
70 }
71 }
72
73 pub fn set_auth_bearer(&self, builder: RequestBuilder) -> RequestBuilder {
74 if let Some(ref api_key) = self.api_key {
75 builder.bearer_auth(api_key)
76 } else {
77 builder
78 }
79 }
80
81 pub async fn get_device_auth_ws(&self) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
82 self.ws_get("/device-auth/ws")
83 .await
84 .with_context(|| "failed to connect to auth endpoint")
85 }
86
87 pub async fn check_project_name(&self, project_name: &str) -> Result<ParsedJson<bool>> {
88 let url = format!("{}/projects/{project_name}/name", self.api_url);
89
90 self.client
91 .get(url)
92 .send()
93 .await
94 .context("failed to check project name availability")?
95 .to_json()
96 .await
97 .context("parsing name check response")
98 }
99
100 pub async fn get_current_user(&self) -> Result<ParsedJson<UserResponse>> {
101 self.get_json("/users/me").await
102 }
103
104 pub async fn deploy(
105 &self,
106 project: &str,
107 deployment_req: DeploymentRequest,
108 ) -> Result<ParsedJson<DeploymentResponse>> {
109 let path = format!("/projects/{project}/deployments");
110 self.post_json(path, Some(deployment_req)).await
111 }
112
113 pub async fn upload_archive(
114 &self,
115 project: &str,
116 data: Vec<u8>,
117 ) -> Result<ParsedJson<UploadArchiveResponse>> {
118 let path = format!("/projects/{project}/archives");
119
120 let url = format!("{}{}", self.api_url, path);
121 let mut builder = self.client.post(url);
122 builder = self.set_auth_bearer(builder);
123
124 builder
125 .body(data)
126 .send()
127 .await
128 .context("failed to upload archive")?
129 .to_json()
130 .await
131 }
132
133 pub async fn redeploy(
134 &self,
135 project: &str,
136 deployment_id: &str,
137 ) -> Result<ParsedJson<DeploymentResponse>> {
138 let path = format!("/projects/{project}/deployments/{deployment_id}/redeploy");
139
140 self.post_json(path, Option::<()>::None).await
141 }
142
143 pub async fn stop_service(&self, project: &str) -> Result<ParsedJson<String>> {
144 let path = format!("/projects/{project}/deployments");
145
146 self.delete_json(path).await
147 }
148
149 pub async fn get_service_resources(
150 &self,
151 project: &str,
152 ) -> Result<ParsedJson<ResourceListResponse>> {
153 self.get_json(format!("/projects/{project}/resources"))
154 .await
155 }
156
157 async fn _dump_service_resource(
158 &self,
159 project: &str,
160 resource_type: &ResourceType,
161 ) -> Result<Vec<u8>> {
162 let r#type = resource_type.to_string();
163 let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
164
165 let bytes = self
166 .get(
167 format!(
168 "/projects/{project}/services/{project}/resources/{}/dump",
169 r#type
170 ),
171 Option::<()>::None,
172 )
173 .await?
174 .to_bytes()
175 .await?
176 .to_vec();
177
178 Ok(bytes)
179 }
180
181 pub async fn delete_service_resource(
182 &self,
183 project: &str,
184 resource_type: &ResourceType,
185 ) -> Result<ParsedJson<String>> {
186 let r#type = resource_type.to_string();
187 let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
188
189 self.delete_json(format!("/projects/{project}/resources/{}", r#type))
190 .await
191 }
192 pub async fn provision_resource(
193 &self,
194 project: &str,
195 req: ProvisionResourceRequest,
196 ) -> Result<ParsedJson<ResourceResponse>> {
197 self.post_json(format!("/projects/{project}/resources"), Some(req))
198 .await
199 }
200 pub async fn get_secrets(&self, project: &str) -> Result<ParsedJson<ResourceResponse>> {
201 self.get_json(format!("/projects/{project}/resources/secrets"))
202 .await
203 }
204
205 pub async fn list_certificates(
206 &self,
207 project: &str,
208 ) -> Result<ParsedJson<CertificateListResponse>> {
209 self.get_json(format!("/projects/{project}/certificates"))
210 .await
211 }
212 pub async fn add_certificate(
213 &self,
214 project: &str,
215 subject: String,
216 ) -> Result<ParsedJson<CertificateResponse>> {
217 self.post_json(
218 format!("/projects/{project}/certificates"),
219 Some(AddCertificateRequest { subject }),
220 )
221 .await
222 }
223 pub async fn delete_certificate(
224 &self,
225 project: &str,
226 subject: String,
227 ) -> Result<ParsedJson<String>> {
228 self.delete_json_with_body(
229 format!("/projects/{project}/certificates"),
230 DeleteCertificateRequest { subject },
231 )
232 .await
233 }
234
235 pub async fn create_project(&self, name: &str) -> Result<ParsedJson<ProjectResponse>> {
236 self.post_json(
237 "/projects",
238 Some(ProjectCreateRequest {
239 name: name.to_string(),
240 }),
241 )
242 .await
243 }
244
245 pub async fn get_project(&self, project: &str) -> Result<ParsedJson<ProjectResponse>> {
246 self.get_json(format!("/projects/{project}")).await
247 }
248
249 pub async fn get_projects_list(&self) -> Result<ParsedJson<ProjectListResponse>> {
250 self.get_json("/projects".to_owned()).await
251 }
252
253 pub async fn update_project(
254 &self,
255 project: &str,
256 req: ProjectUpdateRequest,
257 ) -> Result<ParsedJson<ProjectResponse>> {
258 self.put_json(format!("/projects/{project}"), Some(req))
259 .await
260 }
261
262 pub async fn delete_project(&self, project: &str) -> Result<ParsedJson<String>> {
263 self.delete_json(format!("/projects/{project}")).await
264 }
265
266 #[allow(unused)]
267 async fn get_teams_list(&self) -> Result<ParsedJson<TeamListResponse>> {
268 self.get_json("/teams").await
269 }
270
271 pub async fn get_deployment_logs(
272 &self,
273 project: &str,
274 deployment_id: &str,
275 ) -> Result<ParsedJson<LogsResponse>> {
276 let path = format!("/projects/{project}/deployments/{deployment_id}/logs");
277
278 self.get_json(path).await
279 }
280 pub async fn get_project_logs(&self, project: &str) -> Result<ParsedJson<LogsResponse>> {
281 let path = format!("/projects/{project}/logs");
282
283 self.get_json(path).await
284 }
285
286 pub async fn get_deployments(
287 &self,
288 project: &str,
289 page: i32,
290 per_page: i32,
291 ) -> Result<ParsedJson<DeploymentListResponse>> {
292 let path = format!(
293 "/projects/{project}/deployments?page={}&per_page={}",
294 page.saturating_sub(1).max(0),
295 per_page.max(1),
296 );
297
298 self.get_json(path).await
299 }
300 pub async fn get_current_deployment(
301 &self,
302 project: &str,
303 ) -> Result<ParsedJson<Option<DeploymentResponse>>> {
304 let path = format!("/projects/{project}/deployments/current");
305
306 self.get_json(path).await
307 }
308
309 pub async fn get_deployment(
310 &self,
311 project: &str,
312 deployment_id: &str,
313 ) -> Result<ParsedJson<DeploymentResponse>> {
314 let path = format!("/projects/{project}/deployments/{deployment_id}");
315
316 self.get_json(path).await
317 }
318
319 pub async fn reset_api_key(&self) -> Result<()> {
320 self.put("/users/reset-api-key", Option::<()>::None)
321 .await?
322 .to_empty()
323 .await
324 }
325
326 pub async fn ws_get(
327 &self,
328 path: impl AsRef<str>,
329 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
330 let ws_url = self
331 .api_url
332 .clone()
333 .replacen("http://", "ws://", 1)
334 .replacen("https://", "wss://", 1);
335 let url = format!("{ws_url}{}", path.as_ref());
336 let mut req = url.into_client_request()?;
337
338 #[cfg(feature = "tracing")]
339 debug!("WS Request: {} {}", req.method(), req.uri());
340
341 if let Some(ref api_key) = self.api_key {
342 let auth_header = Authorization::bearer(api_key.as_ref())?;
343 req.headers_mut().typed_insert(auth_header);
344 }
345
346 let (stream, _) = connect_async(req).await.with_context(|| {
347 #[cfg(feature = "tracing")]
348 error!("failed to connect to websocket");
349 "could not connect to websocket"
350 })?;
351
352 Ok(stream)
353 }
354
355 pub async fn get<T: Serialize>(
356 &self,
357 path: impl AsRef<str>,
358 body: Option<T>,
359 ) -> Result<Response> {
360 let url = format!("{}{}", self.api_url, path.as_ref());
361
362 let mut builder = self.client.get(url);
363 builder = self.set_auth_bearer(builder);
364
365 if let Some(body) = body {
366 let body = serde_json::to_string(&body)?;
367 #[cfg(feature = "tracing")]
368 debug!("Outgoing body: {}", body);
369 builder = builder.body(body);
370 builder = builder.header("Content-Type", "application/json");
371 }
372
373 Ok(builder.send().await?)
374 }
375
376 pub async fn get_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
377 where
378 R: for<'de> Deserialize<'de>,
379 {
380 self.get(path, Option::<()>::None).await?.to_json().await
381 }
382
383 pub async fn get_json_with_body<R, T: Serialize>(
384 &self,
385 path: impl AsRef<str>,
386 body: T,
387 ) -> Result<ParsedJson<R>>
388 where
389 R: for<'de> Deserialize<'de>,
390 {
391 self.get(path, Some(body)).await?.to_json().await
392 }
393
394 pub async fn post<T: Serialize>(
395 &self,
396 path: impl AsRef<str>,
397 body: Option<T>,
398 ) -> Result<Response> {
399 let url = format!("{}{}", self.api_url, path.as_ref());
400
401 let mut builder = self.client.post(url);
402 builder = self.set_auth_bearer(builder);
403
404 if let Some(body) = body {
405 let body = serde_json::to_string(&body)?;
406 #[cfg(feature = "tracing")]
407 debug!("Outgoing body: {}", body);
408 builder = builder.body(body);
409 builder = builder.header("Content-Type", "application/json");
410 }
411
412 Ok(builder.send().await?)
413 }
414
415 pub async fn post_json<T: Serialize, R>(
416 &self,
417 path: impl AsRef<str>,
418 body: Option<T>,
419 ) -> Result<ParsedJson<R>>
420 where
421 R: for<'de> Deserialize<'de>,
422 {
423 self.post(path, body).await?.to_json().await
424 }
425
426 pub async fn put<T: Serialize>(
427 &self,
428 path: impl AsRef<str>,
429 body: Option<T>,
430 ) -> Result<Response> {
431 let url = format!("{}{}", self.api_url, path.as_ref());
432
433 let mut builder = self.client.put(url);
434 builder = self.set_auth_bearer(builder);
435
436 if let Some(body) = body {
437 let body = serde_json::to_string(&body)?;
438 #[cfg(feature = "tracing")]
439 debug!("Outgoing body: {}", body);
440 builder = builder.body(body);
441 builder = builder.header("Content-Type", "application/json");
442 }
443
444 Ok(builder.send().await?)
445 }
446
447 pub async fn put_json<T: Serialize, R>(
448 &self,
449 path: impl AsRef<str>,
450 body: Option<T>,
451 ) -> Result<ParsedJson<R>>
452 where
453 R: for<'de> Deserialize<'de>,
454 {
455 self.put(path, body).await?.to_json().await
456 }
457
458 pub async fn delete<T: Serialize>(
459 &self,
460 path: impl AsRef<str>,
461 body: Option<T>,
462 ) -> Result<Response> {
463 let url = format!("{}{}", self.api_url, path.as_ref());
464
465 let mut builder = self.client.delete(url);
466 builder = self.set_auth_bearer(builder);
467
468 if let Some(body) = body {
469 let body = serde_json::to_string(&body)?;
470 #[cfg(feature = "tracing")]
471 debug!("Outgoing body: {}", body);
472 builder = builder.body(body);
473 builder = builder.header("Content-Type", "application/json");
474 }
475
476 Ok(builder.send().await?)
477 }
478
479 pub async fn delete_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
480 where
481 R: for<'de> Deserialize<'de>,
482 {
483 self.delete(path, Option::<()>::None).await?.to_json().await
484 }
485
486 pub async fn delete_json_with_body<R, T: Serialize>(
487 &self,
488 path: impl AsRef<str>,
489 body: T,
490 ) -> Result<ParsedJson<R>>
491 where
492 R: for<'de> Deserialize<'de>,
493 {
494 self.delete(path, Some(body)).await?.to_json().await
495 }
496}