shuttle_api_client/
lib.rs

1use std::time::Duration;
2
3use anyhow::{Context, Result};
4use headers::{Authorization, HeaderMapExt};
5use percent_encoding::utf8_percent_encode;
6use reqwest::header::HeaderMap;
7use reqwest::Response;
8use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
9use serde::{Deserialize, Serialize};
10use shuttle_common::models::{
11    certificate::{
12        AddCertificateRequest, CertificateListResponse, CertificateResponse,
13        DeleteCertificateRequest,
14    },
15    deployment::{
16        DeploymentListResponse, DeploymentRequest, DeploymentResponse, UploadArchiveResponse,
17    },
18    log::LogsResponse,
19    project::{ProjectCreateRequest, ProjectListResponse, ProjectResponse, ProjectUpdateRequest},
20    resource::{ProvisionResourceRequest, ResourceListResponse, ResourceResponse, ResourceType},
21    team::TeamListResponse,
22    user::UserResponse,
23};
24use tokio::net::TcpStream;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
27
28#[cfg(feature = "tracing")]
29mod middleware;
30#[cfg(feature = "tracing")]
31use crate::middleware::LoggingMiddleware;
32#[cfg(feature = "tracing")]
33use tracing::{debug, error};
34
35mod util;
36use util::ToJson;
37
38#[derive(Clone)]
39pub struct ShuttleApiClient {
40    pub client: ClientWithMiddleware,
41    pub api_url: String,
42    pub api_key: Option<String>,
43}
44
45impl ShuttleApiClient {
46    pub fn new(
47        api_url: String,
48        api_key: Option<String>,
49        headers: Option<HeaderMap>,
50        timeout: Option<u64>,
51    ) -> Self {
52        let mut builder = reqwest::Client::builder();
53        if let Some(h) = headers {
54            builder = builder.default_headers(h);
55        }
56        let client = builder
57            .timeout(Duration::from_secs(timeout.unwrap_or(60)))
58            .build()
59            .unwrap();
60
61        let builder = reqwest_middleware::ClientBuilder::new(client);
62        #[cfg(feature = "tracing")]
63        let builder = builder.with(LoggingMiddleware);
64        let client = builder.build();
65
66        Self {
67            client,
68            api_url,
69            api_key,
70        }
71    }
72
73    pub fn set_auth_bearer(&self, builder: RequestBuilder) -> RequestBuilder {
74        if let Some(ref api_key) = self.api_key {
75            builder.bearer_auth(api_key)
76        } else {
77            builder
78        }
79    }
80
81    pub async fn get_device_auth_ws(&self) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
82        self.ws_get("/device-auth/ws")
83            .await
84            .with_context(|| "failed to connect to auth endpoint")
85    }
86
87    pub async fn check_project_name(&self, project_name: &str) -> Result<bool> {
88        let url = format!("{}/projects/{project_name}/name", self.api_url);
89
90        self.client
91            .get(url)
92            .send()
93            .await
94            .context("failed to check project name availability")?
95            .to_json()
96            .await
97            .context("parsing name check response")
98    }
99
100    pub async fn get_current_user(&self) -> Result<UserResponse> {
101        self.get_json("/users/me").await
102    }
103
104    pub async fn deploy(
105        &self,
106        project: &str,
107        deployment_req: DeploymentRequest,
108    ) -> Result<DeploymentResponse> {
109        let path = format!("/projects/{project}/deployments");
110        self.post_json(path, Some(deployment_req)).await
111    }
112
113    pub async fn upload_archive(
114        &self,
115        project: &str,
116        data: Vec<u8>,
117    ) -> Result<UploadArchiveResponse> {
118        let path = format!("/projects/{project}/archives");
119
120        let url = format!("{}{}", self.api_url, path);
121        let mut builder = self.client.post(url);
122        builder = self.set_auth_bearer(builder);
123
124        builder
125            .body(data)
126            .send()
127            .await
128            .context("failed to upload archive")?
129            .to_json()
130            .await
131    }
132
133    pub async fn redeploy(&self, project: &str, deployment_id: &str) -> Result<DeploymentResponse> {
134        let path = format!("/projects/{project}/deployments/{deployment_id}/redeploy");
135
136        self.post_json(path, Option::<()>::None).await
137    }
138
139    pub async fn stop_service(&self, project: &str) -> Result<String> {
140        let path = format!("/projects/{project}/deployments");
141
142        self.delete_json(path).await
143    }
144
145    pub async fn get_service_resources(&self, project: &str) -> Result<ResourceListResponse> {
146        self.get_json(format!("/projects/{project}/resources"))
147            .await
148    }
149
150    async fn _dump_service_resource(
151        &self,
152        project: &str,
153        resource_type: &ResourceType,
154    ) -> Result<Vec<u8>> {
155        let r#type = resource_type.to_string();
156        let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
157
158        let res = self
159            .get(
160                format!(
161                    "/projects/{project}/services/{project}/resources/{}/dump",
162                    r#type
163                ),
164                Option::<()>::None,
165            )
166            .await?;
167
168        let bytes = res.bytes().await?;
169
170        Ok(bytes.to_vec())
171    }
172
173    pub async fn delete_service_resource(
174        &self,
175        project: &str,
176        resource_type: &ResourceType,
177    ) -> Result<String> {
178        let r#type = resource_type.to_string();
179        let r#type = utf8_percent_encode(&r#type, percent_encoding::NON_ALPHANUMERIC).to_owned();
180
181        self.delete_json(format!("/projects/{project}/resources/{}", r#type))
182            .await
183    }
184    pub async fn provision_resource(
185        &self,
186        project: &str,
187        req: ProvisionResourceRequest,
188    ) -> Result<ResourceResponse> {
189        self.post_json(format!("/projects/{project}/resources"), Some(req))
190            .await
191    }
192    pub async fn get_secrets(&self, project: &str) -> Result<ResourceResponse> {
193        self.get_json(format!("/projects/{project}/resources/secrets"))
194            .await
195    }
196
197    pub async fn list_certificates(&self, project: &str) -> Result<CertificateListResponse> {
198        self.get_json(format!("/projects/{project}/certificates"))
199            .await
200    }
201    pub async fn add_certificate(
202        &self,
203        project: &str,
204        subject: String,
205    ) -> Result<CertificateResponse> {
206        self.post_json(
207            format!("/projects/{project}/certificates"),
208            Some(AddCertificateRequest { subject }),
209        )
210        .await
211    }
212    pub async fn delete_certificate(&self, project: &str, subject: String) -> Result<String> {
213        self.delete_json_with_body(
214            format!("/projects/{project}/certificates"),
215            DeleteCertificateRequest { subject },
216        )
217        .await
218    }
219
220    pub async fn create_project(&self, name: &str) -> Result<ProjectResponse> {
221        self.post_json(
222            "/projects",
223            Some(ProjectCreateRequest {
224                name: name.to_string(),
225            }),
226        )
227        .await
228    }
229
230    pub async fn get_project(&self, project: &str) -> Result<ProjectResponse> {
231        self.get_json(format!("/projects/{project}")).await
232    }
233
234    pub async fn get_projects_list(&self) -> Result<ProjectListResponse> {
235        self.get_json("/projects".to_owned()).await
236    }
237
238    pub async fn update_project(
239        &self,
240        project: &str,
241        req: ProjectUpdateRequest,
242    ) -> Result<ProjectResponse> {
243        self.put_json(format!("/projects/{project}"), Some(req))
244            .await
245    }
246
247    pub async fn delete_project(&self, project: &str) -> Result<String> {
248        self.delete_json(format!("/projects/{project}")).await
249    }
250
251    #[allow(unused)]
252    async fn get_teams_list(&self) -> Result<TeamListResponse> {
253        self.get_json("/teams").await
254    }
255
256    pub async fn get_deployment_logs(
257        &self,
258        project: &str,
259        deployment_id: &str,
260    ) -> Result<LogsResponse> {
261        let path = format!("/projects/{project}/deployments/{deployment_id}/logs");
262
263        self.get_json(path).await
264    }
265    pub async fn get_project_logs(&self, project: &str) -> Result<LogsResponse> {
266        let path = format!("/projects/{project}/logs");
267
268        self.get_json(path).await
269    }
270
271    pub async fn get_deployments(
272        &self,
273        project: &str,
274        page: i32,
275        per_page: i32,
276    ) -> Result<DeploymentListResponse> {
277        let path = format!(
278            "/projects/{project}/deployments?page={}&per_page={}",
279            page.saturating_sub(1).max(0),
280            per_page.max(1),
281        );
282
283        self.get_json(path).await
284    }
285    pub async fn get_current_deployment(
286        &self,
287        project: &str,
288    ) -> Result<Option<DeploymentResponse>> {
289        let path = format!("/projects/{project}/deployments/current");
290
291        self.get_json(path).await
292    }
293
294    pub async fn get_deployment(
295        &self,
296        project: &str,
297        deployment_id: &str,
298    ) -> Result<DeploymentResponse> {
299        let path = format!("/projects/{project}/deployments/{deployment_id}");
300
301        self.get_json(path).await
302    }
303
304    pub async fn reset_api_key(&self) -> Result<Response> {
305        self.put("/users/reset-api-key", Option::<()>::None).await
306    }
307
308    pub async fn ws_get(
309        &self,
310        path: impl AsRef<str>,
311    ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
312        let ws_url = self.api_url.clone().replace("http", "ws");
313        let url = format!("{ws_url}{}", path.as_ref());
314        let mut req = url.into_client_request()?;
315
316        #[cfg(feature = "tracing")]
317        debug!("WS Request: {} {}", req.method(), req.uri());
318
319        if let Some(ref api_key) = self.api_key {
320            let auth_header = Authorization::bearer(api_key.as_ref())?;
321            req.headers_mut().typed_insert(auth_header);
322        }
323
324        let (stream, _) = connect_async(req).await.with_context(|| {
325            #[cfg(feature = "tracing")]
326            error!("failed to connect to websocket");
327            "could not connect to websocket"
328        })?;
329
330        Ok(stream)
331    }
332
333    pub async fn get<T: Serialize>(
334        &self,
335        path: impl AsRef<str>,
336        body: Option<T>,
337    ) -> Result<Response> {
338        let url = format!("{}{}", self.api_url, path.as_ref());
339
340        let mut builder = self.client.get(url);
341        builder = self.set_auth_bearer(builder);
342
343        if let Some(body) = body {
344            let body = serde_json::to_string(&body)?;
345            #[cfg(feature = "tracing")]
346            debug!("Outgoing body: {}", body);
347            builder = builder.body(body);
348            builder = builder.header("Content-Type", "application/json");
349        }
350
351        Ok(builder.send().await?)
352    }
353
354    pub async fn get_json<R>(&self, path: impl AsRef<str>) -> Result<R>
355    where
356        R: for<'de> Deserialize<'de>,
357    {
358        self.get(path, Option::<()>::None).await?.to_json().await
359    }
360
361    pub async fn get_json_with_body<R, T: Serialize>(
362        &self,
363        path: impl AsRef<str>,
364        body: T,
365    ) -> Result<R>
366    where
367        R: for<'de> Deserialize<'de>,
368    {
369        self.get(path, Some(body)).await?.to_json().await
370    }
371
372    pub async fn post<T: Serialize>(
373        &self,
374        path: impl AsRef<str>,
375        body: Option<T>,
376    ) -> Result<Response> {
377        let url = format!("{}{}", self.api_url, path.as_ref());
378
379        let mut builder = self.client.post(url);
380        builder = self.set_auth_bearer(builder);
381
382        if let Some(body) = body {
383            let body = serde_json::to_string(&body)?;
384            #[cfg(feature = "tracing")]
385            debug!("Outgoing body: {}", body);
386            builder = builder.body(body);
387            builder = builder.header("Content-Type", "application/json");
388        }
389
390        Ok(builder.send().await?)
391    }
392
393    pub async fn post_json<T: Serialize, R>(
394        &self,
395        path: impl AsRef<str>,
396        body: Option<T>,
397    ) -> Result<R>
398    where
399        R: for<'de> Deserialize<'de>,
400    {
401        self.post(path, body).await?.to_json().await
402    }
403
404    pub async fn put<T: Serialize>(
405        &self,
406        path: impl AsRef<str>,
407        body: Option<T>,
408    ) -> Result<Response> {
409        let url = format!("{}{}", self.api_url, path.as_ref());
410
411        let mut builder = self.client.put(url);
412        builder = self.set_auth_bearer(builder);
413
414        if let Some(body) = body {
415            let body = serde_json::to_string(&body)?;
416            #[cfg(feature = "tracing")]
417            debug!("Outgoing body: {}", body);
418            builder = builder.body(body);
419            builder = builder.header("Content-Type", "application/json");
420        }
421
422        Ok(builder.send().await?)
423    }
424
425    pub async fn put_json<T: Serialize, R>(
426        &self,
427        path: impl AsRef<str>,
428        body: Option<T>,
429    ) -> Result<R>
430    where
431        R: for<'de> Deserialize<'de>,
432    {
433        self.put(path, body).await?.to_json().await
434    }
435
436    pub async fn delete<T: Serialize>(
437        &self,
438        path: impl AsRef<str>,
439        body: Option<T>,
440    ) -> Result<Response> {
441        let url = format!("{}{}", self.api_url, path.as_ref());
442
443        let mut builder = self.client.delete(url);
444        builder = self.set_auth_bearer(builder);
445
446        if let Some(body) = body {
447            let body = serde_json::to_string(&body)?;
448            #[cfg(feature = "tracing")]
449            debug!("Outgoing body: {}", body);
450            builder = builder.body(body);
451            builder = builder.header("Content-Type", "application/json");
452        }
453
454        Ok(builder.send().await?)
455    }
456
457    pub async fn delete_json<R>(&self, path: impl AsRef<str>) -> Result<R>
458    where
459        R: for<'de> Deserialize<'de>,
460    {
461        self.delete(path, Option::<()>::None).await?.to_json().await
462    }
463
464    pub async fn delete_json_with_body<R, T: Serialize>(
465        &self,
466        path: impl AsRef<str>,
467        body: T,
468    ) -> Result<R>
469    where
470        R: for<'de> Deserialize<'de>,
471    {
472        self.delete(path, Some(body)).await?.to_json().await
473    }
474}