Skip to main content

sdk_rust/
client.rs

1use std::{
2    collections::BTreeMap,
3    sync::Mutex,
4    time::{Duration, Instant},
5};
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{
10    builder::ClientBuilder,
11    error::SdkError,
12    models::{
13        CallerIdentityResponse, SdkArtifactRegisterRequest, SdkArtifactRegisterResponse,
14        SdkBootstrapResponse, SdkCapabilitiesResponse, SdkEvidenceIngestRequest,
15        SdkEvidenceIngestResponse, SdkKeyAccessPlanRequest, SdkKeyAccessPlanResponse,
16        SdkPolicyResolveRequest, SdkPolicyResolveResponse, SdkProtectionPlanRequest,
17        SdkProtectionPlanResponse, SdkSessionExchangeResponse,
18    },
19};
20
21pub(crate) enum ClientAuthStrategy {
22    StaticBearer(String),
23    SdkClientCredentials(Box<SdkClientCredentialsAuth>),
24}
25
26pub(crate) struct SdkClientCredentialsAuth {
27    pub tenant_id: String,
28    pub client_id: String,
29    pub client_secret: String,
30    pub token_exchange_path: String,
31    pub requested_scopes: Vec<String>,
32    cached_session: Mutex<Option<CachedSessionToken>>,
33}
34
35struct CachedSessionToken {
36    response: SdkSessionExchangeResponse,
37    refresh_at: Instant,
38}
39
40#[derive(Serialize)]
41#[serde(rename_all = "snake_case")]
42struct SdkSessionExchangeRequest<'a> {
43    tenant_id: &'a str,
44    client_id: &'a str,
45    client_secret: &'a str,
46    requested_scopes: &'a [String],
47}
48
49impl SdkClientCredentialsAuth {
50    pub(crate) fn new(
51        tenant_id: String,
52        client_id: String,
53        client_secret: String,
54        token_exchange_path: String,
55        requested_scopes: Vec<String>,
56    ) -> Self {
57        Self {
58            tenant_id,
59            client_id,
60            client_secret,
61            token_exchange_path,
62            requested_scopes,
63            cached_session: Mutex::new(None),
64        }
65    }
66
67    fn resolve_access_token(
68        &self,
69        agent: &ureq::Agent,
70        base_url: &str,
71    ) -> Result<SdkSessionExchangeResponse, SdkError> {
72        {
73            let cache = self.cached_session.lock().map_err(|_| {
74                SdkError::Connection("failed to acquire sdk session cache".to_string())
75            })?;
76            if let Some(cached) = cache.as_ref()
77                && Instant::now() < cached.refresh_at
78            {
79                return Ok(cached.response.clone());
80            }
81        }
82
83        let endpoint = if self.token_exchange_path.starts_with("http://")
84            || self.token_exchange_path.starts_with("https://")
85        {
86            self.token_exchange_path.clone()
87        } else {
88            format!("{}{}", base_url, self.token_exchange_path)
89        };
90        let response = post_json_with_agent::<_, SdkSessionExchangeResponse>(
91            agent,
92            &endpoint,
93            &SdkSessionExchangeRequest {
94                tenant_id: &self.tenant_id,
95                client_id: &self.client_id,
96                client_secret: &self.client_secret,
97                requested_scopes: &self.requested_scopes,
98            },
99            &BTreeMap::new(),
100        )?;
101
102        let refresh_after_secs = if response.expires_in > 60 {
103            response.expires_in - 60
104        } else {
105            1
106        };
107
108        let mut cache = self
109            .cached_session
110            .lock()
111            .map_err(|_| SdkError::Connection("failed to update sdk session cache".to_string()))?;
112        *cache = Some(CachedSessionToken {
113            response: response.clone(),
114            refresh_at: Instant::now() + Duration::from_secs(refresh_after_secs),
115        });
116
117        Ok(response)
118    }
119}
120
121pub struct Client {
122    pub(crate) base_url: String,
123    pub(crate) agent: ureq::Agent,
124    pub(crate) default_headers: BTreeMap<String, String>,
125    pub(crate) auth_strategy: Option<ClientAuthStrategy>,
126}
127
128impl Client {
129    pub(crate) fn new(
130        base_url: String,
131        agent: ureq::Agent,
132        default_headers: BTreeMap<String, String>,
133        auth_strategy: Option<ClientAuthStrategy>,
134    ) -> Self {
135        Self {
136            base_url,
137            agent,
138            default_headers,
139            auth_strategy,
140        }
141    }
142
143    pub fn builder(base_url: impl Into<String>) -> ClientBuilder {
144        ClientBuilder::new(base_url)
145    }
146
147    pub fn base_url(&self) -> &str {
148        &self.base_url
149    }
150
151    pub fn capabilities(&self) -> Result<SdkCapabilitiesResponse, SdkError> {
152        self.get_json("/v1/sdk/capabilities")
153    }
154
155    pub fn whoami(&self) -> Result<CallerIdentityResponse, SdkError> {
156        self.get_json("/v1/sdk/whoami")
157    }
158
159    pub fn bootstrap(&self) -> Result<SdkBootstrapResponse, SdkError> {
160        self.get_json("/v1/sdk/bootstrap")
161    }
162
163    pub fn exchange_session(&self) -> Result<SdkSessionExchangeResponse, SdkError> {
164        match self.auth_strategy.as_ref() {
165            Some(ClientAuthStrategy::SdkClientCredentials(auth)) => {
166                auth.resolve_access_token(&self.agent, &self.base_url)
167            }
168            Some(ClientAuthStrategy::StaticBearer(_)) => Err(SdkError::InvalidInput(
169                "client is configured with a static bearer token; no sdk session exchange is required"
170                    .to_string(),
171            )),
172            None => Err(SdkError::InvalidInput(
173                "client is not configured with sdk client credentials".to_string(),
174            )),
175        }
176    }
177
178    pub fn protection_plan(
179        &self,
180        request: &SdkProtectionPlanRequest,
181    ) -> Result<SdkProtectionPlanResponse, SdkError> {
182        self.post_json("/v1/sdk/protection-plan", request)
183    }
184
185    pub fn policy_resolve(
186        &self,
187        request: &SdkPolicyResolveRequest,
188    ) -> Result<SdkPolicyResolveResponse, SdkError> {
189        self.post_json("/v1/sdk/policy-resolve", request)
190    }
191
192    pub fn key_access_plan(
193        &self,
194        request: &SdkKeyAccessPlanRequest,
195    ) -> Result<SdkKeyAccessPlanResponse, SdkError> {
196        self.post_json("/v1/sdk/key-access-plan", request)
197    }
198
199    pub fn artifact_register(
200        &self,
201        request: &SdkArtifactRegisterRequest,
202    ) -> Result<SdkArtifactRegisterResponse, SdkError> {
203        self.post_json("/v1/sdk/artifact-register", request)
204    }
205
206    pub fn evidence(
207        &self,
208        request: &SdkEvidenceIngestRequest,
209    ) -> Result<SdkEvidenceIngestResponse, SdkError> {
210        self.post_json("/v1/sdk/evidence", request)
211    }
212
213    fn get_json<T>(&self, path: &str) -> Result<T, SdkError>
214    where
215        T: DeserializeOwned,
216    {
217        let response = self
218            .apply_headers(self.agent.get(&self.endpoint(path)))?
219            .call()
220            .map_err(map_ureq_error)?;
221        decode_response(response)
222    }
223
224    fn post_json<TReq, TRes>(&self, path: &str, payload: &TReq) -> Result<TRes, SdkError>
225    where
226        TReq: Serialize,
227        TRes: DeserializeOwned,
228    {
229        let payload_json = serde_json::to_string(payload).map_err(|error| {
230            SdkError::Serialization(format!("failed to serialize request payload: {error}"))
231        })?;
232        let response = self
233            .apply_headers(
234                self.agent
235                    .post(&self.endpoint(path))
236                    .set("Content-Type", "application/json"),
237            )?
238            .send_string(&payload_json)
239            .map_err(map_ureq_error)?;
240        decode_response(response)
241    }
242
243    fn endpoint(&self, path: &str) -> String {
244        format!("{}{}", self.base_url, path)
245    }
246
247    fn apply_headers(&self, mut request: ureq::Request) -> Result<ureq::Request, SdkError> {
248        for (name, value) in &self.default_headers {
249            request = request.set(name, value);
250        }
251
252        if let Some(authorization_header) = self.resolve_authorization_header()? {
253            request = request.set("Authorization", &authorization_header);
254        }
255
256        Ok(request)
257    }
258
259    fn resolve_authorization_header(&self) -> Result<Option<String>, SdkError> {
260        match self.auth_strategy.as_ref() {
261            Some(ClientAuthStrategy::StaticBearer(header)) => Ok(Some(header.clone())),
262            Some(ClientAuthStrategy::SdkClientCredentials(auth)) => {
263                let session = auth.resolve_access_token(&self.agent, &self.base_url)?;
264                Ok(Some(format!("Bearer {}", session.access_token)))
265            }
266            None => Ok(None),
267        }
268    }
269}
270
271fn post_json_with_agent<TReq, TRes>(
272    agent: &ureq::Agent,
273    endpoint: &str,
274    payload: &TReq,
275    headers: &BTreeMap<String, String>,
276) -> Result<TRes, SdkError>
277where
278    TReq: Serialize,
279    TRes: DeserializeOwned,
280{
281    let payload_json = serde_json::to_string(payload).map_err(|error| {
282        SdkError::Serialization(format!("failed to serialize request payload: {error}"))
283    })?;
284    let mut request = agent.post(endpoint).set("Content-Type", "application/json");
285    for (name, value) in headers {
286        request = request.set(name, value);
287    }
288    let response = request.send_string(&payload_json).map_err(map_ureq_error)?;
289    decode_response(response)
290}
291
292fn decode_response<T>(response: ureq::Response) -> Result<T, SdkError>
293where
294    T: DeserializeOwned,
295{
296    let body = response.into_string().map_err(|error| {
297        SdkError::Connection(format!("failed to read HTTP response body: {error}"))
298    })?;
299    serde_json::from_str(&body).map_err(|error| {
300        SdkError::Serialization(format!("failed to decode JSON response body: {error}"))
301    })
302}
303
304fn map_ureq_error(error: ureq::Error) -> SdkError {
305    match error {
306        ureq::Error::Status(status, response) => {
307            let body = response.into_string().unwrap_or_default();
308            SdkError::Server(format!("HTTP {status}: {body}"))
309        }
310        ureq::Error::Transport(transport) => SdkError::Connection(transport.to_string()),
311    }
312}