1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use headers::{Authorization, HeaderMapExt};
5use percent_encoding::utf8_percent_encode;
6use reqwest::{header::HeaderMap, Response};
7use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
8use serde::{Deserialize, Serialize};
9use shuttle_common::models::{
10 certificate::{
11 AddCertificateRequest, CertificateListResponse, CertificateResponse,
12 DeleteCertificateRequest,
13 },
14 deployment::{
15 DeploymentListResponse, DeploymentRequest, DeploymentResponse, UploadArchiveResponse,
16 },
17 log::LogsResponse,
18 project::{ProjectCreateRequest, ProjectListResponse, ProjectResponse, ProjectUpdateRequest},
19 resource::{ProvisionResourceRequest, ResourceListResponse, ResourceResponse, ResourceType},
20 team::TeamListResponse,
21 user::UserResponse,
22};
23use tokio::net::TcpStream;
24use tokio_tungstenite::{
25 connect_async, tungstenite::client::IntoClientRequest, MaybeTlsStream, WebSocketStream,
26};
27
28#[cfg(feature = "tracing")]
29mod middleware;
30#[cfg(feature = "tracing")]
31use tracing::{debug, error};
32
33#[cfg(feature = "tracing")]
34use crate::middleware::LoggingMiddleware;
35
36pub mod util;
37use util::{ParsedJson, ToBodyContent};
38
39#[derive(Clone)]
40pub struct ShuttleApiClient {
41 pub client: ClientWithMiddleware,
42 pub api_url: String,
43 pub api_key: Option<String>,
44}
45
46impl ShuttleApiClient {
47 pub fn new(
48 api_url: String,
49 api_key: Option<String>,
50 headers: Option<HeaderMap>,
51 timeout: Option<u64>,
52 ) -> Self {
53 let mut builder = reqwest::Client::builder();
54
55 if let Ok(proxy) = std::env::var("HTTP_PROXY") {
56 builder = builder.proxy(reqwest::Proxy::http(proxy).unwrap());
57 }
58
59 if let Ok(proxy) = std::env::var("HTTPS_PROXY") {
60 builder = builder.proxy(reqwest::Proxy::https(proxy).unwrap());
61 }
62
63 if std::env::var("SHUTTLE_TLS_INSECURE_SKIP_VERIFY")
64 .is_ok_and(|value| value == "1" || value == "true")
65 {
66 builder = builder.danger_accept_invalid_certs(true);
67 }
68
69 if let Some(headers) = headers {
70 builder = builder.default_headers(headers);
71 }
72
73 let client = builder
74 .timeout(Duration::from_secs(timeout.unwrap_or(60)))
75 .build()
76 .unwrap();
77
78 let builder = reqwest_middleware::ClientBuilder::new(client);
79
80 #[cfg(feature = "tracing")]
81 let builder = builder.with(LoggingMiddleware);
82
83 let client = builder.build();
84
85 Self {
86 client,
87 api_url,
88 api_key,
89 }
90 }
91
92 pub fn set_auth_bearer(&self, builder: RequestBuilder) -> RequestBuilder {
93 if let Some(ref api_key) = self.api_key {
94 builder.bearer_auth(api_key)
95 } else {
96 builder
97 }
98 }
99
100 pub async fn get_device_auth_ws(&self) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
101 self.ws_get("/device-auth/ws")
102 .await
103 .with_context(|| "failed to connect to auth endpoint")
104 }
105
106 pub async fn check_project_name(&self, project_name: &str) -> Result<ParsedJson<bool>> {
107 let url = format!("{}/projects/{project_name}/name", self.api_url);
108
109 self.client
110 .get(url)
111 .send()
112 .await
113 .context("failed to check project name availability")?
114 .to_json()
115 .await
116 .context("parsing name check response")
117 }
118
119 pub async fn get_current_user(&self) -> Result<ParsedJson<UserResponse>> {
120 self.get_json("/users/me").await
121 }
122
123 pub async fn deploy(
124 &self,
125 project: &str,
126 deployment_req: DeploymentRequest,
127 ) -> Result<ParsedJson<DeploymentResponse>> {
128 let path = format!("/projects/{project}/deployments");
129 self.post_json(path, Some(deployment_req)).await
130 }
131
132 pub async fn upload_archive(
133 &self,
134 project: &str,
135 data: Vec<u8>,
136 ) -> Result<ParsedJson<UploadArchiveResponse>> {
137 let path = format!("/projects/{project}/archives");
138
139 let url = format!("{}{}", self.api_url, path);
140 let mut builder = self.client.post(url);
141 builder = self.set_auth_bearer(builder);
142
143 builder
144 .body(data)
145 .send()
146 .await
147 .context("failed to upload archive")?
148 .to_json()
149 .await
150 }
151
152 pub async fn redeploy(
153 &self,
154 project: &str,
155 deployment_id: &str,
156 ) -> Result<ParsedJson<DeploymentResponse>> {
157 let path = format!("/projects/{project}/deployments/{deployment_id}/redeploy");
158
159 self.post_json(path, Option::<()>::None).await
160 }
161
162 pub async fn stop_service(&self, project: &str) -> Result<ParsedJson<String>> {
163 let path = format!("/projects/{project}/deployments");
164
165 self.delete_json(path).await
166 }
167
168 pub async fn get_service_resources(
169 &self,
170 project: &str,
171 ) -> Result<ParsedJson<ResourceListResponse>> {
172 self.get_json(format!("/projects/{project}/resources"))
173 .await
174 }
175
176 pub async fn dump_service_resource(
177 &self,
178 project: &str,
179 resource_type: &ResourceType,
180 ) -> Result<Vec<u8>> {
181 let r#type = resource_type.to_string();
182 let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
183
184 let bytes = self
185 .get(
186 format!("/projects/{project}/resources/{type}/dump"),
187 Option::<()>::None,
188 )
189 .await?
190 .to_bytes()
191 .await?
192 .to_vec();
193
194 Ok(bytes)
195 }
196
197 pub async fn delete_service_resource(
198 &self,
199 project: &str,
200 resource_type: &ResourceType,
201 ) -> Result<ParsedJson<String>> {
202 let r#type = resource_type.to_string();
203 let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
204
205 self.delete_json(format!("/projects/{project}/resources/{}", r#type))
206 .await
207 }
208 pub async fn provision_resource(
209 &self,
210 project: &str,
211 req: ProvisionResourceRequest,
212 ) -> Result<ParsedJson<ResourceResponse>> {
213 self.post_json(format!("/projects/{project}/resources"), Some(req))
214 .await
215 }
216 pub async fn get_secrets(&self, project: &str) -> Result<ParsedJson<ResourceResponse>> {
217 self.get_json(format!("/projects/{project}/resources/secrets"))
218 .await
219 }
220
221 pub async fn list_certificates(
222 &self,
223 project: &str,
224 ) -> Result<ParsedJson<CertificateListResponse>> {
225 self.get_json(format!("/projects/{project}/certificates"))
226 .await
227 }
228 pub async fn add_certificate(
229 &self,
230 project: &str,
231 subject: String,
232 ) -> Result<ParsedJson<CertificateResponse>> {
233 self.post_json(
234 format!("/projects/{project}/certificates"),
235 Some(AddCertificateRequest { subject }),
236 )
237 .await
238 }
239 pub async fn delete_certificate(
240 &self,
241 project: &str,
242 subject: String,
243 ) -> Result<ParsedJson<String>> {
244 self.delete_json_with_body(
245 format!("/projects/{project}/certificates"),
246 DeleteCertificateRequest { subject },
247 )
248 .await
249 }
250
251 pub async fn create_project(&self, name: &str) -> Result<ParsedJson<ProjectResponse>> {
252 self.post_json(
253 "/projects",
254 Some(ProjectCreateRequest {
255 name: name.to_string(),
256 }),
257 )
258 .await
259 }
260
261 pub async fn get_project(&self, project: &str) -> Result<ParsedJson<ProjectResponse>> {
262 self.get_json(format!("/projects/{project}")).await
263 }
264
265 pub async fn get_projects_list(&self) -> Result<ParsedJson<ProjectListResponse>> {
266 self.get_json("/projects".to_owned()).await
267 }
268
269 pub async fn update_project(
270 &self,
271 project: &str,
272 req: ProjectUpdateRequest,
273 ) -> Result<ParsedJson<ProjectResponse>> {
274 self.put_json(format!("/projects/{project}"), Some(req))
275 .await
276 }
277
278 pub async fn delete_project(&self, project: &str) -> Result<ParsedJson<String>> {
279 self.delete_json(format!("/projects/{project}")).await
280 }
281
282 #[allow(unused)]
283 async fn get_teams_list(&self) -> Result<ParsedJson<TeamListResponse>> {
284 self.get_json("/teams").await
285 }
286
287 pub async fn get_deployment_logs(
288 &self,
289 project: &str,
290 deployment_id: &str,
291 ) -> Result<ParsedJson<LogsResponse>> {
292 let path = format!("/projects/{project}/deployments/{deployment_id}/logs");
293
294 self.get_json(path).await
295 }
296 pub async fn get_project_logs(&self, project: &str) -> Result<ParsedJson<LogsResponse>> {
297 let path = format!("/projects/{project}/logs");
298
299 self.get_json(path).await
300 }
301
302 pub async fn get_deployments(
303 &self,
304 project: &str,
305 page: i32,
306 per_page: i32,
307 ) -> Result<ParsedJson<DeploymentListResponse>> {
308 let path = format!(
309 "/projects/{project}/deployments?page={}&per_page={}",
310 page.saturating_sub(1).max(0),
311 per_page.max(1),
312 );
313
314 self.get_json(path).await
315 }
316 pub async fn get_current_deployment(
317 &self,
318 project: &str,
319 ) -> Result<ParsedJson<Option<DeploymentResponse>>> {
320 let path = format!("/projects/{project}/deployments/current");
321
322 self.get_json(path).await
323 }
324
325 pub async fn get_deployment(
326 &self,
327 project: &str,
328 deployment_id: &str,
329 ) -> Result<ParsedJson<DeploymentResponse>> {
330 let path = format!("/projects/{project}/deployments/{deployment_id}");
331
332 self.get_json(path).await
333 }
334
335 pub async fn reset_api_key(&self) -> Result<()> {
336 self.put("/users/reset-api-key", Option::<()>::None)
337 .await?
338 .to_empty()
339 .await
340 }
341
342 pub async fn ws_get(
343 &self,
344 path: impl AsRef<str>,
345 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
346 let ws_url = self
347 .api_url
348 .clone()
349 .replacen("http://", "ws://", 1)
350 .replacen("https://", "wss://", 1);
351 let url = format!("{ws_url}{}", path.as_ref());
352 let mut req = url.into_client_request()?;
353
354 #[cfg(feature = "tracing")]
355 debug!("WS Request: {} {}", req.method(), req.uri());
356
357 if let Some(ref api_key) = self.api_key {
358 let auth_header = Authorization::bearer(api_key.as_ref())?;
359 req.headers_mut().typed_insert(auth_header);
360 }
361
362 let (stream, _) = connect_async(req).await.with_context(|| {
363 #[cfg(feature = "tracing")]
364 error!("failed to connect to websocket");
365 "could not connect to websocket"
366 })?;
367
368 Ok(stream)
369 }
370
371 pub async fn get<T: Serialize>(
372 &self,
373 path: impl AsRef<str>,
374 body: Option<T>,
375 ) -> Result<Response> {
376 let url = format!("{}{}", self.api_url, path.as_ref());
377
378 let mut builder = self.client.get(url);
379 builder = self.set_auth_bearer(builder);
380
381 if let Some(body) = body {
382 let body = serde_json::to_string(&body)?;
383 #[cfg(feature = "tracing")]
384 debug!("Outgoing body: {}", body);
385 builder = builder.body(body);
386 builder = builder.header("Content-Type", "application/json");
387 }
388
389 Ok(builder.send().await?)
390 }
391
392 pub async fn get_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
393 where
394 R: for<'de> Deserialize<'de>,
395 {
396 self.get(path, Option::<()>::None).await?.to_json().await
397 }
398
399 pub async fn get_json_with_body<R, T: Serialize>(
400 &self,
401 path: impl AsRef<str>,
402 body: T,
403 ) -> Result<ParsedJson<R>>
404 where
405 R: for<'de> Deserialize<'de>,
406 {
407 self.get(path, Some(body)).await?.to_json().await
408 }
409
410 pub async fn post<T: Serialize>(
411 &self,
412 path: impl AsRef<str>,
413 body: Option<T>,
414 ) -> Result<Response> {
415 let url = format!("{}{}", self.api_url, path.as_ref());
416
417 let mut builder = self.client.post(url);
418 builder = self.set_auth_bearer(builder);
419
420 if let Some(body) = body {
421 let body = serde_json::to_string(&body)?;
422 #[cfg(feature = "tracing")]
423 debug!("Outgoing body: {}", body);
424 builder = builder.body(body);
425 builder = builder.header("Content-Type", "application/json");
426 }
427
428 Ok(builder.send().await?)
429 }
430
431 pub async fn post_json<T: Serialize, R>(
432 &self,
433 path: impl AsRef<str>,
434 body: Option<T>,
435 ) -> Result<ParsedJson<R>>
436 where
437 R: for<'de> Deserialize<'de>,
438 {
439 self.post(path, body).await?.to_json().await
440 }
441
442 pub async fn put<T: Serialize>(
443 &self,
444 path: impl AsRef<str>,
445 body: Option<T>,
446 ) -> Result<Response> {
447 let url = format!("{}{}", self.api_url, path.as_ref());
448
449 let mut builder = self.client.put(url);
450 builder = self.set_auth_bearer(builder);
451
452 if let Some(body) = body {
453 let body = serde_json::to_string(&body)?;
454 #[cfg(feature = "tracing")]
455 debug!("Outgoing body: {}", body);
456 builder = builder.body(body);
457 builder = builder.header("Content-Type", "application/json");
458 }
459
460 Ok(builder.send().await?)
461 }
462
463 pub async fn put_json<T: Serialize, R>(
464 &self,
465 path: impl AsRef<str>,
466 body: Option<T>,
467 ) -> Result<ParsedJson<R>>
468 where
469 R: for<'de> Deserialize<'de>,
470 {
471 self.put(path, body).await?.to_json().await
472 }
473
474 pub async fn delete<T: Serialize>(
475 &self,
476 path: impl AsRef<str>,
477 body: Option<T>,
478 ) -> Result<Response> {
479 let url = format!("{}{}", self.api_url, path.as_ref());
480
481 let mut builder = self.client.delete(url);
482 builder = self.set_auth_bearer(builder);
483
484 if let Some(body) = body {
485 let body = serde_json::to_string(&body)?;
486 #[cfg(feature = "tracing")]
487 debug!("Outgoing body: {}", body);
488 builder = builder.body(body);
489 builder = builder.header("Content-Type", "application/json");
490 }
491
492 Ok(builder.send().await?)
493 }
494
495 pub async fn delete_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
496 where
497 R: for<'de> Deserialize<'de>,
498 {
499 self.delete(path, Option::<()>::None).await?.to_json().await
500 }
501
502 pub async fn delete_json_with_body<R, T: Serialize>(
503 &self,
504 path: impl AsRef<str>,
505 body: T,
506 ) -> Result<ParsedJson<R>>
507 where
508 R: for<'de> Deserialize<'de>,
509 {
510 self.delete(path, Some(body)).await?.to_json().await
511 }
512}