1use aes_gcm::aead::{Aead, KeyInit};
7use aes_gcm::{Aes256Gcm, Nonce};
8use base64::Engine as _;
9use tracing::warn;
10
11use crate::host_match::host_matches;
12
13const TOKEN_PREFIX: &[u8] = b"starpod:v1:";
14const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";
15
16#[derive(Debug)]
18pub struct ScanResult {
19 pub data: Vec<u8>,
21 pub replaced: usize,
23 pub stripped: usize,
25}
26
27fn decode_token(cipher: &Aes256Gcm, token: &str) -> Option<(String, Vec<String>)> {
29 let encoded = token.strip_prefix("starpod:v1:")?;
30
31 let blob = base64::engine::general_purpose::STANDARD
32 .decode(encoded)
33 .ok()?;
34
35 if blob.len() < 13 {
36 return None;
37 }
38
39 let (nonce_bytes, ciphertext) = blob.split_at(12);
40 let nonce = Nonce::from_slice(nonce_bytes);
41 let plaintext = cipher.decrypt(nonce, ciphertext).ok()?;
42
43 #[derive(serde::Deserialize)]
44 struct Payload {
45 v: String,
46 h: Vec<String>,
47 }
48
49 let payload: Payload = serde_json::from_slice(&plaintext).ok()?;
50 Some((payload.v, payload.h))
51}
52
53pub fn scan_and_replace(cipher: &Aes256Gcm, data: &[u8], target_host: &str) -> ScanResult {
61 let mut result = Vec::with_capacity(data.len());
62 let mut replaced = 0usize;
63 let mut stripped = 0usize;
64 let mut i = 0;
65
66 while i < data.len() {
67 if data[i..].starts_with(TOKEN_PREFIX) {
69 let token_start = i;
70 i += TOKEN_PREFIX.len();
71
72 while i < data.len() && BASE64_CHARS.contains(&data[i]) {
74 i += 1;
75 }
76
77 let token_bytes = &data[token_start..i];
78 let token_str = match std::str::from_utf8(token_bytes) {
79 Ok(s) => s,
80 Err(_) => {
81 result.extend_from_slice(token_bytes);
83 continue;
84 }
85 };
86
87 match decode_token(cipher, token_str) {
88 Some((value, allowed_hosts)) => {
89 if host_matches(target_host, &allowed_hosts) {
90 result.extend_from_slice(value.as_bytes());
91 replaced += 1;
92 } else {
93 warn!(
94 target_host = %target_host,
95 allowed_hosts = ?allowed_hosts,
96 "Token host mismatch — stripped"
97 );
98 stripped += 1;
100 }
101 }
102 None => {
103 result.extend_from_slice(token_bytes);
105 }
106 }
107 } else {
108 result.push(data[i]);
109 i += 1;
110 }
111 }
112
113 ScanResult {
114 data: result,
115 replaced,
116 stripped,
117 }
118}
119
120pub fn scan_and_replace_str(cipher: &Aes256Gcm, data: &str, target_host: &str) -> ScanResult {
122 scan_and_replace(cipher, data.as_bytes(), target_host)
123}
124
125pub fn cipher_from_key(master_key: &[u8; 32]) -> Aes256Gcm {
127 Aes256Gcm::new_from_slice(master_key).expect("32-byte key is always valid for AES-256")
128}
129
130#[cfg(test)]
131mod tests {
132 use aes_gcm::aead::OsRng;
133 use aes_gcm::AeadCore;
134
135 use super::*;
136
137 fn test_cipher() -> Aes256Gcm {
138 cipher_from_key(&[0xAB; 32])
139 }
140
141 fn encode_token(cipher: &Aes256Gcm, value: &str, hosts: &[String]) -> String {
142 #[derive(serde::Serialize)]
143 struct Payload {
144 v: String,
145 h: Vec<String>,
146 }
147 let payload = Payload {
148 v: value.to_string(),
149 h: hosts.to_vec(),
150 };
151 let json = serde_json::to_vec(&payload).unwrap();
152 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
153 let ciphertext = cipher.encrypt(&nonce, json.as_ref()).unwrap();
154 let mut blob = Vec::with_capacity(12 + ciphertext.len());
155 blob.extend_from_slice(nonce.as_slice());
156 blob.extend_from_slice(&ciphertext);
157 format!(
158 "starpod:v1:{}",
159 base64::engine::general_purpose::STANDARD.encode(&blob)
160 )
161 }
162
163 #[test]
164 fn replace_token_in_header() {
165 let cipher = test_cipher();
166 let token = encode_token(&cipher, "ghp_real", &["api.github.com".into()]);
167 let header = format!("Bearer {token}");
168
169 let result = scan_and_replace_str(&cipher, &header, "api.github.com");
170 assert_eq!(result.replaced, 1);
171 assert_eq!(result.stripped, 0);
172 assert_eq!(String::from_utf8(result.data).unwrap(), "Bearer ghp_real");
173 }
174
175 #[test]
176 fn strip_token_on_host_mismatch() {
177 let cipher = test_cipher();
178 let token = encode_token(&cipher, "ghp_real", &["api.github.com".into()]);
179 let header = format!("Bearer {token}");
180
181 let result = scan_and_replace_str(&cipher, &header, "evil.com");
182 assert_eq!(result.replaced, 0);
183 assert_eq!(result.stripped, 1);
184 assert_eq!(String::from_utf8(result.data).unwrap(), "Bearer ");
185 }
186
187 #[test]
188 fn unrestricted_token_matches_any_host() {
189 let cipher = test_cipher();
190 let token = encode_token(&cipher, "secret", &[]);
191 let data = format!("key={token}");
192
193 let result = scan_and_replace_str(&cipher, &data, "any-host.com");
194 assert_eq!(result.replaced, 1);
195 assert_eq!(String::from_utf8(result.data).unwrap(), "key=secret");
196 }
197
198 #[test]
199 fn multiple_tokens_in_one_buffer() {
200 let cipher = test_cipher();
201 let t1 = encode_token(&cipher, "val1", &[]);
202 let t2 = encode_token(&cipher, "val2", &[]);
203 let data = format!("a={t1}&b={t2}");
204
205 let result = scan_and_replace_str(&cipher, &data, "host.com");
206 assert_eq!(result.replaced, 2);
207 assert_eq!(String::from_utf8(result.data).unwrap(), "a=val1&b=val2");
208 }
209
210 #[test]
211 fn no_tokens_passes_through() {
212 let cipher = test_cipher();
213 let data = "just normal data with no tokens";
214 let result = scan_and_replace_str(&cipher, data, "host.com");
215 assert_eq!(result.replaced, 0);
216 assert_eq!(result.stripped, 0);
217 assert_eq!(String::from_utf8(result.data).unwrap(), data);
218 }
219
220 #[test]
221 fn wrong_key_leaves_token_as_is() {
222 let cipher1 = test_cipher();
223 let cipher2 = cipher_from_key(&[0xCD; 32]);
224 let token = encode_token(&cipher1, "secret", &[]);
225 let data = format!("key={token}");
226
227 let result = scan_and_replace_str(&cipher2, &data, "host.com");
229 assert_eq!(result.replaced, 0);
230 assert_eq!(result.stripped, 0);
231 assert_eq!(String::from_utf8(result.data).unwrap(), data);
232 }
233
234 #[test]
235 fn token_at_end_of_buffer() {
236 let cipher = test_cipher();
237 let token = encode_token(&cipher, "val", &[]);
238 let data = format!("Authorization: {token}");
239
240 let result = scan_and_replace_str(&cipher, &data, "x.com");
241 assert_eq!(result.replaced, 1);
242 assert_eq!(
243 String::from_utf8(result.data).unwrap(),
244 "Authorization: val"
245 );
246 }
247
248 #[test]
249 fn token_at_start_of_buffer() {
250 let cipher = test_cipher();
251 let token = encode_token(&cipher, "val", &[]);
252
253 let result = scan_and_replace_str(&cipher, &token, "x.com");
254 assert_eq!(result.replaced, 1);
255 assert_eq!(String::from_utf8(result.data).unwrap(), "val");
256 }
257
258 #[test]
259 fn large_body_with_embedded_token() {
260 let cipher = test_cipher();
261 let token = encode_token(&cipher, "secret", &[]);
262 let padding = "-".repeat(100_000);
264 let data = format!("{padding}token={token}{padding}");
265
266 let result = scan_and_replace_str(&cipher, &data, "host.com");
267 assert_eq!(result.replaced, 1);
268 let expected = format!("{padding}token=secret{padding}");
269 assert_eq!(String::from_utf8(result.data).unwrap(), expected);
270 }
271
272 #[test]
273 fn mixed_match_and_mismatch() {
274 let cipher = test_cipher();
275 let good = encode_token(&cipher, "good", &["ok.com".into()]);
276 let bad = encode_token(&cipher, "bad", &["other.com".into()]);
277 let data = format!("a={good}&b={bad}");
278
279 let result = scan_and_replace_str(&cipher, &data, "ok.com");
280 assert_eq!(result.replaced, 1);
281 assert_eq!(result.stripped, 1);
282 assert_eq!(String::from_utf8(result.data).unwrap(), "a=good&b=");
283 }
284}