securitydept_session_context/
lib.rs1pub mod config;
2#[cfg(feature = "service")]
3mod service;
4
5use std::collections::HashMap;
6#[cfg(test)]
7use std::time::Duration as StdDuration;
8
9pub use config::{
10 NoopSessionContextConfigValidator, ResolvedSessionContextConfig, SessionContextConfig,
11 SessionContextConfigSource, SessionContextConfigValidationError,
12 SessionContextConfigValidationFailure, SessionContextConfigValidator,
13 SessionContextFixedPostAuthRedirectValidator,
14};
15use http::StatusCode;
16#[cfg(test)]
17use securitydept_utils::redirect::RedirectTargetConfig;
18use securitydept_utils::{
19 error::{ErrorPresentation, ToErrorPresentation, UserRecovery},
20 principal::AuthenticatedPrincipal,
21 redirect::RedirectTargetError,
22};
23use serde::{Serialize, de::DeserializeOwned};
24use serde_json::Value;
25#[cfg(feature = "service")]
26pub use service::{
27 DevSessionAuthService, OidcSessionAuthService, OidcSessionAuthServiceConfig,
28 SessionAuthServiceError, SessionAuthServiceTrait,
29};
30use snafu::Snafu;
31use tower_sessions::{
32 Expiry, Session, SessionManagerLayer, SessionStore,
33 cookie::{SameSite, time::Duration},
34};
35use typed_builder::TypedBuilder;
36
37pub const DEFAULT_COOKIE_NAME: &str = "securitydept_session";
38pub const DEFAULT_SESSION_CONTEXT_KEY: &str = "securitydept.session_context";
39
40pub type SessionPrincipal = AuthenticatedPrincipal;
41
42#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, TypedBuilder)]
43pub struct SessionContext<Extra = HashMap<String, Value>> {
44 pub principal: SessionPrincipal,
45 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
46 #[builder(default)]
47 pub attributes: HashMap<String, Value>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 #[builder(default, setter(strip_option))]
50 pub extra: Option<Extra>,
51}
52
53#[derive(Debug, Clone, Copy, Serialize, serde::Deserialize, PartialEq, Eq, Default)]
54#[serde(rename_all = "snake_case")]
55pub enum SessionCookieSameSite {
56 Strict,
57 #[default]
58 Lax,
59 None,
60}
61
62impl From<SessionCookieSameSite> for SameSite {
63 fn from(value: SessionCookieSameSite) -> Self {
64 match value {
65 SessionCookieSameSite::Strict => SameSite::Strict,
66 SessionCookieSameSite::Lax => SameSite::Lax,
67 SessionCookieSameSite::None => SameSite::None,
68 }
69 }
70}
71
72pub fn build_session_layer<Store>(
73 config: &ResolvedSessionContextConfig,
74 store: Store,
75) -> SessionManagerLayer<Store>
76where
77 Store: SessionStore,
78{
79 let mut layer = SessionManagerLayer::new(store)
80 .with_name(config.cookie_name.clone())
81 .with_path(config.cookie_path.clone())
82 .with_same_site(config.same_site.into())
83 .with_http_only(config.http_only)
84 .with_secure(config.secure);
85
86 if let Some(ttl) = config.ttl {
87 layer = layer.with_expiry(Expiry::OnInactivity(
88 Duration::seconds(ttl.as_secs() as i64),
89 ));
90 }
91
92 layer
93}
94
95#[derive(Debug, Snafu)]
96pub enum SessionContextError {
97 #[snafu(display("session context is missing"))]
98 MissingContext,
99 #[snafu(display("session operation failed: {source}"))]
100 Session {
101 source: tower_sessions::session::Error,
102 },
103 #[snafu(display("post-auth redirect is invalid: {source}"))]
104 RedirectTarget { source: RedirectTargetError },
105}
106
107pub type SessionContextResult<T> = Result<T, SessionContextError>;
108
109impl SessionContextError {
110 pub fn status_code(&self) -> StatusCode {
111 match self {
112 Self::MissingContext => StatusCode::UNAUTHORIZED,
113 Self::Session { .. } | Self::RedirectTarget { .. } => StatusCode::INTERNAL_SERVER_ERROR,
114 }
115 }
116}
117
118impl ToErrorPresentation for SessionContextError {
119 fn to_error_presentation(&self) -> ErrorPresentation {
120 match self {
121 SessionContextError::MissingContext => ErrorPresentation::new(
122 "authentication_required",
123 "Sign in to continue.",
124 UserRecovery::Reauthenticate,
125 ),
126 SessionContextError::Session { .. } => ErrorPresentation::new(
127 "session_unavailable",
128 "The session is temporarily unavailable.",
129 UserRecovery::Retry,
130 ),
131 SessionContextError::RedirectTarget { .. } => ErrorPresentation::new(
132 "session_post_auth_redirect_invalid",
133 "The configured post-auth redirect is invalid.",
134 UserRecovery::ContactSupport,
135 ),
136 }
137 }
138}
139
140#[derive(Clone)]
141pub struct SessionContextSession {
142 session: Session,
143 session_context_key: String,
144}
145
146impl From<Session> for SessionContextSession {
147 fn from(session: Session) -> Self {
148 Self {
149 session,
150 session_context_key: DEFAULT_SESSION_CONTEXT_KEY.to_string(),
151 }
152 }
153}
154
155impl SessionContextSession {
156 pub fn new(session: Session) -> Self {
157 Self::from(session)
158 }
159
160 pub fn from_resolved_config(session: Session, config: &ResolvedSessionContextConfig) -> Self {
161 Self {
162 session,
163 session_context_key: config.session_context_key.clone(),
164 }
165 }
166
167 pub fn with_key(session: Session, session_context_key: impl Into<String>) -> Self {
168 Self {
169 session,
170 session_context_key: session_context_key.into(),
171 }
172 }
173
174 pub fn raw_session(&self) -> &Session {
175 &self.session
176 }
177
178 pub async fn insert<Extra>(&self, context: &SessionContext<Extra>) -> SessionContextResult<()>
179 where
180 Extra: Serialize,
181 {
182 self.session
183 .insert(&self.session_context_key, context)
184 .await
185 .map_err(|source| SessionContextError::Session { source })
186 }
187
188 pub async fn get<Extra>(&self) -> SessionContextResult<Option<SessionContext<Extra>>>
189 where
190 Extra: DeserializeOwned,
191 {
192 self.session
193 .get(&self.session_context_key)
194 .await
195 .map_err(|source| SessionContextError::Session { source })
196 }
197
198 pub async fn require<Extra>(&self) -> SessionContextResult<SessionContext<Extra>>
199 where
200 Extra: DeserializeOwned,
201 {
202 self.get().await?.ok_or(SessionContextError::MissingContext)
203 }
204
205 pub async fn clear(&self) -> SessionContextResult<()> {
206 self.session
207 .remove_value(&self.session_context_key)
208 .await
209 .map(|_| ())
210 .map_err(|source| SessionContextError::Session { source })
211 }
212
213 pub async fn is_authenticated<Extra>(&self) -> SessionContextResult<bool>
214 where
215 Extra: DeserializeOwned,
216 {
217 Ok(self.get::<Extra>().await?.is_some())
218 }
219
220 pub async fn cycle_id(&self) -> SessionContextResult<()> {
221 self.session
222 .cycle_id()
223 .await
224 .map_err(|source| SessionContextError::Session { source })
225 }
226
227 pub async fn flush(&self) -> SessionContextResult<()> {
228 self.session
229 .flush()
230 .await
231 .map_err(|source| SessionContextError::Session { source })
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use securitydept_utils::redirect::RedirectTargetRule;
238
239 use super::*;
240
241 #[test]
242 fn test_default_config() {
243 let config = SessionContextConfig::default();
244 assert_eq!(config.cookie_name, DEFAULT_COOKIE_NAME);
245 assert_eq!(config.session_context_key, DEFAULT_SESSION_CONTEXT_KEY);
246 assert_eq!(config.cookie_path, "/");
247 assert!(config.http_only);
248 assert!(!config.secure);
249 assert_eq!(config.same_site, SessionCookieSameSite::Lax);
250 assert_eq!(config.ttl, Some(StdDuration::from_secs(86_400)));
251 assert_eq!(
252 config.post_auth_redirect.default_redirect_target.as_deref(),
253 Some("/")
254 );
255 }
256
257 #[test]
258 fn test_context_with_extra_data() {
259 let context = SessionContext::builder()
260 .principal(
261 SessionPrincipal::builder()
262 .subject("dev-session")
263 .display_name("dev")
264 .build(),
265 )
266 .attributes(HashMap::from([(
267 "mode".to_string(),
268 Value::String("dev".to_string()),
269 )]))
270 .extra(HashMap::from([(
271 "provider".to_string(),
272 Value::String("local".to_string()),
273 )]))
274 .build();
275
276 assert_eq!(context.principal.subject, "dev-session");
277 assert_eq!(context.principal.display_name, "dev");
278 assert_eq!(
279 context.attributes.get("mode"),
280 Some(&Value::String("dev".to_string()))
281 );
282 assert_eq!(
283 context
284 .extra
285 .as_ref()
286 .and_then(|extra| extra.get("provider")),
287 Some(&Value::String("local".to_string()))
288 );
289 }
290
291 #[test]
292 fn test_post_auth_redirect_resolution() {
293 let config = SessionContextConfigSource::resolve_all(
294 &SessionContextConfig::builder()
295 .post_auth_redirect(RedirectTargetConfig::dynamic_default_and_dynamic_targets(
296 "/",
297 [RedirectTargetRule::Strict {
298 value: "/app".to_string(),
299 }],
300 ))
301 .build(),
302 )
303 .expect("session context config should resolve");
304
305 assert_eq!(
306 config
307 .resolve_post_auth_redirect(None)
308 .expect("default redirect should resolve"),
309 "/"
310 );
311 assert_eq!(
312 config
313 .resolve_post_auth_redirect(Some("/app"))
314 .expect("dynamic redirect should resolve"),
315 "/app"
316 );
317 }
318
319 #[test]
320 fn fixed_post_auth_redirect_validator_rejects_override() {
321 let config = SessionContextConfig::builder()
322 .post_auth_redirect(RedirectTargetConfig::strict_default("/admin"))
323 .build();
324 let validator = SessionContextFixedPostAuthRedirectValidator::new(
325 RedirectTargetConfig::strict_default("/"),
326 );
327
328 let error = SessionContextConfigSource::resolve_all_with_validator(&config, &validator)
329 .expect_err("unexpected session post_auth_redirect should be rejected");
330
331 assert!(matches!(
332 error,
333 SessionContextConfigValidationFailure::Validation { source }
334 if source.field_path == "post_auth_redirect"
335 && source.code == "fixed_post_auth_redirect_conflict"
336 ));
337 }
338}