1use std::time::Duration;
2
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use base64::Engine as _;
5use ed25519_dalek::{Signer, SigningKey};
6use hkdf::Hkdf;
7use hmac::{Hmac, Mac};
8use reqwest::Client;
9use sha2::{Digest, Sha256};
10use url::Url;
11
12use crate::models::{
13 WaitForWorkloadActivationOpts, WaitForWorkloadActivationResponse, WorkloadActivationPayload,
14 WorkloadActivationWaitRequest, WorkloadIdentity,
15};
16use crate::TrellisAuthError;
17
18type HmacSha256 = Hmac<Sha256>;
19
20const WORKLOAD_IDENTITY_HKDF_INFO: &str = "trellis/workload-identity/v1";
21const WORKLOAD_ACTIVATION_HKDF_INFO: &str = "trellis/workload-activate/v1";
22const WORKLOAD_QR_MAC_DOMAIN: &str = "trellis-workload-qr/v1";
23const WORKLOAD_CONFIRMATION_DOMAIN: &str = "trellis-workload-confirm/v1";
24const CROCKFORD_ALPHABET: &[u8; 32] = b"0123456789ABCDEFGHJKMNPQRSTVWXYZ";
25
26fn base64url_encode(bytes: &[u8]) -> String {
27 URL_SAFE_NO_PAD.encode(bytes)
28}
29
30fn base64url_decode(value: &str) -> Result<Vec<u8>, TrellisAuthError> {
31 URL_SAFE_NO_PAD
32 .decode(value)
33 .map_err(|error| TrellisAuthError::InvalidArgument(format!("invalid base64url: {error}")))
34}
35
36fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>, TrellisAuthError> {
37 let mut mac = HmacSha256::new_from_slice(key)
38 .map_err(|error| TrellisAuthError::InvalidArgument(format!("invalid hmac key: {error}")))?;
39 mac.update(data);
40 Ok(mac.finalize().into_bytes().to_vec())
41}
42
43fn concat_bytes(parts: &[&[u8]]) -> Vec<u8> {
44 let size = parts.iter().map(|part| part.len()).sum();
45 let mut out = Vec::with_capacity(size);
46 for part in parts {
47 out.extend_from_slice(part);
48 }
49 out
50}
51
52fn crockford_encode(bytes: &[u8]) -> String {
53 let mut value = 0u32;
54 let mut bits = 0u32;
55 let mut out = String::new();
56
57 for byte in bytes {
58 value = (value << 8) | (*byte as u32);
59 bits += 8;
60 while bits >= 5 {
61 bits -= 5;
62 out.push(CROCKFORD_ALPHABET[((value >> bits) & 31) as usize] as char);
63 }
64 }
65
66 if bits > 0 {
67 out.push(CROCKFORD_ALPHABET[((value << (5 - bits)) & 31) as usize] as char);
68 }
69
70 out
71}
72
73fn normalize_crockford(value: &str) -> String {
74 value
75 .trim()
76 .to_uppercase()
77 .replace('O', "0")
78 .replace(['I', 'L'], "1")
79}
80
81pub fn derive_workload_identity(
82 workload_root_secret: &[u8],
83) -> Result<WorkloadIdentity, TrellisAuthError> {
84 if workload_root_secret.len() != 32 {
85 return Err(TrellisAuthError::InvalidArgument(format!(
86 "invalid workload root secret length: {} (expected 32)",
87 workload_root_secret.len()
88 )));
89 }
90
91 let hkdf = Hkdf::<Sha256>::new(Some(&[]), workload_root_secret);
92 let mut identity_seed = [0u8; 32];
93 hkdf.expand(WORKLOAD_IDENTITY_HKDF_INFO.as_bytes(), &mut identity_seed)
94 .map_err(|error| TrellisAuthError::InvalidArgument(format!("failed to derive workload identity seed: {error}")))?;
95 let mut activation_key = [0u8; 32];
96 hkdf.expand(WORKLOAD_ACTIVATION_HKDF_INFO.as_bytes(), &mut activation_key)
97 .map_err(|error| TrellisAuthError::InvalidArgument(format!("failed to derive activation key: {error}")))?;
98
99 let signing_key = SigningKey::from_bytes(&identity_seed);
100 let public_identity_key = base64url_encode(&signing_key.verifying_key().to_bytes());
101
102 Ok(WorkloadIdentity {
103 identity_seed_base64url: base64url_encode(&identity_seed),
104 public_identity_key,
105 activation_key_base64url: base64url_encode(&activation_key),
106 })
107}
108
109pub fn derive_workload_qr_mac(
110 activation_key_base64url: &str,
111 public_identity_key: &str,
112 nonce: &str,
113) -> Result<String, TrellisAuthError> {
114 let activation_key = base64url_decode(activation_key_base64url)?;
115 let mac = hmac_sha256(
116 &activation_key,
117 &concat_bytes(&[
118 WORKLOAD_QR_MAC_DOMAIN.as_bytes(),
119 public_identity_key.as_bytes(),
120 nonce.as_bytes(),
121 ]),
122 )?;
123 Ok(base64url_encode(&mac[..8]))
124}
125
126pub fn build_workload_activation_payload(
127 activation_key_base64url: &str,
128 public_identity_key: &str,
129 nonce: &str,
130) -> Result<WorkloadActivationPayload, TrellisAuthError> {
131 Ok(WorkloadActivationPayload {
132 v: 1,
133 public_identity_key: public_identity_key.to_string(),
134 nonce: nonce.to_string(),
135 qr_mac: derive_workload_qr_mac(
136 activation_key_base64url,
137 public_identity_key,
138 nonce,
139 )?,
140 })
141}
142
143pub fn encode_workload_activation_payload(
144 payload: &WorkloadActivationPayload,
145) -> Result<String, TrellisAuthError> {
146 serde_json::to_vec(payload)
147 .map(|bytes| base64url_encode(&bytes))
148 .map_err(|error| TrellisAuthError::InvalidArgument(format!("invalid workload activation payload: {error}")))
149}
150
151pub fn parse_workload_activation_payload(
152 payload_base64url: &str,
153) -> Result<WorkloadActivationPayload, TrellisAuthError> {
154 let bytes = base64url_decode(payload_base64url)?;
155 serde_json::from_slice(&bytes)
156 .map_err(|error| TrellisAuthError::InvalidArgument(format!("invalid workload activation payload: {error}")))
157}
158
159pub fn build_workload_activation_url(
160 auth_url: &str,
161 payload: &WorkloadActivationPayload,
162) -> Result<String, TrellisAuthError> {
163 let mut url = Url::parse(auth_url)?;
164 url.set_path("/auth/workloads/activate");
165 url.set_query(Some(&format!("payload={}", encode_workload_activation_payload(payload)?)));
166 Ok(url.to_string())
167}
168
169pub fn build_workload_wait_proof_input(
170 public_identity_key: &str,
171 nonce: &str,
172 iat: u64,
173) -> Vec<u8> {
174 let public_identity_key = public_identity_key.as_bytes();
175 let nonce = nonce.as_bytes();
176 let iat = iat.to_string();
177 let iat = iat.as_bytes();
178
179 let mut out = Vec::with_capacity(
180 4 + public_identity_key.len() + 4 + nonce.len() + 4 + iat.len(),
181 );
182 out.extend_from_slice(&(public_identity_key.len() as u32).to_be_bytes());
183 out.extend_from_slice(public_identity_key);
184 out.extend_from_slice(&(nonce.len() as u32).to_be_bytes());
185 out.extend_from_slice(nonce);
186 out.extend_from_slice(&(iat.len() as u32).to_be_bytes());
187 out.extend_from_slice(iat);
188 out
189}
190
191pub fn sign_workload_wait_request(
192 public_identity_key: &str,
193 nonce: &str,
194 identity_seed_base64url: &str,
195 contract_digest: Option<&str>,
196 iat: u64,
197) -> Result<WorkloadActivationWaitRequest, TrellisAuthError> {
198 let identity_seed = base64url_decode(identity_seed_base64url)?;
199 if identity_seed.len() != 32 {
200 return Err(TrellisAuthError::InvalidArgument(format!(
201 "invalid identity seed length: {} (expected 32)",
202 identity_seed.len()
203 )));
204 }
205 let mut seed = [0u8; 32];
206 seed.copy_from_slice(&identity_seed);
207 let signing_key = SigningKey::from_bytes(&seed);
208 let digest = Sha256::digest(build_workload_wait_proof_input(public_identity_key, nonce, iat));
209 let signature = signing_key.sign(&digest);
210
211 Ok(WorkloadActivationWaitRequest {
212 public_identity_key: public_identity_key.to_string(),
213 contract_digest: contract_digest.map(ToOwned::to_owned),
214 nonce: nonce.to_string(),
215 iat,
216 sig: base64url_encode(&signature.to_bytes()),
217 })
218}
219
220pub async fn wait_for_workload_activation_response(
221 auth_url: &str,
222 request: &WorkloadActivationWaitRequest,
223) -> Result<WaitForWorkloadActivationResponse, TrellisAuthError> {
224 let url = Url::parse(auth_url)?.join("/auth/workloads/activate/wait")?;
225 let response = Client::new().post(url).json(request).send().await?;
226 if !response.status().is_success() {
227 let status = response.status().as_u16();
228 let body = response.text().await.unwrap_or_default();
229 return Err(TrellisAuthError::WorkloadActivationWaitFailure(status, body));
230 }
231
232 response.json().await.map_err(TrellisAuthError::from)
233}
234
235pub async fn wait_for_workload_activation(
236 opts: WaitForWorkloadActivationOpts<'_>,
237) -> Result<serde_json::Value, TrellisAuthError> {
238 loop {
239 let request = sign_workload_wait_request(
240 opts.public_identity_key,
241 opts.nonce,
242 opts.identity_seed_base64url,
243 opts.contract_digest,
244 std::time::SystemTime::now()
245 .duration_since(std::time::UNIX_EPOCH)
246 .unwrap_or_default()
247 .as_secs(),
248 )?;
249 match wait_for_workload_activation_response(opts.auth_url, &request).await? {
250 WaitForWorkloadActivationResponse::Activated { connect_info, .. } => {
251 return Ok(connect_info)
252 }
253 WaitForWorkloadActivationResponse::Rejected { reason } => {
254 return Err(TrellisAuthError::WorkloadActivationRejected(match reason {
255 Some(reason) => format!(": {reason}"),
256 None => String::new(),
257 }))
258 }
259 WaitForWorkloadActivationResponse::Pending => tokio::time::sleep(match opts.poll_interval {
260 duration if duration.is_zero() => Duration::from_millis(1),
261 duration => duration,
262 })
263 .await,
264 }
265 }
266}
267
268pub fn derive_workload_confirmation_code(
269 activation_key_base64url: &str,
270 public_identity_key: &str,
271 nonce: &str,
272) -> Result<String, TrellisAuthError> {
273 let activation_key = base64url_decode(activation_key_base64url)?;
274 let mac = hmac_sha256(
275 &activation_key,
276 &concat_bytes(&[
277 WORKLOAD_CONFIRMATION_DOMAIN.as_bytes(),
278 public_identity_key.as_bytes(),
279 nonce.as_bytes(),
280 ]),
281 )?;
282 Ok(crockford_encode(&mac[..5]))
283}
284
285pub fn verify_workload_confirmation_code(
286 activation_key_base64url: &str,
287 public_identity_key: &str,
288 nonce: &str,
289 confirmation_code: &str,
290) -> Result<bool, TrellisAuthError> {
291 Ok(normalize_crockford(&derive_workload_confirmation_code(
292 activation_key_base64url,
293 public_identity_key,
294 nonce,
295 )?) == normalize_crockford(confirmation_code))
296}