shuttle_api_client/
lib.rs

1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use headers::{Authorization, HeaderMapExt};
5use percent_encoding::utf8_percent_encode;
6use reqwest::header::HeaderMap;
7use reqwest::Response;
8use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
9use serde::{Deserialize, Serialize};
10use shuttle_common::models::{
11    certificate::{
12        AddCertificateRequest, CertificateListResponse, CertificateResponse,
13        DeleteCertificateRequest,
14    },
15    deployment::{
16        DeploymentListResponse, DeploymentRequest, DeploymentResponse, UploadArchiveResponse,
17    },
18    log::LogsResponse,
19    project::{ProjectCreateRequest, ProjectListResponse, ProjectResponse, ProjectUpdateRequest},
20    resource::{ProvisionResourceRequest, ResourceListResponse, ResourceResponse, ResourceType},
21    team::TeamListResponse,
22    user::UserResponse,
23};
24use tokio::net::TcpStream;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
27
28#[cfg(feature = "tracing")]
29mod middleware;
30#[cfg(feature = "tracing")]
31use crate::middleware::LoggingMiddleware;
32#[cfg(feature = "tracing")]
33use tracing::{debug, error};
34
35pub mod util;
36use util::{ParsedJson, ToBodyContent};
37
38#[derive(Clone)]
39pub struct ShuttleApiClient {
40    pub client: ClientWithMiddleware,
41    pub api_url: String,
42    pub api_key: Option<String>,
43}
44
45impl ShuttleApiClient {
46    pub fn new(
47        api_url: String,
48        api_key: Option<String>,
49        headers: Option<HeaderMap>,
50        timeout: Option<u64>,
51    ) -> Self {
52        let mut builder = reqwest::Client::builder();
53        if let Some(h) = headers {
54            builder = builder.default_headers(h);
55        }
56        let client = builder
57            .timeout(Duration::from_secs(timeout.unwrap_or(60)))
58            .build()
59            .unwrap();
60
61        let builder = reqwest_middleware::ClientBuilder::new(client);
62        #[cfg(feature = "tracing")]
63        let builder = builder.with(LoggingMiddleware);
64        let client = builder.build();
65
66        Self {
67            client,
68            api_url,
69            api_key,
70        }
71    }
72
73    pub fn set_auth_bearer(&self, builder: RequestBuilder) -> RequestBuilder {
74        if let Some(ref api_key) = self.api_key {
75            builder.bearer_auth(api_key)
76        } else {
77            builder
78        }
79    }
80
81    pub async fn get_device_auth_ws(&self) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
82        self.ws_get("/device-auth/ws")
83            .await
84            .with_context(|| "failed to connect to auth endpoint")
85    }
86
87    pub async fn check_project_name(&self, project_name: &str) -> Result<ParsedJson<bool>> {
88        let url = format!("{}/projects/{project_name}/name", self.api_url);
89
90        self.client
91            .get(url)
92            .send()
93            .await
94            .context("failed to check project name availability")?
95            .to_json()
96            .await
97            .context("parsing name check response")
98    }
99
100    pub async fn get_current_user(&self) -> Result<ParsedJson<UserResponse>> {
101        self.get_json("/users/me").await
102    }
103
104    pub async fn deploy(
105        &self,
106        project: &str,
107        deployment_req: DeploymentRequest,
108    ) -> Result<ParsedJson<DeploymentResponse>> {
109        let path = format!("/projects/{project}/deployments");
110        self.post_json(path, Some(deployment_req)).await
111    }
112
113    pub async fn upload_archive(
114        &self,
115        project: &str,
116        data: Vec<u8>,
117    ) -> Result<ParsedJson<UploadArchiveResponse>> {
118        let path = format!("/projects/{project}/archives");
119
120        let url = format!("{}{}", self.api_url, path);
121        let mut builder = self.client.post(url);
122        builder = self.set_auth_bearer(builder);
123
124        builder
125            .body(data)
126            .send()
127            .await
128            .context("failed to upload archive")?
129            .to_json()
130            .await
131    }
132
133    pub async fn redeploy(
134        &self,
135        project: &str,
136        deployment_id: &str,
137    ) -> Result<ParsedJson<DeploymentResponse>> {
138        let path = format!("/projects/{project}/deployments/{deployment_id}/redeploy");
139
140        self.post_json(path, Option::<()>::None).await
141    }
142
143    pub async fn stop_service(&self, project: &str) -> Result<ParsedJson<String>> {
144        let path = format!("/projects/{project}/deployments");
145
146        self.delete_json(path).await
147    }
148
149    pub async fn get_service_resources(
150        &self,
151        project: &str,
152    ) -> Result<ParsedJson<ResourceListResponse>> {
153        self.get_json(format!("/projects/{project}/resources"))
154            .await
155    }
156
157    async fn _dump_service_resource(
158        &self,
159        project: &str,
160        resource_type: &ResourceType,
161    ) -> Result<Vec<u8>> {
162        let r#type = resource_type.to_string();
163        let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
164
165        let bytes = self
166            .get(
167                format!(
168                    "/projects/{project}/services/{project}/resources/{}/dump",
169                    r#type
170                ),
171                Option::<()>::None,
172            )
173            .await?
174            .to_bytes()
175            .await?
176            .to_vec();
177
178        Ok(bytes)
179    }
180
181    pub async fn delete_service_resource(
182        &self,
183        project: &str,
184        resource_type: &ResourceType,
185    ) -> Result<ParsedJson<String>> {
186        let r#type = resource_type.to_string();
187        let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
188
189        self.delete_json(format!("/projects/{project}/resources/{}", r#type))
190            .await
191    }
192    pub async fn provision_resource(
193        &self,
194        project: &str,
195        req: ProvisionResourceRequest,
196    ) -> Result<ParsedJson<ResourceResponse>> {
197        self.post_json(format!("/projects/{project}/resources"), Some(req))
198            .await
199    }
200    pub async fn get_secrets(&self, project: &str) -> Result<ParsedJson<ResourceResponse>> {
201        self.get_json(format!("/projects/{project}/resources/secrets"))
202            .await
203    }
204
205    pub async fn list_certificates(
206        &self,
207        project: &str,
208    ) -> Result<ParsedJson<CertificateListResponse>> {
209        self.get_json(format!("/projects/{project}/certificates"))
210            .await
211    }
212    pub async fn add_certificate(
213        &self,
214        project: &str,
215        subject: String,
216    ) -> Result<ParsedJson<CertificateResponse>> {
217        self.post_json(
218            format!("/projects/{project}/certificates"),
219            Some(AddCertificateRequest { subject }),
220        )
221        .await
222    }
223    pub async fn delete_certificate(
224        &self,
225        project: &str,
226        subject: String,
227    ) -> Result<ParsedJson<String>> {
228        self.delete_json_with_body(
229            format!("/projects/{project}/certificates"),
230            DeleteCertificateRequest { subject },
231        )
232        .await
233    }
234
235    pub async fn create_project(&self, name: &str) -> Result<ParsedJson<ProjectResponse>> {
236        self.post_json(
237            "/projects",
238            Some(ProjectCreateRequest {
239                name: name.to_string(),
240            }),
241        )
242        .await
243    }
244
245    pub async fn get_project(&self, project: &str) -> Result<ParsedJson<ProjectResponse>> {
246        self.get_json(format!("/projects/{project}")).await
247    }
248
249    pub async fn get_projects_list(&self) -> Result<ParsedJson<ProjectListResponse>> {
250        self.get_json("/projects".to_owned()).await
251    }
252
253    pub async fn update_project(
254        &self,
255        project: &str,
256        req: ProjectUpdateRequest,
257    ) -> Result<ParsedJson<ProjectResponse>> {
258        self.put_json(format!("/projects/{project}"), Some(req))
259            .await
260    }
261
262    pub async fn delete_project(&self, project: &str) -> Result<ParsedJson<String>> {
263        self.delete_json(format!("/projects/{project}")).await
264    }
265
266    #[allow(unused)]
267    async fn get_teams_list(&self) -> Result<ParsedJson<TeamListResponse>> {
268        self.get_json("/teams").await
269    }
270
271    pub async fn get_deployment_logs(
272        &self,
273        project: &str,
274        deployment_id: &str,
275    ) -> Result<ParsedJson<LogsResponse>> {
276        let path = format!("/projects/{project}/deployments/{deployment_id}/logs");
277
278        self.get_json(path).await
279    }
280    pub async fn get_project_logs(&self, project: &str) -> Result<ParsedJson<LogsResponse>> {
281        let path = format!("/projects/{project}/logs");
282
283        self.get_json(path).await
284    }
285
286    pub async fn get_deployments(
287        &self,
288        project: &str,
289        page: i32,
290        per_page: i32,
291    ) -> Result<ParsedJson<DeploymentListResponse>> {
292        let path = format!(
293            "/projects/{project}/deployments?page={}&per_page={}",
294            page.saturating_sub(1).max(0),
295            per_page.max(1),
296        );
297
298        self.get_json(path).await
299    }
300    pub async fn get_current_deployment(
301        &self,
302        project: &str,
303    ) -> Result<ParsedJson<Option<DeploymentResponse>>> {
304        let path = format!("/projects/{project}/deployments/current");
305
306        self.get_json(path).await
307    }
308
309    pub async fn get_deployment(
310        &self,
311        project: &str,
312        deployment_id: &str,
313    ) -> Result<ParsedJson<DeploymentResponse>> {
314        let path = format!("/projects/{project}/deployments/{deployment_id}");
315
316        self.get_json(path).await
317    }
318
319    pub async fn reset_api_key(&self) -> Result<()> {
320        self.put("/users/reset-api-key", Option::<()>::None)
321            .await?
322            .to_empty()
323            .await
324    }
325
326    pub async fn ws_get(
327        &self,
328        path: impl AsRef<str>,
329    ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
330        let ws_url = self
331            .api_url
332            .clone()
333            .replacen("http://", "ws://", 1)
334            .replacen("https://", "wss://", 1);
335        let url = format!("{ws_url}{}", path.as_ref());
336        let mut req = url.into_client_request()?;
337
338        #[cfg(feature = "tracing")]
339        debug!("WS Request: {} {}", req.method(), req.uri());
340
341        if let Some(ref api_key) = self.api_key {
342            let auth_header = Authorization::bearer(api_key.as_ref())?;
343            req.headers_mut().typed_insert(auth_header);
344        }
345
346        let (stream, _) = connect_async(req).await.with_context(|| {
347            #[cfg(feature = "tracing")]
348            error!("failed to connect to websocket");
349            "could not connect to websocket"
350        })?;
351
352        Ok(stream)
353    }
354
355    pub async fn get<T: Serialize>(
356        &self,
357        path: impl AsRef<str>,
358        body: Option<T>,
359    ) -> Result<Response> {
360        let url = format!("{}{}", self.api_url, path.as_ref());
361
362        let mut builder = self.client.get(url);
363        builder = self.set_auth_bearer(builder);
364
365        if let Some(body) = body {
366            let body = serde_json::to_string(&body)?;
367            #[cfg(feature = "tracing")]
368            debug!("Outgoing body: {}", body);
369            builder = builder.body(body);
370            builder = builder.header("Content-Type", "application/json");
371        }
372
373        Ok(builder.send().await?)
374    }
375
376    pub async fn get_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
377    where
378        R: for<'de> Deserialize<'de>,
379    {
380        self.get(path, Option::<()>::None).await?.to_json().await
381    }
382
383    pub async fn get_json_with_body<R, T: Serialize>(
384        &self,
385        path: impl AsRef<str>,
386        body: T,
387    ) -> Result<ParsedJson<R>>
388    where
389        R: for<'de> Deserialize<'de>,
390    {
391        self.get(path, Some(body)).await?.to_json().await
392    }
393
394    pub async fn post<T: Serialize>(
395        &self,
396        path: impl AsRef<str>,
397        body: Option<T>,
398    ) -> Result<Response> {
399        let url = format!("{}{}", self.api_url, path.as_ref());
400
401        let mut builder = self.client.post(url);
402        builder = self.set_auth_bearer(builder);
403
404        if let Some(body) = body {
405            let body = serde_json::to_string(&body)?;
406            #[cfg(feature = "tracing")]
407            debug!("Outgoing body: {}", body);
408            builder = builder.body(body);
409            builder = builder.header("Content-Type", "application/json");
410        }
411
412        Ok(builder.send().await?)
413    }
414
415    pub async fn post_json<T: Serialize, R>(
416        &self,
417        path: impl AsRef<str>,
418        body: Option<T>,
419    ) -> Result<ParsedJson<R>>
420    where
421        R: for<'de> Deserialize<'de>,
422    {
423        self.post(path, body).await?.to_json().await
424    }
425
426    pub async fn put<T: Serialize>(
427        &self,
428        path: impl AsRef<str>,
429        body: Option<T>,
430    ) -> Result<Response> {
431        let url = format!("{}{}", self.api_url, path.as_ref());
432
433        let mut builder = self.client.put(url);
434        builder = self.set_auth_bearer(builder);
435
436        if let Some(body) = body {
437            let body = serde_json::to_string(&body)?;
438            #[cfg(feature = "tracing")]
439            debug!("Outgoing body: {}", body);
440            builder = builder.body(body);
441            builder = builder.header("Content-Type", "application/json");
442        }
443
444        Ok(builder.send().await?)
445    }
446
447    pub async fn put_json<T: Serialize, R>(
448        &self,
449        path: impl AsRef<str>,
450        body: Option<T>,
451    ) -> Result<ParsedJson<R>>
452    where
453        R: for<'de> Deserialize<'de>,
454    {
455        self.put(path, body).await?.to_json().await
456    }
457
458    pub async fn delete<T: Serialize>(
459        &self,
460        path: impl AsRef<str>,
461        body: Option<T>,
462    ) -> Result<Response> {
463        let url = format!("{}{}", self.api_url, path.as_ref());
464
465        let mut builder = self.client.delete(url);
466        builder = self.set_auth_bearer(builder);
467
468        if let Some(body) = body {
469            let body = serde_json::to_string(&body)?;
470            #[cfg(feature = "tracing")]
471            debug!("Outgoing body: {}", body);
472            builder = builder.body(body);
473            builder = builder.header("Content-Type", "application/json");
474        }
475
476        Ok(builder.send().await?)
477    }
478
479    pub async fn delete_json<R>(&self, path: impl AsRef<str>) -> Result<ParsedJson<R>>
480    where
481        R: for<'de> Deserialize<'de>,
482    {
483        self.delete(path, Option::<()>::None).await?.to_json().await
484    }
485
486    pub async fn delete_json_with_body<R, T: Serialize>(
487        &self,
488        path: impl AsRef<str>,
489        body: T,
490    ) -> Result<ParsedJson<R>>
491    where
492        R: for<'de> Deserialize<'de>,
493    {
494        self.delete(path, Some(body)).await?.to_json().await
495    }
496}