1use std::{
2 collections::BTreeMap,
3 sync::Mutex,
4 time::{Duration, Instant},
5};
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{
10 builder::ClientBuilder,
11 error::SdkError,
12 models::{
13 CallerIdentityResponse, SdkArtifactRegisterRequest, SdkArtifactRegisterResponse,
14 SdkBootstrapResponse, SdkCapabilitiesResponse, SdkEvidenceIngestRequest,
15 SdkEvidenceIngestResponse, SdkKeyAccessPlanRequest, SdkKeyAccessPlanResponse,
16 SdkPolicyResolveRequest, SdkPolicyResolveResponse, SdkProtectionPlanRequest,
17 SdkProtectionPlanResponse, SdkSessionExchangeResponse,
18 },
19};
20
21pub(crate) enum ClientAuthStrategy {
22 StaticBearer(String),
23 SdkClientCredentials(Box<SdkClientCredentialsAuth>),
24}
25
26pub(crate) struct SdkClientCredentialsAuth {
27 pub tenant_id: String,
28 pub client_id: String,
29 pub client_secret: String,
30 pub token_exchange_path: String,
31 pub requested_scopes: Vec<String>,
32 cached_session: Mutex<Option<CachedSessionToken>>,
33}
34
35struct CachedSessionToken {
36 response: SdkSessionExchangeResponse,
37 refresh_at: Instant,
38}
39
40#[derive(Serialize)]
41#[serde(rename_all = "snake_case")]
42struct SdkSessionExchangeRequest<'a> {
43 tenant_id: &'a str,
44 client_id: &'a str,
45 client_secret: &'a str,
46 requested_scopes: &'a [String],
47}
48
49impl SdkClientCredentialsAuth {
50 pub(crate) fn new(
51 tenant_id: String,
52 client_id: String,
53 client_secret: String,
54 token_exchange_path: String,
55 requested_scopes: Vec<String>,
56 ) -> Self {
57 Self {
58 tenant_id,
59 client_id,
60 client_secret,
61 token_exchange_path,
62 requested_scopes,
63 cached_session: Mutex::new(None),
64 }
65 }
66
67 fn resolve_access_token(
68 &self,
69 agent: &ureq::Agent,
70 base_url: &str,
71 ) -> Result<SdkSessionExchangeResponse, SdkError> {
72 {
73 let cache = self.cached_session.lock().map_err(|_| {
74 SdkError::Connection("failed to acquire sdk session cache".to_string())
75 })?;
76 if let Some(cached) = cache.as_ref()
77 && Instant::now() < cached.refresh_at
78 {
79 return Ok(cached.response.clone());
80 }
81 }
82
83 let endpoint = if self.token_exchange_path.starts_with("http://")
84 || self.token_exchange_path.starts_with("https://")
85 {
86 self.token_exchange_path.clone()
87 } else {
88 format!("{}{}", base_url, self.token_exchange_path)
89 };
90 let response = post_json_with_agent::<_, SdkSessionExchangeResponse>(
91 agent,
92 &endpoint,
93 &SdkSessionExchangeRequest {
94 tenant_id: &self.tenant_id,
95 client_id: &self.client_id,
96 client_secret: &self.client_secret,
97 requested_scopes: &self.requested_scopes,
98 },
99 &BTreeMap::new(),
100 )?;
101
102 let refresh_after_secs = if response.expires_in > 60 {
103 response.expires_in - 60
104 } else {
105 1
106 };
107
108 let mut cache = self
109 .cached_session
110 .lock()
111 .map_err(|_| SdkError::Connection("failed to update sdk session cache".to_string()))?;
112 *cache = Some(CachedSessionToken {
113 response: response.clone(),
114 refresh_at: Instant::now() + Duration::from_secs(refresh_after_secs),
115 });
116
117 Ok(response)
118 }
119}
120
121pub struct Client {
122 pub(crate) base_url: String,
123 pub(crate) agent: ureq::Agent,
124 pub(crate) default_headers: BTreeMap<String, String>,
125 pub(crate) auth_strategy: Option<ClientAuthStrategy>,
126}
127
128impl Client {
129 pub(crate) fn new(
130 base_url: String,
131 agent: ureq::Agent,
132 default_headers: BTreeMap<String, String>,
133 auth_strategy: Option<ClientAuthStrategy>,
134 ) -> Self {
135 Self {
136 base_url,
137 agent,
138 default_headers,
139 auth_strategy,
140 }
141 }
142
143 pub fn builder(base_url: impl Into<String>) -> ClientBuilder {
144 ClientBuilder::new(base_url)
145 }
146
147 pub fn base_url(&self) -> &str {
148 &self.base_url
149 }
150
151 pub fn capabilities(&self) -> Result<SdkCapabilitiesResponse, SdkError> {
152 self.get_json("/v1/sdk/capabilities")
153 }
154
155 pub fn whoami(&self) -> Result<CallerIdentityResponse, SdkError> {
156 self.get_json("/v1/sdk/whoami")
157 }
158
159 pub fn bootstrap(&self) -> Result<SdkBootstrapResponse, SdkError> {
160 self.get_json("/v1/sdk/bootstrap")
161 }
162
163 pub fn exchange_session(&self) -> Result<SdkSessionExchangeResponse, SdkError> {
164 match self.auth_strategy.as_ref() {
165 Some(ClientAuthStrategy::SdkClientCredentials(auth)) => {
166 auth.resolve_access_token(&self.agent, &self.base_url)
167 }
168 Some(ClientAuthStrategy::StaticBearer(_)) => Err(SdkError::InvalidInput(
169 "client is configured with a static bearer token; no sdk session exchange is required"
170 .to_string(),
171 )),
172 None => Err(SdkError::InvalidInput(
173 "client is not configured with sdk client credentials".to_string(),
174 )),
175 }
176 }
177
178 pub fn protection_plan(
179 &self,
180 request: &SdkProtectionPlanRequest,
181 ) -> Result<SdkProtectionPlanResponse, SdkError> {
182 self.post_json("/v1/sdk/protection-plan", request)
183 }
184
185 pub fn policy_resolve(
186 &self,
187 request: &SdkPolicyResolveRequest,
188 ) -> Result<SdkPolicyResolveResponse, SdkError> {
189 self.post_json("/v1/sdk/policy-resolve", request)
190 }
191
192 pub fn key_access_plan(
193 &self,
194 request: &SdkKeyAccessPlanRequest,
195 ) -> Result<SdkKeyAccessPlanResponse, SdkError> {
196 self.post_json("/v1/sdk/key-access-plan", request)
197 }
198
199 pub fn artifact_register(
200 &self,
201 request: &SdkArtifactRegisterRequest,
202 ) -> Result<SdkArtifactRegisterResponse, SdkError> {
203 self.post_json("/v1/sdk/artifact-register", request)
204 }
205
206 pub fn evidence(
207 &self,
208 request: &SdkEvidenceIngestRequest,
209 ) -> Result<SdkEvidenceIngestResponse, SdkError> {
210 self.post_json("/v1/sdk/evidence", request)
211 }
212
213 fn get_json<T>(&self, path: &str) -> Result<T, SdkError>
214 where
215 T: DeserializeOwned,
216 {
217 let response = self
218 .apply_headers(self.agent.get(&self.endpoint(path)))?
219 .call()
220 .map_err(map_ureq_error)?;
221 decode_response(response)
222 }
223
224 fn post_json<TReq, TRes>(&self, path: &str, payload: &TReq) -> Result<TRes, SdkError>
225 where
226 TReq: Serialize,
227 TRes: DeserializeOwned,
228 {
229 let payload_json = serde_json::to_string(payload).map_err(|error| {
230 SdkError::Serialization(format!("failed to serialize request payload: {error}"))
231 })?;
232 let response = self
233 .apply_headers(
234 self.agent
235 .post(&self.endpoint(path))
236 .set("Content-Type", "application/json"),
237 )?
238 .send_string(&payload_json)
239 .map_err(map_ureq_error)?;
240 decode_response(response)
241 }
242
243 fn endpoint(&self, path: &str) -> String {
244 format!("{}{}", self.base_url, path)
245 }
246
247 fn apply_headers(&self, mut request: ureq::Request) -> Result<ureq::Request, SdkError> {
248 for (name, value) in &self.default_headers {
249 request = request.set(name, value);
250 }
251
252 if let Some(authorization_header) = self.resolve_authorization_header()? {
253 request = request.set("Authorization", &authorization_header);
254 }
255
256 Ok(request)
257 }
258
259 fn resolve_authorization_header(&self) -> Result<Option<String>, SdkError> {
260 match self.auth_strategy.as_ref() {
261 Some(ClientAuthStrategy::StaticBearer(header)) => Ok(Some(header.clone())),
262 Some(ClientAuthStrategy::SdkClientCredentials(auth)) => {
263 let session = auth.resolve_access_token(&self.agent, &self.base_url)?;
264 Ok(Some(format!("Bearer {}", session.access_token)))
265 }
266 None => Ok(None),
267 }
268 }
269}
270
271fn post_json_with_agent<TReq, TRes>(
272 agent: &ureq::Agent,
273 endpoint: &str,
274 payload: &TReq,
275 headers: &BTreeMap<String, String>,
276) -> Result<TRes, SdkError>
277where
278 TReq: Serialize,
279 TRes: DeserializeOwned,
280{
281 let payload_json = serde_json::to_string(payload).map_err(|error| {
282 SdkError::Serialization(format!("failed to serialize request payload: {error}"))
283 })?;
284 let mut request = agent.post(endpoint).set("Content-Type", "application/json");
285 for (name, value) in headers {
286 request = request.set(name, value);
287 }
288 let response = request.send_string(&payload_json).map_err(map_ureq_error)?;
289 decode_response(response)
290}
291
292fn decode_response<T>(response: ureq::Response) -> Result<T, SdkError>
293where
294 T: DeserializeOwned,
295{
296 let body = response.into_string().map_err(|error| {
297 SdkError::Connection(format!("failed to read HTTP response body: {error}"))
298 })?;
299 serde_json::from_str(&body).map_err(|error| {
300 SdkError::Serialization(format!("failed to decode JSON response body: {error}"))
301 })
302}
303
304fn map_ureq_error(error: ureq::Error) -> SdkError {
305 match error {
306 ureq::Error::Status(status, response) => {
307 let body = response.into_string().unwrap_or_default();
308 SdkError::Server(format!("HTTP {status}: {body}"))
309 }
310 ureq::Error::Transport(transport) => SdkError::Connection(transport.to_string()),
311 }
312}