shuttle_api_client/
lib.rs

1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use headers::{Authorization, HeaderMapExt};
5use percent_encoding::utf8_percent_encode;
6use reqwest::{header::HeaderMap, Response};
7use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
8use serde::{Deserialize, Serialize};
9use shuttle_common::models::{
10    certificate::{
11        AddCertificateRequest, CertificateListResponse, CertificateResponse,
12        DeleteCertificateRequest,
13    },
14    deployment::{
15        DeploymentListResponse, DeploymentRequest, DeploymentResponse, UploadArchiveResponse,
16    },
17    log::LogsResponse,
18    project::{ProjectCreateRequest, ProjectListResponse, ProjectResponse, ProjectUpdateRequest},
19    resource::{ProvisionResourceRequest, ResourceListResponse, ResourceResponse, ResourceType},
20    team::TeamListResponse,
21    user::UserResponse,
22};
23use tokio::net::TcpStream;
24use tokio_tungstenite::{
25    connect_async, tungstenite::client::IntoClientRequest, MaybeTlsStream, WebSocketStream,
26};
27
28#[cfg(feature = "tracing")]
29mod middleware;
30#[cfg(feature = "tracing")]
31use tracing::{debug, error};
32
33#[cfg(feature = "tracing")]
34use crate::middleware::LoggingMiddleware;
35
36pub mod util;
37use util::{ParsedJson, ToBodyContent};
38
39#[derive(Clone)]
40pub struct ShuttleApiClient {
41    pub client: ClientWithMiddleware,
42    pub api_url: String,
43    pub api_key: Option<String>,
44}
45
46impl ShuttleApiClient {
47    pub fn new(
48        api_url: String,
49        api_key: Option<String>,
50        headers: Option<HeaderMap>,
51        timeout: Option<u64>,
52    ) -> Self {
53        let mut builder = reqwest::Client::builder();
54
55        if let Ok(proxy) = std::env::var("HTTP_PROXY") {
56            builder = builder.proxy(reqwest::Proxy::http(proxy).unwrap());
57        }
58
59        if let Ok(proxy) = std::env::var("HTTPS_PROXY") {
60            builder = builder.proxy(reqwest::Proxy::https(proxy).unwrap());
61        }
62
63        if std::env::var("SHUTTLE_TLS_INSECURE_SKIP_VERIFY")
64            .is_ok_and(|value| value == "1" || value == "true")
65        {
66            builder = builder.danger_accept_invalid_certs(true);
67        }
68
69        if let Some(headers) = headers {
70            builder = builder.default_headers(headers);
71        }
72
73        let client = builder
74            .timeout(Duration::from_secs(timeout.unwrap_or(60)))
75            .build()
76            .unwrap();
77
78        let builder = reqwest_middleware::ClientBuilder::new(client);
79
80        #[cfg(feature = "tracing")]
81        let builder = builder.with(LoggingMiddleware);
82
83        let client = builder.build();
84
85        Self {
86            client,
87            api_url,
88            api_key,
89        }
90    }
91
92    pub fn set_auth_bearer(&self, builder: RequestBuilder) -> RequestBuilder {
93        if let Some(ref api_key) = self.api_key {
94            builder.bearer_auth(api_key)
95        } else {
96            builder
97        }
98    }
99
100    pub async fn get_device_auth_ws(&self) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
101        self.ws_get("/device-auth/ws")
102            .await
103            .with_context(|| "failed to connect to auth endpoint")
104    }
105
106    pub async fn check_project_name(&self, project_name: &str) -> Result<ParsedJson<bool>> {
107        let url = format!("{}/projects/{project_name}/name", self.api_url);
108
109        self.client
110            .get(url)
111            .send()
112            .await
113            .context("failed to check project name availability")?
114            .to_json()
115            .await
116            .context("parsing name check response")
117    }
118
119    pub async fn get_current_user(&self) -> Result<ParsedJson<UserResponse>> {
120        self.get_json("/users/me").await
121    }
122
123    pub async fn deploy(
124        &self,
125        project: &str,
126        deployment_req: DeploymentRequest,
127    ) -> Result<ParsedJson<DeploymentResponse>> {
128        let path = format!("/projects/{project}/deployments");
129        self.post_json(path, Some(deployment_req)).await
130    }
131
132    pub async fn upload_archive(
133        &self,
134        project: &str,
135        data: Vec<u8>,
136    ) -> Result<ParsedJson<UploadArchiveResponse>> {
137        let path = format!("/projects/{project}/archives");
138
139        let url = format!("{}{}", self.api_url, path);
140        let mut builder = self.client.post(url);
141        builder = self.set_auth_bearer(builder);
142
143        builder
144            .body(data)
145            .send()
146            .await
147            .context("failed to upload archive")?
148            .to_json()
149            .await
150    }
151
152    pub async fn redeploy(
153        &self,
154        project: &str,
155        deployment_id: &str,
156    ) -> Result<ParsedJson<DeploymentResponse>> {
157        let path = format!("/projects/{project}/deployments/{deployment_id}/redeploy");
158
159        self.post_json(path, Option::<()>::None).await
160    }
161
162    pub async fn stop_service(&self, project: &str) -> Result<ParsedJson<String>> {
163        let path = format!("/projects/{project}/deployments");
164
165        self.delete_json(path).await
166    }
167
168    pub async fn get_service_resources(
169        &self,
170        project: &str,
171    ) -> Result<ParsedJson<ResourceListResponse>> {
172        self.get_json(format!("/projects/{project}/resources"))
173            .await
174    }
175
176    pub async fn dump_service_resource(
177        &self,
178        project: &str,
179        resource_type: &ResourceType,
180    ) -> Result<Vec<u8>> {
181        let r#type = resource_type.to_string();
182        let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
183
184        let bytes = self
185            .get(
186                format!("/projects/{project}/resources/{type}/dump"),
187                Option::<()>::None,
188            )
189            .await?
190            .to_bytes()
191            .await?
192            .to_vec();
193
194        Ok(bytes)
195    }
196
197    pub async fn delete_service_resource(
198        &self,
199        project: &str,
200        resource_type: &ResourceType,
201    ) -> Result<ParsedJson<String>> {
202        let r#type = resource_type.to_string();
203        let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
204
205        self.delete_json(format!("/projects/{project}/resources/{}", r#type))
206            .await
207    }
208    pub async fn provision_resource(
209        &self,
210        project: &str,
211        req: ProvisionResourceRequest,
212    ) -> Result<ParsedJson<ResourceResponse>> {
213        self.post_json(format!("/projects/{project}/resources"), Some(req))
214            .await
215    }
216    pub async fn get_secrets(&self, project: &str) -> Result<ParsedJson<ResourceResponse>> {
217        self.get_json(format!("/projects/{project}/resources/secrets"))
218            .await
219    }
220
221    pub async fn list_certificates(
222        &self,
223        project: &str,
224    ) -> Result<ParsedJson<CertificateListResponse>> {
225        self.get_json(format!("/projects/{project}/certificates"))
226            .await
227    }
228    pub async fn add_certificate(
229        &self,
230        project: &str,
231        subject: String,
232    ) -> Result<ParsedJson<CertificateResponse>> {
233        self.post_json(
234            format!("/projects/{project}/certificates"),
235            Some(AddCertificateRequest { subject }),
236        )
237        .await
238    }
239    pub async fn delete_certificate(
240        &self,
241        project: &str,
242        subject: String,
243    ) -> Result<ParsedJson<String>> {
244        self.delete_json_with_body(
245            format!("/projects/{project}/certificates"),
246            DeleteCertificateRequest { subject },
247        )
248        .await
249    }
250
251    pub async fn create_project(&self, name: &str) -> Result<ParsedJson<ProjectResponse>> {
252        self.post_json(
253            "/projects",
254            Some(ProjectCreateRequest {
255                name: name.to_string(),
256            }),
257        )
258        .await
259    }
260
261    pub async fn get_project(&self, project: &str) -> Result<ParsedJson<ProjectResponse>> {
262        self.get_json(format!("/projects/{project}")).await
263    }
264
265    pub async fn get_projects_list(&self) -> Result<ParsedJson<ProjectListResponse>> {
266        self.get_json("/projects".to_owned()).await
267    }
268
269    pub async fn update_project(
270        &self,
271        project: &str,
272        req: ProjectUpdateRequest,
273    ) -> Result<ParsedJson<ProjectResponse>> {
274        self.put_json(format!("/projects/{project}"), Some(req))
275            .await
276    }
277
278    pub async fn delete_project(&self, project: &str) -> Result<ParsedJson<String>> {
279        self.delete_json(format!("/projects/{project}")).await
280    }
281
282    #[allow(unused)]
283    async fn get_teams_list(&self) -> Result<ParsedJson<TeamListResponse>> {
284        self.get_json("/teams").await
285    }
286
287    pub async fn get_deployment_logs(
288        &self,
289        project: &str,
290        deployment_id: &str,
291    ) -> Result<ParsedJson<LogsResponse>> {
292        let path = format!("/projects/{project}/deployments/{deployment_id}/logs");
293
294        self.get_json(path).await
295    }
296    pub async fn get_project_logs(&self, project: &str) -> Result<ParsedJson<LogsResponse>> {
297        let path = format!("/projects/{project}/logs");
298
299        self.get_json(path).await
300    }
301
302    pub async fn get_deployments(
303        &self,
304        project: &str,
305        page: i32,
306        per_page: i32,
307    ) -> Result<ParsedJson<DeploymentListResponse>> {
308        let path = format!(
309            "/projects/{project}/deployments?page={}&per_page={}",
310            page.saturating_sub(1).max(0),
311            per_page.max(1),
312        );
313
314        self.get_json(path).await
315    }
316    pub async fn get_current_deployment(
317        &self,
318        project: &str,
319    ) -> Result<ParsedJson<Option<DeploymentResponse>>> {
320        let path = format!("/projects/{project}/deployments/current");
321
322        self.get_json(path).await
323    }
324
325    pub async fn get_deployment(
326        &self,
327        project: &str,
328        deployment_id: &str,
329    ) -> Result<ParsedJson<DeploymentResponse>> {
330        let path = format!("/projects/{project}/deployments/{deployment_id}");
331
332        self.get_json(path).await
333    }
334
335    pub async fn reset_api_key(&self) -> Result<()> {
336        self.put("/users/reset-api-key", Option::<()>::None)
337            .await?
338            .to_empty()
339            .await
340    }
341
342    pub async fn ws_get(
343        &self,
344        path: impl AsRef<str>,
345    ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
346        let ws_url = self
347            .api_url
348            .clone()
349            .replacen("http://", "ws://", 1)
350            .replacen("https://", "wss://", 1);
351        let url = format!("{ws_url}{}", path.as_ref());
352        let mut req = url.into_client_request()?;
353
354        #[cfg(feature = "tracing")]
355        debug!("WS Request: {} {}", req.method(), req.uri());
356
357        if let Some(ref api_key) = self.api_key {
358            let auth_header = Authorization::bearer(api_key.as_ref())?;
359            req.headers_mut().typed_insert(auth_header);
360        }
361
362        let (stream, _) = connect_async(req).await.with_context(|| {
363            #[cfg(feature = "tracing")]
364            error!("failed to connect to websocket");
365            "could not connect to websocket"
366        })?;
367
368        Ok(stream)
369    }
370
371    pub async fn get<T: Serialize>(
372        &self,
373        path: impl AsRef<str>,
374        body: Option<T>,
375    ) -> Result<Response> {
376        let url = format!("{}{}", self.api_url, path.as_ref());
377
378        let mut builder = self.client.get(url);
379        builder = self.set_auth_bearer(builder);
380
381        if let Some(body) = body {
382            let body = serde_json::to_string(&body)?;
383            #[cfg(feature = "tracing")]
384            debug!("Outgoing body: {}", body);
385            builder = builder.body(body);
386            builder = builder.header("Content-Type", "application/json");
387        }
388
389        Ok(builder.send().await?)
390    }
391
392    pub async fn get_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
393    where
394        R: for<'de> Deserialize<'de>,
395    {
396        self.get(path, Option::<()>::None).await?.to_json().await
397    }
398
399    pub async fn get_json_with_body<R, T: Serialize>(
400        &self,
401        path: impl AsRef<str>,
402        body: T,
403    ) -> Result<ParsedJson<R>>
404    where
405        R: for<'de> Deserialize<'de>,
406    {
407        self.get(path, Some(body)).await?.to_json().await
408    }
409
410    pub async fn post<T: Serialize>(
411        &self,
412        path: impl AsRef<str>,
413        body: Option<T>,
414    ) -> Result<Response> {
415        let url = format!("{}{}", self.api_url, path.as_ref());
416
417        let mut builder = self.client.post(url);
418        builder = self.set_auth_bearer(builder);
419
420        if let Some(body) = body {
421            let body = serde_json::to_string(&body)?;
422            #[cfg(feature = "tracing")]
423            debug!("Outgoing body: {}", body);
424            builder = builder.body(body);
425            builder = builder.header("Content-Type", "application/json");
426        }
427
428        Ok(builder.send().await?)
429    }
430
431    pub async fn post_json<T: Serialize, R>(
432        &self,
433        path: impl AsRef<str>,
434        body: Option<T>,
435    ) -> Result<ParsedJson<R>>
436    where
437        R: for<'de> Deserialize<'de>,
438    {
439        self.post(path, body).await?.to_json().await
440    }
441
442    pub async fn put<T: Serialize>(
443        &self,
444        path: impl AsRef<str>,
445        body: Option<T>,
446    ) -> Result<Response> {
447        let url = format!("{}{}", self.api_url, path.as_ref());
448
449        let mut builder = self.client.put(url);
450        builder = self.set_auth_bearer(builder);
451
452        if let Some(body) = body {
453            let body = serde_json::to_string(&body)?;
454            #[cfg(feature = "tracing")]
455            debug!("Outgoing body: {}", body);
456            builder = builder.body(body);
457            builder = builder.header("Content-Type", "application/json");
458        }
459
460        Ok(builder.send().await?)
461    }
462
463    pub async fn put_json<T: Serialize, R>(
464        &self,
465        path: impl AsRef<str>,
466        body: Option<T>,
467    ) -> Result<ParsedJson<R>>
468    where
469        R: for<'de> Deserialize<'de>,
470    {
471        self.put(path, body).await?.to_json().await
472    }
473
474    pub async fn delete<T: Serialize>(
475        &self,
476        path: impl AsRef<str>,
477        body: Option<T>,
478    ) -> Result<Response> {
479        let url = format!("{}{}", self.api_url, path.as_ref());
480
481        let mut builder = self.client.delete(url);
482        builder = self.set_auth_bearer(builder);
483
484        if let Some(body) = body {
485            let body = serde_json::to_string(&body)?;
486            #[cfg(feature = "tracing")]
487            debug!("Outgoing body: {}", body);
488            builder = builder.body(body);
489            builder = builder.header("Content-Type", "application/json");
490        }
491
492        Ok(builder.send().await?)
493    }
494
495    pub async fn delete_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
496    where
497        R: for<'de> Deserialize<'de>,
498    {
499        self.delete(path, Option::<()>::None).await?.to_json().await
500    }
501
502    pub async fn delete_json_with_body<R, T: Serialize>(
503        &self,
504        path: impl AsRef<str>,
505        body: T,
506    ) -> Result<ParsedJson<R>>
507    where
508        R: for<'de> Deserialize<'de>,
509    {
510        self.delete(path, Some(body)).await?.to_json().await
511    }
512}