systemprompt_cloud/api_client/
client.rs1use std::sync::Arc;
5use std::time::Instant;
6
7use reqwest::{Client, StatusCode};
8use serde::de::DeserializeOwned;
9use systemprompt_models::net::{HTTP_CONNECT_TIMEOUT, HTTP_DEFAULT_TIMEOUT};
10use tokio::sync::Mutex;
11
12use super::types::ApiError;
13use crate::error::{CloudError, CloudResult};
14
15pub(super) type TenantTokenCache = Arc<Mutex<Option<(String, Instant)>>>;
16
17#[derive(Debug)]
18pub struct CloudApiClient {
19 pub(super) client: Client,
20 pub(super) api_url: String,
21 pub(super) token: String,
22 pub(super) tenant_token_cache: TenantTokenCache,
23}
24
25impl CloudApiClient {
26 pub fn new(api_url: &str, token: &str) -> Result<Self, reqwest::Error> {
27 Ok(Self {
28 client: Client::builder()
29 .connect_timeout(HTTP_CONNECT_TIMEOUT)
30 .timeout(HTTP_DEFAULT_TIMEOUT)
31 .build()?,
32 api_url: api_url.to_owned(),
33 token: token.to_owned(),
34 tenant_token_cache: Arc::new(Mutex::new(None)),
35 })
36 }
37
38 #[must_use]
39 pub fn api_url(&self) -> &str {
40 &self.api_url
41 }
42
43 #[must_use]
44 pub fn token(&self) -> &str {
45 &self.token
46 }
47
48 pub(super) async fn handle_response<T: DeserializeOwned>(
49 &self,
50 response: reqwest::Response,
51 ) -> CloudResult<T> {
52 let status = response.status();
53
54 if status == StatusCode::UNAUTHORIZED {
55 return Err(CloudError::Unauthorized);
56 }
57
58 if !status.is_success() {
59 return Err(parse_error_response(status, response).await);
60 }
61
62 response.json().await.map_err(CloudError::from)
63 }
64
65 pub(super) async fn handle_no_content_response(
66 &self,
67 response: reqwest::Response,
68 ) -> CloudResult<()> {
69 let status = response.status();
70 if status == StatusCode::UNAUTHORIZED {
71 return Err(CloudError::Unauthorized);
72 }
73 if status == StatusCode::NO_CONTENT || status.is_success() {
74 return Ok(());
75 }
76 Err(parse_error_response(status, response).await)
77 }
78}
79
80pub(super) async fn parse_error_response(
81 status: StatusCode,
82 response: reqwest::Response,
83) -> CloudError {
84 let error_text = match response.text().await {
85 Ok(t) => t,
86 Err(e) => {
87 tracing::warn!(error = %e, "Failed to read error response body");
88 String::from("<failed to read response body>")
89 },
90 };
91
92 serde_json::from_str::<ApiError>(&error_text).map_or_else(
93 |_| CloudError::HttpStatus {
94 status: status.as_u16(),
95 body: error_text.chars().take(500).collect(),
96 },
97 |parsed| CloudError::ApiError {
98 message: format!("{}: {}", parsed.error.code, parsed.error.message),
99 },
100 )
101}