1use base64::Engine as _;
2use bhttp::{Message, Mode};
3use ohttp::ClientRequest;
4use reqwest::{Client, StatusCode};
5use serde::{Deserialize, Serialize};
6
7use crate::AuthenticatorError;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct OhttpClientConfig {
16 pub relay_url: String,
18 pub key_config_base64: String,
20}
21
22impl OhttpClientConfig {
23 pub fn new(relay_url: String, key_config_base64: String) -> Self {
24 Self {
25 relay_url,
26 key_config_base64,
27 }
28 }
29}
30
31#[derive(Debug)]
33pub struct OhttpResponse {
34 pub status: StatusCode,
36 pub body: Vec<u8>,
38}
39
40#[derive(Clone, Debug)]
43pub struct OhttpClient {
44 client: Client,
45 relay_url: String,
46 target_scheme: String,
47 target_authority: String,
48 encoded_config_list: Vec<u8>,
49}
50
51impl OhttpClient {
52 pub fn new(
66 client: Client,
67 config_scope: &str,
68 target_url: &str,
69 config: OhttpClientConfig,
70 ) -> Result<Self, AuthenticatorError> {
71 let (target_scheme, target_authority) =
72 target_url
73 .split_once("://")
74 .ok_or_else(|| AuthenticatorError::InvalidConfig {
75 attribute: format!("{config_scope}.target_url"),
76 reason: format!("expected scheme://authority, got {:?}", target_url),
77 })?;
78
79 let target_scheme = target_scheme.to_owned();
80 let target_authority = target_authority.trim_end_matches('/').to_owned();
81
82 let attribute = format!("{config_scope}.key_config_base64");
83
84 let encoded_config_list = base64::engine::general_purpose::STANDARD
85 .decode(&config.key_config_base64)
86 .map_err(|err| AuthenticatorError::InvalidConfig {
87 attribute: attribute.clone(),
88 reason: format!("invalid base64: {err}"),
89 })?;
90
91 ClientRequest::from_encoded_config_list(&encoded_config_list).map_err(|err| {
92 AuthenticatorError::InvalidConfig {
93 attribute,
94 reason: format!("invalid application/ohttp-keys payload: {err}"),
95 }
96 })?;
97
98 Ok(Self {
99 client,
100 relay_url: config.relay_url,
101 target_scheme,
102 target_authority,
103 encoded_config_list,
104 })
105 }
106
107 pub async fn post_json<T: serde::Serialize>(
113 &self,
114 path: &str,
115 body: &T,
116 ) -> Result<OhttpResponse, AuthenticatorError> {
117 let body = serde_json::to_vec(body).map_err(|e| {
118 AuthenticatorError::Generic(format!("failed to serialize request body: {e}"))
119 })?;
120 self.request(b"POST", path, Some(&body)).await
121 }
122
123 pub async fn get(&self, path: &str) -> Result<OhttpResponse, AuthenticatorError> {
128 self.request(b"GET", path, None).await
129 }
130
131 async fn request(
132 &self,
133 method: &[u8],
134 path: &str,
135 body: Option<&[u8]>,
136 ) -> Result<OhttpResponse, AuthenticatorError> {
137 let mut msg = Message::request(
138 method.to_vec(),
139 self.target_scheme.as_bytes().to_vec(),
140 self.target_authority.as_bytes().to_vec(),
141 path.as_bytes().to_vec(),
142 );
143 if let Some(body) = body {
144 msg.put_header("content-type", "application/json");
145 msg.write_content(body);
146 }
147 let mut bhttp_buf = Vec::new();
148 msg.write_bhttp(Mode::KnownLength, &mut bhttp_buf)?;
149
150 let ohttp_req = ClientRequest::from_encoded_config_list(&self.encoded_config_list)?;
151 let (enc_request, ohttp_resp_ctx) = ohttp_req.encapsulate(&bhttp_buf)?;
152
153 let resp = self
154 .client
155 .post(&self.relay_url)
156 .header("content-type", "message/ohttp-req")
157 .body(enc_request)
158 .send()
159 .await?;
160
161 if !resp.status().is_success() {
162 return Err(AuthenticatorError::OhttpRelayError {
163 status: resp.status(),
164 body: resp.text().await.unwrap_or_default(),
165 });
166 }
167
168 let enc_response = resp.bytes().await?;
169 let response_buf = ohttp_resp_ctx.decapsulate(&enc_response)?;
170
171 let response_msg = Message::read_bhttp(&mut std::io::Cursor::new(&response_buf))?;
172 let status_code = response_msg
173 .control()
174 .status()
175 .map(|s| s.code())
176 .ok_or_else(|| {
177 AuthenticatorError::Generic("OHTTP response missing HTTP status line".into())
178 })?;
179 let status = StatusCode::from_u16(status_code).map_err(|_| bhttp::Error::InvalidStatus)?;
180
181 Ok(OhttpResponse {
182 status,
183 body: response_msg.content().to_vec(),
184 })
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::AuthenticatorError;
192
193 #[test]
194 fn invalid_base64_key_config_returns_invalid_config() {
195 let config = OhttpClientConfig::new(
196 "http://localhost:1234".into(),
197 "not valid base64 !!!".into(),
198 );
199
200 let result = OhttpClient::new(
201 reqwest::Client::new(),
202 "test_scope",
203 "https://localhost:9999",
204 config,
205 );
206 match result {
207 Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
208 assert_eq!(attribute, "test_scope.key_config_base64");
209 assert!(
210 reason.contains("invalid base64"),
211 "unexpected reason: {reason}"
212 );
213 }
214 other => panic!("expected InvalidConfig, got: {other:?}"),
215 }
216 }
217
218 #[test]
219 fn invalid_ohttp_keys_payload_returns_invalid_config() {
220 let config = OhttpClientConfig::new(
221 "http://localhost:1234".into(),
222 base64::engine::general_purpose::STANDARD
223 .encode(b"definitely not an ohttp-keys payload"),
224 );
225
226 let result = OhttpClient::new(
227 reqwest::Client::new(),
228 "my_scope",
229 "https://localhost:9999",
230 config,
231 );
232 match result {
233 Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
234 assert_eq!(attribute, "my_scope.key_config_base64");
235 assert!(
236 reason.contains("invalid application/ohttp-keys payload"),
237 "unexpected reason: {reason}"
238 );
239 }
240 other => panic!("expected InvalidConfig, got: {other:?}"),
241 }
242 }
243
244 #[test]
245 fn garbage_ohttp_keys_bytes_returns_invalid_config() {
246 let config = OhttpClientConfig::new(
247 "http://127.0.0.1:0/does-not-exist".into(),
248 base64::engine::general_purpose::STANDARD.encode(b"not-a-valid-ohttp-keys"),
249 );
250
251 let result = OhttpClient::new(
252 reqwest::Client::new(),
253 "test",
254 "http://localhost:1234",
255 config,
256 );
257 assert!(
258 matches!(result, Err(AuthenticatorError::InvalidConfig { .. })),
259 "expected InvalidConfig for garbage key config, got: {result:?}"
260 );
261 }
262
263 #[test]
264 fn missing_scheme_in_target_url_returns_invalid_config() {
265 let config = OhttpClientConfig::new(
266 "http://localhost:1234".into(),
267 base64::engine::general_purpose::STANDARD.encode(b"irrelevant"),
268 );
269
270 let result = OhttpClient::new(
271 reqwest::Client::new(),
272 "test_scope",
273 "localhost:9999",
274 config,
275 );
276 match result {
277 Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
278 assert_eq!(attribute, "test_scope.target_url");
279 assert!(
280 reason.contains("expected scheme://authority"),
281 "unexpected reason: {reason}"
282 );
283 }
284 other => panic!("expected InvalidConfig, got: {other:?}"),
285 }
286 }
287
288 #[test]
289 fn empty_key_config_returns_invalid_config() {
290 let config = OhttpClientConfig::new(
291 "http://localhost:1234".into(),
292 base64::engine::general_purpose::STANDARD.encode(b""),
293 );
294
295 let result = OhttpClient::new(
296 reqwest::Client::new(),
297 "test_scope",
298 "https://localhost:9999",
299 config,
300 );
301 assert!(
302 matches!(result, Err(AuthenticatorError::InvalidConfig { .. })),
303 "expected InvalidConfig for empty key config, got: {result:?}"
304 );
305 }
306}