1use base64::Engine as _;
2use bhttp::{Message, Mode};
3use ohttp::ClientRequest;
4use reqwest::{Client, StatusCode};
5use serde::{Deserialize, Serialize};
6
7use crate::AuthenticatorError;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
16pub struct OhttpClientConfig {
17 pub relay_url: String,
19 pub key_config_base64: String,
21}
22
23impl OhttpClientConfig {
24 pub fn new(relay_url: String, key_config_base64: String) -> Self {
25 Self {
26 relay_url,
27 key_config_base64,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct OhttpResponse {
35 pub status: StatusCode,
37 pub body: Vec<u8>,
39}
40
41#[derive(Clone, Debug)]
44pub struct OhttpClient {
45 client: Client,
46 relay_url: String,
47 target_scheme: String,
48 target_authority: String,
49 encoded_config_list: Vec<u8>,
50}
51
52impl OhttpClient {
53 pub fn new(
67 client: Client,
68 config_scope: &str,
69 target_url: &str,
70 config: OhttpClientConfig,
71 ) -> Result<Self, AuthenticatorError> {
72 let (target_scheme, target_authority) =
73 target_url
74 .split_once("://")
75 .ok_or_else(|| AuthenticatorError::InvalidConfig {
76 attribute: format!("{config_scope}.target_url"),
77 reason: format!("expected scheme://authority, got {:?}", target_url),
78 })?;
79
80 let target_scheme = target_scheme.to_owned();
81 let target_authority = target_authority.trim_end_matches('/').to_owned();
82
83 let attribute = format!("{config_scope}.key_config_base64");
84
85 let encoded_config_list = base64::engine::general_purpose::STANDARD
86 .decode(&config.key_config_base64)
87 .map_err(|err| AuthenticatorError::InvalidConfig {
88 attribute: attribute.clone(),
89 reason: format!("invalid base64: {err}"),
90 })?;
91
92 ClientRequest::from_encoded_config_list(&encoded_config_list).map_err(|err| {
93 AuthenticatorError::InvalidConfig {
94 attribute,
95 reason: format!("invalid application/ohttp-keys payload: {err}"),
96 }
97 })?;
98
99 Ok(Self {
100 client,
101 relay_url: config.relay_url,
102 target_scheme,
103 target_authority,
104 encoded_config_list,
105 })
106 }
107
108 pub async fn post_json<T: serde::Serialize>(
114 &self,
115 path: &str,
116 body: &T,
117 ) -> Result<OhttpResponse, AuthenticatorError> {
118 let body = serde_json::to_vec(body).map_err(|e| {
119 AuthenticatorError::Generic(format!("failed to serialize request body: {e}"))
120 })?;
121 self.request(b"POST", path, Some(&body)).await
122 }
123
124 pub async fn get(&self, path: &str) -> Result<OhttpResponse, AuthenticatorError> {
129 self.request(b"GET", path, None).await
130 }
131
132 async fn request(
133 &self,
134 method: &[u8],
135 path: &str,
136 body: Option<&[u8]>,
137 ) -> Result<OhttpResponse, AuthenticatorError> {
138 let mut msg = Message::request(
139 method.to_vec(),
140 self.target_scheme.as_bytes().to_vec(),
141 self.target_authority.as_bytes().to_vec(),
142 path.as_bytes().to_vec(),
143 );
144 if let Some(body) = body {
145 msg.put_header("content-type", "application/json");
146 msg.write_content(body);
147 }
148 let mut bhttp_buf = Vec::new();
149 msg.write_bhttp(Mode::KnownLength, &mut bhttp_buf)?;
150
151 let ohttp_req = ClientRequest::from_encoded_config_list(&self.encoded_config_list)?;
152 let (enc_request, ohttp_resp_ctx) = ohttp_req.encapsulate(&bhttp_buf)?;
153
154 let resp = self
155 .client
156 .post(&self.relay_url)
157 .header("content-type", "message/ohttp-req")
158 .body(enc_request)
159 .send()
160 .await?;
161
162 if !resp.status().is_success() {
163 return Err(AuthenticatorError::OhttpRelayError {
164 status: resp.status(),
165 body: resp.text().await.unwrap_or_default(),
166 });
167 }
168
169 let enc_response = resp.bytes().await?;
170 let response_buf = ohttp_resp_ctx.decapsulate(&enc_response)?;
171
172 let response_msg = Message::read_bhttp(&mut std::io::Cursor::new(&response_buf))?;
173 let status_code = response_msg
174 .control()
175 .status()
176 .map(|s| s.code())
177 .ok_or_else(|| {
178 AuthenticatorError::Generic("OHTTP response missing HTTP status line".into())
179 })?;
180 let status = StatusCode::from_u16(status_code).map_err(|_| bhttp::Error::InvalidStatus)?;
181
182 Ok(OhttpResponse {
183 status,
184 body: response_msg.content().to_vec(),
185 })
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::AuthenticatorError;
193
194 #[test]
195 fn invalid_base64_key_config_returns_invalid_config() {
196 let config = OhttpClientConfig::new(
197 "http://localhost:1234".into(),
198 "not valid base64 !!!".into(),
199 );
200
201 let result = OhttpClient::new(
202 reqwest::Client::new(),
203 "test_scope",
204 "https://localhost:9999",
205 config,
206 );
207 match result {
208 Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
209 assert_eq!(attribute, "test_scope.key_config_base64");
210 assert!(
211 reason.contains("invalid base64"),
212 "unexpected reason: {reason}"
213 );
214 }
215 other => panic!("expected InvalidConfig, got: {other:?}"),
216 }
217 }
218
219 #[test]
220 fn invalid_ohttp_keys_payload_returns_invalid_config() {
221 let config = OhttpClientConfig::new(
222 "http://localhost:1234".into(),
223 base64::engine::general_purpose::STANDARD
224 .encode(b"definitely not an ohttp-keys payload"),
225 );
226
227 let result = OhttpClient::new(
228 reqwest::Client::new(),
229 "my_scope",
230 "https://localhost:9999",
231 config,
232 );
233 match result {
234 Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
235 assert_eq!(attribute, "my_scope.key_config_base64");
236 assert!(
237 reason.contains("invalid application/ohttp-keys payload"),
238 "unexpected reason: {reason}"
239 );
240 }
241 other => panic!("expected InvalidConfig, got: {other:?}"),
242 }
243 }
244
245 #[test]
246 fn garbage_ohttp_keys_bytes_returns_invalid_config() {
247 let config = OhttpClientConfig::new(
248 "http://127.0.0.1:0/does-not-exist".into(),
249 base64::engine::general_purpose::STANDARD.encode(b"not-a-valid-ohttp-keys"),
250 );
251
252 let result = OhttpClient::new(
253 reqwest::Client::new(),
254 "test",
255 "http://localhost:1234",
256 config,
257 );
258 assert!(
259 matches!(result, Err(AuthenticatorError::InvalidConfig { .. })),
260 "expected InvalidConfig for garbage key config, got: {result:?}"
261 );
262 }
263
264 #[test]
265 fn missing_scheme_in_target_url_returns_invalid_config() {
266 let config = OhttpClientConfig::new(
267 "http://localhost:1234".into(),
268 base64::engine::general_purpose::STANDARD.encode(b"irrelevant"),
269 );
270
271 let result = OhttpClient::new(
272 reqwest::Client::new(),
273 "test_scope",
274 "localhost:9999",
275 config,
276 );
277 match result {
278 Err(AuthenticatorError::InvalidConfig { attribute, reason }) => {
279 assert_eq!(attribute, "test_scope.target_url");
280 assert!(
281 reason.contains("expected scheme://authority"),
282 "unexpected reason: {reason}"
283 );
284 }
285 other => panic!("expected InvalidConfig, got: {other:?}"),
286 }
287 }
288
289 #[test]
290 fn empty_key_config_returns_invalid_config() {
291 let config = OhttpClientConfig::new(
292 "http://localhost:1234".into(),
293 base64::engine::general_purpose::STANDARD.encode(b""),
294 );
295
296 let result = OhttpClient::new(
297 reqwest::Client::new(),
298 "test_scope",
299 "https://localhost:9999",
300 config,
301 );
302 assert!(
303 matches!(result, Err(AuthenticatorError::InvalidConfig { .. })),
304 "expected InvalidConfig for empty key config, got: {result:?}"
305 );
306 }
307}