1use axum::Json;
2use axum::extract::FromRequestParts;
3use axum::response::{IntoResponse, Response as AxumResponse};
4use axum_extra::extract::{Query, QueryRejection};
5use http::header::{self, HOST};
6use http::request::Parts;
7use http::uri::InvalidUri;
8use http::{HeaderValue, StatusCode};
9use tracing::trace;
10
11use crate::{Rel, WebFingerRequest, WebFingerResponse};
12
13const JRD_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/jrd+json");
14
15impl IntoResponse for WebFingerResponse {
16 fn into_response(self) -> AxumResponse {
42 ([(header::CONTENT_TYPE, JRD_CONTENT_TYPE)], Json(self)).into_response()
43 }
44}
45
46#[derive(Debug, serde::Deserialize)]
48struct RequestParams {
49 resource: String,
50
51 #[serde(default)]
52 rel: Vec<String>,
53}
54
55pub enum Rejection {
60 InvalidQueryString(String),
62
63 MissingHost,
65
66 InvalidResource(InvalidUri),
68}
69
70impl IntoResponse for Rejection {
71 fn into_response(self) -> AxumResponse {
73 let message = match self {
74 Rejection::MissingHost => "missing host".to_string(),
75 Rejection::InvalidQueryString(e) => format!("{e}"),
76 Rejection::InvalidResource(e) => format!("invalid resource: {e}"),
77 };
78 (StatusCode::BAD_REQUEST, message).into_response()
79 }
80}
81
82impl From<QueryRejection> for Rejection {
83 fn from(rejection: QueryRejection) -> Self {
84 Rejection::InvalidQueryString(rejection.to_string())
85 }
86}
87
88impl<S: Send + Sync> FromRequestParts<S> for WebFingerRequest {
89 type Rejection = Rejection;
90
91 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
126 trace!("request parts: {:?}", parts);
127
128 let host = parts
129 .uri
130 .host()
131 .or_else(|| parts.headers.get(HOST).and_then(|host| host.to_str().ok()))
132 .map(str::to_string)
133 .ok_or(Rejection::MissingHost)?;
134
135 let query = Query::<RequestParams>::from_request_parts(parts, state).await?;
138 let resource = query.resource.parse().map_err(Rejection::InvalidResource)?;
139 let rels = query.rel.clone().into_iter().map(Rel::from).collect();
140
141 Ok(WebFingerRequest {
142 host,
143 resource,
144 rels,
145 })
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use axum::body::Body;
152 use axum::routing::get;
153 use http::{Request, Response};
154 use http_body_util::BodyExt;
155 use tower::ServiceExt;
156
157 use super::*;
158 use crate::WELL_KNOWN_PATH;
159
160 type Result<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
161
162 trait IntoText {
164 async fn into_text(self) -> Result<String>;
165 }
166
167 impl IntoText for Response<Body> {
168 async fn into_text(self) -> Result<String> {
169 let body = self.into_body().collect().await?.to_bytes();
170 let string = String::from_utf8(body.to_vec())?;
171 Ok(string)
172 }
173 }
174
175 fn app() -> axum::Router {
176 axum::Router::new().route(WELL_KNOWN_PATH, get(webfinger))
177 }
178
179 async fn webfinger(request: WebFingerRequest) -> impl IntoResponse {
180 WebFingerResponse::builder(request.resource.to_string()).build()
181 }
182
183 const VALID_RESOURCE: &str = "acct:carol@example.com";
184
185 #[tokio::test]
186 async fn valid_request() -> Result {
187 let uri = format!("https://example.com{WELL_KNOWN_PATH}?resource={VALID_RESOURCE}");
188 let request = Request::builder().uri(uri).body(Body::empty())?;
189
190 let response = app().oneshot(request).await?;
191
192 assert_eq!(response.status(), StatusCode::OK, "{response:?}");
193 let body = response.into_text().await?;
194 assert_eq!(body, r#"{"subject":"acct:carol@example.com","links":[]}"#);
195 Ok(())
196 }
197
198 #[tokio::test]
199 async fn valid_request_with_host_header() -> Result {
200 let request = Request::builder()
201 .uri(format!("{WELL_KNOWN_PATH}?resource={VALID_RESOURCE}"))
202 .header(HOST, "example.com")
203 .body(Body::empty())?;
204
205 let response = app().oneshot(request).await?;
206
207 assert_eq!(response.status(), StatusCode::OK, "{response:?}");
208 let body = response.into_text().await?;
209 assert_eq!(body, r#"{"subject":"acct:carol@example.com","links":[]}"#);
210 Ok(())
211 }
212
213 #[tokio::test]
214 async fn request_with_no_host() -> Result {
215 let uri = format!("{WELL_KNOWN_PATH}?resource={VALID_RESOURCE}");
216 let request = Request::builder().uri(uri).body(Body::empty())?;
217
218 let response = app().oneshot(request).await?;
219
220 assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{response:?}");
221 let body = response.into_text().await?;
222 assert_eq!(body, "missing host");
223 Ok(())
224 }
225
226 #[tokio::test]
227 async fn request_with_missing_resource() -> Result {
228 let request = Request::builder()
229 .uri(WELL_KNOWN_PATH)
230 .header(HOST, "example.com")
231 .body(Body::empty())?;
232
233 let response = app().oneshot(request).await?;
234
235 assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{response:?}");
236 let body = response.into_text().await?;
237 assert_eq!(
238 body,
239 "Failed to deserialize query string: missing field `resource`",
240 );
241 Ok(())
242 }
243
244 #[tokio::test]
245 async fn request_with_invalid_resource() -> Result {
246 let uri = format!("https://example.com{WELL_KNOWN_PATH}?resource=%");
247 let request = Request::builder().uri(uri).body(Body::empty())?;
248
249 let response = app().oneshot(request).await?;
250
251 assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{response:?}");
252 let body = response.into_text().await?;
253 assert_eq!(body, "invalid resource: invalid authority");
254 Ok(())
255 }
256}