1use std::net::SocketAddr;
2
3use axum::{
4 Extension, RequestPartsExt,
5 extract::{ConnectInfo, FromRequestParts},
6 http::{StatusCode, request::Parts},
7};
8use axum_extra::{TypedHeader, extract::CookieJar, headers::UserAgent};
9use torii::{SessionToken, User};
10
11use crate::{error::AuthError, types::ConnectionInfo};
12
13impl<S> FromRequestParts<S> for ConnectionInfo
14where
15 S: Send + Sync,
16{
17 type Rejection = (StatusCode, &'static str);
18
19 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
20 let user_agent = parts
21 .extract::<Option<TypedHeader<UserAgent>>>()
22 .await
23 .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid user agent header"))?
24 .map(|ua| ua.to_string());
25
26 let ip = parts
27 .extract::<ConnectInfo<SocketAddr>>()
28 .await
29 .ok()
30 .map(|addr| addr.ip().to_string());
31
32 Ok(ConnectionInfo { ip, user_agent })
33 }
34}
35
36pub struct AuthUser(pub User);
37
38impl<S> FromRequestParts<S> for AuthUser
39where
40 S: Send + Sync,
41{
42 type Rejection = AuthError;
43
44 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
45 let Extension(user): Extension<User> =
46 parts.extract().await.map_err(|_| AuthError::Unauthorized)?;
47
48 Ok(AuthUser(user))
49 }
50}
51
52pub struct OptionalAuthUser(pub Option<User>);
53
54impl<S> FromRequestParts<S> for OptionalAuthUser
55where
56 S: Send + Sync,
57{
58 type Rejection = (StatusCode, &'static str);
59
60 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61 let user = parts.extensions.get::<User>().cloned();
62
63 Ok(OptionalAuthUser(user))
64 }
65}
66
67pub struct SessionTokenFromCookie(pub Option<SessionToken>);
68
69impl<S> FromRequestParts<S> for SessionTokenFromCookie
70where
71 S: Send + Sync,
72{
73 type Rejection = (StatusCode, &'static str);
74
75 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
76 let jar = parts
77 .extract::<CookieJar>()
78 .await
79 .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid cookie header"))?;
80
81 let session_token = jar
82 .get("session_id")
83 .and_then(|cookie| cookie.value().parse::<String>().ok())
84 .map(|token| SessionToken::new(&token));
85
86 Ok(SessionTokenFromCookie(session_token))
87 }
88}
89
90pub struct SessionTokenFromBearer(pub Option<SessionToken>);
91
92impl<S> FromRequestParts<S> for SessionTokenFromBearer
93where
94 S: Send + Sync,
95{
96 type Rejection = (StatusCode, &'static str);
97
98 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
99 let session_token = parts
100 .headers
101 .get("Authorization")
102 .and_then(|header| header.to_str().ok())
103 .and_then(|header| header.strip_prefix("Bearer "))
104 .map(SessionToken::new);
105
106 Ok(SessionTokenFromBearer(session_token))
107 }
108}
109
110pub struct SessionTokenFromRequest(pub Option<SessionToken>);
111
112impl<S> FromRequestParts<S> for SessionTokenFromRequest
113where
114 S: Send + Sync,
115{
116 type Rejection = (StatusCode, &'static str);
117
118 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
119 if let Some(token) = parts
121 .headers
122 .get("Authorization")
123 .and_then(|header| header.to_str().ok())
124 .and_then(|header| header.strip_prefix("Bearer "))
125 {
126 return Ok(SessionTokenFromRequest(Some(SessionToken::new(token))));
127 }
128
129 let jar = parts
131 .extract::<CookieJar>()
132 .await
133 .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid cookie header"))?;
134
135 let session_token = jar
136 .get("session_id")
137 .and_then(|cookie| cookie.value().parse::<String>().ok())
138 .map(|token| SessionToken::new(&token));
139
140 Ok(SessionTokenFromRequest(session_token))
141 }
142}