Skip to main content

systemprompt_cloud/api_client/
client.rs

1//! `CloudApiClient` constructor + accessors. Lower-level HTTP verbs
2//! live in `methods.rs`; high-level endpoints in `endpoints.rs`.
3
4use std::sync::Arc;
5use std::time::Instant;
6
7use reqwest::{Client, StatusCode};
8use serde::de::DeserializeOwned;
9use systemprompt_models::net::{HTTP_CONNECT_TIMEOUT, HTTP_DEFAULT_TIMEOUT};
10use tokio::sync::Mutex;
11
12use super::types::ApiError;
13use crate::error::{CloudError, CloudResult};
14
15pub(super) type TenantTokenCache = Arc<Mutex<Option<(String, Instant)>>>;
16
17#[derive(Debug)]
18pub struct CloudApiClient {
19    pub(super) client: Client,
20    pub(super) api_url: String,
21    pub(super) token: String,
22    pub(super) tenant_token_cache: TenantTokenCache,
23}
24
25impl CloudApiClient {
26    pub fn new(api_url: &str, token: &str) -> Result<Self, reqwest::Error> {
27        Ok(Self {
28            client: Client::builder()
29                .connect_timeout(HTTP_CONNECT_TIMEOUT)
30                .timeout(HTTP_DEFAULT_TIMEOUT)
31                .build()?,
32            api_url: api_url.to_owned(),
33            token: token.to_owned(),
34            tenant_token_cache: Arc::new(Mutex::new(None)),
35        })
36    }
37
38    #[must_use]
39    pub fn api_url(&self) -> &str {
40        &self.api_url
41    }
42
43    #[must_use]
44    pub fn token(&self) -> &str {
45        &self.token
46    }
47
48    pub(super) async fn handle_response<T: DeserializeOwned>(
49        &self,
50        response: reqwest::Response,
51    ) -> CloudResult<T> {
52        let status = response.status();
53
54        if status == StatusCode::UNAUTHORIZED {
55            return Err(CloudError::Unauthorized);
56        }
57
58        if !status.is_success() {
59            return Err(parse_error_response(status, response).await);
60        }
61
62        response.json().await.map_err(CloudError::from)
63    }
64
65    pub(super) async fn handle_no_content_response(
66        &self,
67        response: reqwest::Response,
68    ) -> CloudResult<()> {
69        let status = response.status();
70        if status == StatusCode::UNAUTHORIZED {
71            return Err(CloudError::Unauthorized);
72        }
73        if status == StatusCode::NO_CONTENT || status.is_success() {
74            return Ok(());
75        }
76        Err(parse_error_response(status, response).await)
77    }
78}
79
80pub(super) async fn parse_error_response(
81    status: StatusCode,
82    response: reqwest::Response,
83) -> CloudError {
84    let error_text = match response.text().await {
85        Ok(t) => t,
86        Err(e) => {
87            tracing::warn!(error = %e, "Failed to read error response body");
88            String::from("<failed to read response body>")
89        },
90    };
91
92    serde_json::from_str::<ApiError>(&error_text).map_or_else(
93        |_| CloudError::HttpStatus {
94            status: status.as_u16(),
95            body: error_text.chars().take(500).collect(),
96        },
97        |parsed| CloudError::ApiError {
98            message: format!("{}: {}", parsed.error.code, parsed.error.message),
99        },
100    )
101}