1use std::sync::Arc;
10use std::time::Duration;
11
12use axum::extract::State;
13use axum::http::header::{HeaderName, HeaderValue, HOST};
14use axum::http::{HeaderMap, StatusCode, Uri};
15use axum::middleware::Next;
16use axum::response::{IntoResponse, Response};
17use axum::Json;
18use tonic::transport::Channel;
19
20use envoy_types::pb::envoy::config::core::v3::HeaderValueOption;
21use envoy_types::pb::envoy::service::auth::v3::{
22 attribute_context::{HttpRequest, Request as AttrRequest},
23 authorization_client::AuthorizationClient,
24 check_response::HttpResponse as EnvoyHttpResponse,
25 AttributeContext, CheckRequest, CheckResponse, DeniedHttpResponse,
26};
27
28use super::forbidden;
29use crate::config::AuthzConfig;
30
31pub struct Authz {
33 channel: Channel,
34 timeout: Duration,
35 failure_mode_allow: bool,
36}
37
38impl Authz {
39 pub fn build(config: &AuthzConfig) -> Result<Option<Arc<Self>>, String> {
44 if !config.enabled {
45 return Ok(None);
46 }
47 let channel = Channel::from_shared(config.endpoint.clone())
48 .map_err(|e| format!("invalid authz endpoint: {e}"))?
49 .connect_lazy();
50 Ok(Some(Arc::new(Self {
51 channel,
52 timeout: Duration::from_millis(config.timeout_ms),
53 failure_mode_allow: config.failure_mode_allow,
54 })))
55 }
56}
57
58pub async fn middleware(
60 State(authz): State<Arc<Authz>>,
61 mut request: axum::extract::Request,
62 next: Next,
63) -> Response {
64 let check = build_check_request(request.headers(), request.method().as_str(), request.uri());
65 let mut client = AuthorizationClient::new(authz.channel.clone());
66 let mut grpc_req = tonic::Request::new(check);
67 grpc_req.set_timeout(authz.timeout);
68
69 match client.check(grpc_req).await {
70 Ok(resp) => match evaluate(resp.into_inner()) {
71 Decision::Allow(headers) => {
72 apply_headers(request.headers_mut(), headers);
73 next.run(request).await
74 }
75 Decision::Deny(response) => response,
76 },
77 Err(status) if authz.failure_mode_allow => {
80 tracing::warn!(error = %status, "authz check failed; failing open");
81 next.run(request).await
82 }
83 Err(status) => {
84 tracing::warn!(error = %status, "authz check failed; failing closed");
85 service_unavailable("authorization service unavailable")
86 }
87 }
88}
89
90enum Decision {
92 Allow(Vec<(HeaderName, HeaderValue)>),
94 Deny(Response),
96}
97
98fn build_check_request(headers: &HeaderMap, method: &str, uri: &Uri) -> CheckRequest {
100 let mut header_map = std::collections::HashMap::new();
101 for (name, value) in headers {
102 if let Ok(v) = value.to_str() {
103 header_map.insert(name.as_str().to_string(), v.to_string());
104 }
105 }
106 let host = headers
107 .get(HOST)
108 .and_then(|v| v.to_str().ok())
109 .unwrap_or_default()
110 .to_string();
111 let scheme = headers
112 .get("x-forwarded-proto")
113 .and_then(|v| v.to_str().ok())
114 .unwrap_or("http")
115 .to_string();
116
117 let http = HttpRequest {
118 method: method.to_string(),
119 path: uri.path().to_string(),
120 query: uri.query().unwrap_or_default().to_string(),
121 host,
122 scheme,
123 headers: header_map,
124 ..Default::default()
125 };
126 CheckRequest {
127 attributes: Some(AttributeContext {
128 request: Some(AttrRequest {
129 http: Some(http),
130 ..Default::default()
131 }),
132 ..Default::default()
133 }),
134 }
135}
136
137fn evaluate(resp: CheckResponse) -> Decision {
139 let allowed = resp.status.as_ref().map(|s| s.code == 0).unwrap_or(false);
140 if allowed {
141 let headers = match resp.http_response {
142 Some(EnvoyHttpResponse::OkResponse(ok)) => {
143 ok.headers.into_iter().filter_map(header_kv).collect()
144 }
145 _ => Vec::new(),
146 };
147 Decision::Allow(headers)
148 } else {
149 let response = match resp.http_response {
150 Some(EnvoyHttpResponse::DeniedResponse(denied))
151 | Some(EnvoyHttpResponse::ErrorResponse(denied)) => denied_to_response(denied),
152 _ => forbidden("forbidden by authorization policy"),
153 };
154 Decision::Deny(response)
155 }
156}
157
158fn apply_headers(dst: &mut HeaderMap, headers: Vec<(HeaderName, HeaderValue)>) {
161 for (name, value) in headers {
162 dst.append(name, value);
163 }
164}
165
166fn header_kv(opt: HeaderValueOption) -> Option<(HeaderName, HeaderValue)> {
168 let header = opt.header?;
169 let name = HeaderName::try_from(header.key).ok()?;
170 let value = HeaderValue::try_from(header.value).ok()?;
171 Some((name, value))
172}
173
174fn denied_to_response(denied: DeniedHttpResponse) -> Response {
176 let status = denied
178 .status
179 .and_then(|s| u16::try_from(s.code).ok())
180 .and_then(|c| StatusCode::from_u16(c).ok())
181 .unwrap_or(StatusCode::FORBIDDEN);
182 let mut headers = HeaderMap::new();
183 apply_headers(
184 &mut headers,
185 denied.headers.into_iter().filter_map(header_kv).collect(),
186 );
187 (status, headers, denied.body).into_response()
188}
189
190fn service_unavailable(message: &str) -> Response {
191 (
192 StatusCode::SERVICE_UNAVAILABLE,
193 Json(serde_json::json!({ "error": "UNAVAILABLE", "message": message })),
194 )
195 .into_response()
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use envoy_types::pb::envoy::config::core::v3::HeaderValue as EnvoyHeaderValue;
202 use envoy_types::pb::envoy::r#type::v3::HttpStatus;
203 use envoy_types::pb::envoy::service::auth::v3::{CheckResponse, OkHttpResponse};
204 use envoy_types::pb::google::rpc::Status as RpcStatus;
205
206 fn hvo(key: &str, value: &str) -> HeaderValueOption {
207 HeaderValueOption {
208 header: Some(EnvoyHeaderValue {
209 key: key.to_string(),
210 value: value.to_string(),
211 ..Default::default()
212 }),
213 ..Default::default()
214 }
215 }
216
217 #[test]
218 fn build_check_request_maps_http_attributes() {
219 let mut headers = HeaderMap::new();
220 headers.insert(HOST, HeaderValue::from_static("api.example.com"));
221 headers.insert("x-forwarded-proto", HeaderValue::from_static("https"));
222 headers.insert("x-forwarded-user", HeaderValue::from_static("alice"));
223 let uri: Uri = "/v1/things?page=2".parse().unwrap();
224
225 let check = build_check_request(&headers, "POST", &uri);
226 let http = check.attributes.unwrap().request.unwrap().http.unwrap();
227 assert_eq!(http.method, "POST");
228 assert_eq!(http.path, "/v1/things");
229 assert_eq!(http.query, "page=2");
230 assert_eq!(http.host, "api.example.com");
231 assert_eq!(http.scheme, "https");
232 assert_eq!(http.headers.get("x-forwarded-user").unwrap(), "alice");
234 }
235
236 #[test]
237 fn scheme_defaults_to_http() {
238 let headers = HeaderMap::new();
239 let uri: Uri = "/x".parse().unwrap();
240 let check = build_check_request(&headers, "GET", &uri);
241 let http = check.attributes.unwrap().request.unwrap().http.unwrap();
242 assert_eq!(http.scheme, "http");
243 }
244
245 #[test]
246 fn ok_status_allows_and_collects_headers() {
247 let resp = CheckResponse {
248 status: Some(RpcStatus {
249 code: 0,
250 ..Default::default()
251 }),
252 http_response: Some(EnvoyHttpResponse::OkResponse(OkHttpResponse {
253 headers: vec![hvo("x-authz-decision", "allow")],
254 ..Default::default()
255 })),
256 ..Default::default()
257 };
258 match evaluate(resp) {
259 Decision::Allow(headers) => {
260 assert_eq!(headers.len(), 1);
261 assert_eq!(headers[0].0.as_str(), "x-authz-decision");
262 assert_eq!(headers[0].1, "allow");
263 }
264 Decision::Deny(_) => panic!("expected allow"),
265 }
266 }
267
268 #[test]
269 fn apply_headers_preserves_duplicate_names() {
270 let mut dst = HeaderMap::new();
273 apply_headers(
274 &mut dst,
275 vec![
276 (
277 HeaderName::from_static("set-cookie"),
278 HeaderValue::from_static("a=1"),
279 ),
280 (
281 HeaderName::from_static("set-cookie"),
282 HeaderValue::from_static("b=2"),
283 ),
284 ],
285 );
286 let values: Vec<_> = dst
287 .get_all("set-cookie")
288 .iter()
289 .map(|v| v.to_str().unwrap())
290 .collect();
291 assert_eq!(values, vec!["a=1", "b=2"]);
292 }
293
294 #[test]
295 fn missing_status_denies() {
296 let resp = CheckResponse::default();
298 match evaluate(resp) {
299 Decision::Deny(response) => assert_eq!(response.status(), StatusCode::FORBIDDEN),
300 Decision::Allow(_) => panic!("expected deny"),
301 }
302 }
303
304 #[test]
305 fn denied_response_uses_its_status() {
306 let resp = CheckResponse {
307 status: Some(RpcStatus {
308 code: 7, ..Default::default()
310 }),
311 http_response: Some(EnvoyHttpResponse::DeniedResponse(DeniedHttpResponse {
312 status: Some(HttpStatus { code: 401 }),
313 body: "nope".to_string(),
314 ..Default::default()
315 })),
316 ..Default::default()
317 };
318 match evaluate(resp) {
319 Decision::Deny(response) => assert_eq!(response.status(), StatusCode::UNAUTHORIZED),
320 Decision::Allow(_) => panic!("expected deny"),
321 }
322 }
323
324 #[tokio::test]
325 async fn fail_closed_returns_503_when_authz_unreachable() {
326 use axum::routing::get;
327 use axum::Router;
328 use tower::ServiceExt;
329
330 let authz = Authz::build(&AuthzConfig {
331 enabled: true,
332 endpoint: "http://127.0.0.1:1".into(),
333 timeout_ms: 100,
334 failure_mode_allow: false,
335 })
336 .unwrap()
337 .unwrap();
338 let app: Router = Router::new()
339 .route("/x", get(|| async { "upstream" }))
340 .layer(axum::middleware::from_fn_with_state(authz, middleware));
341 let resp = app
342 .oneshot(
343 axum::http::Request::get("/x")
344 .body(axum::body::Body::empty())
345 .unwrap(),
346 )
347 .await
348 .unwrap();
349 assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
350 }
351
352 #[tokio::test]
353 async fn fail_open_passes_through_when_authz_unreachable() {
354 use axum::routing::get;
355 use axum::Router;
356 use tower::ServiceExt;
357
358 let authz = Authz::build(&AuthzConfig {
359 enabled: true,
360 endpoint: "http://127.0.0.1:1".into(),
361 timeout_ms: 100,
362 failure_mode_allow: true,
363 })
364 .unwrap()
365 .unwrap();
366 let app: Router = Router::new()
367 .route("/x", get(|| async { "upstream" }))
368 .layer(axum::middleware::from_fn_with_state(authz, middleware));
369 let resp = app
370 .oneshot(
371 axum::http::Request::get("/x")
372 .body(axum::body::Body::empty())
373 .unwrap(),
374 )
375 .await
376 .unwrap();
377 assert_eq!(resp.status(), StatusCode::OK);
378 }
379}