Skip to main content

securitydept_session_context/
lib.rs

1#[cfg(feature = "service")]
2mod service;
3
4use std::{collections::HashMap, time::Duration as StdDuration};
5
6use http::StatusCode;
7use securitydept_utils::{
8    error::{ErrorPresentation, ToErrorPresentation, UserRecovery},
9    principal::AuthenticatedPrincipal,
10    redirect::{RedirectTargetConfig, RedirectTargetError, UriRelativeRedirectTargetResolver},
11};
12use serde::{Serialize, de::DeserializeOwned};
13use serde_json::Value;
14#[cfg(feature = "service")]
15pub use service::{
16    DevSessionAuthService, OidcSessionAuthService, SessionAuthServiceError, SessionAuthServiceTrait,
17};
18use snafu::Snafu;
19use tower_sessions::{
20    Expiry, Session, SessionManagerLayer, SessionStore,
21    cookie::{SameSite, time::Duration},
22};
23use typed_builder::TypedBuilder;
24
25pub const DEFAULT_COOKIE_NAME: &str = "securitydept_session";
26pub const DEFAULT_SESSION_CONTEXT_KEY: &str = "securitydept.session_context";
27
28pub type SessionPrincipal = AuthenticatedPrincipal;
29
30#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, TypedBuilder)]
31pub struct SessionContext<Extra = HashMap<String, Value>> {
32    pub principal: SessionPrincipal,
33    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
34    #[builder(default)]
35    pub attributes: HashMap<String, Value>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    #[builder(default, setter(strip_option))]
38    pub extra: Option<Extra>,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, serde::Deserialize, PartialEq, Eq, Default)]
42#[serde(rename_all = "snake_case")]
43pub enum SessionCookieSameSite {
44    Strict,
45    #[default]
46    Lax,
47    None,
48}
49
50impl From<SessionCookieSameSite> for SameSite {
51    fn from(value: SessionCookieSameSite) -> Self {
52        match value {
53            SessionCookieSameSite::Strict => SameSite::Strict,
54            SessionCookieSameSite::Lax => SameSite::Lax,
55            SessionCookieSameSite::None => SameSite::None,
56        }
57    }
58}
59
60#[derive(Debug, Clone, Serialize, serde::Deserialize, TypedBuilder)]
61pub struct SessionContextConfig {
62    #[builder(default = DEFAULT_COOKIE_NAME.to_string())]
63    #[serde(default = "default_cookie_name")]
64    pub cookie_name: String,
65    #[builder(default = DEFAULT_SESSION_CONTEXT_KEY.to_string())]
66    #[serde(default = "default_session_context_key")]
67    pub session_context_key: String,
68    #[builder(default = "/".to_string())]
69    #[serde(default = "default_cookie_path")]
70    pub cookie_path: String,
71    #[builder(default = true)]
72    #[serde(default = "default_true")]
73    pub http_only: bool,
74    #[builder(default = false)]
75    #[serde(default)]
76    pub secure: bool,
77    #[builder(default)]
78    #[serde(default)]
79    pub same_site: SessionCookieSameSite,
80    #[builder(default = Some(StdDuration::from_secs(86_400)))]
81    #[serde(default = "default_ttl", with = "humantime_serde::option")]
82    pub ttl: Option<StdDuration>,
83    #[builder(default = default_post_auth_redirect())]
84    #[serde(default = "default_post_auth_redirect")]
85    pub post_auth_redirect: RedirectTargetConfig,
86}
87
88fn default_cookie_name() -> String {
89    DEFAULT_COOKIE_NAME.to_string()
90}
91
92fn default_session_context_key() -> String {
93    DEFAULT_SESSION_CONTEXT_KEY.to_string()
94}
95
96fn default_cookie_path() -> String {
97    "/".to_string()
98}
99
100fn default_true() -> bool {
101    true
102}
103
104fn default_ttl() -> Option<StdDuration> {
105    Some(StdDuration::from_secs(86_400))
106}
107
108fn default_post_auth_redirect() -> RedirectTargetConfig {
109    RedirectTargetConfig::strict_default("/")
110}
111
112impl Default for SessionContextConfig {
113    fn default() -> Self {
114        Self {
115            cookie_name: default_cookie_name(),
116            session_context_key: default_session_context_key(),
117            cookie_path: default_cookie_path(),
118            http_only: default_true(),
119            secure: false,
120            same_site: SessionCookieSameSite::default(),
121            ttl: default_ttl(),
122            post_auth_redirect: default_post_auth_redirect(),
123        }
124    }
125}
126
127pub fn build_session_layer<Store>(
128    config: &SessionContextConfig,
129    store: Store,
130) -> SessionManagerLayer<Store>
131where
132    Store: SessionStore,
133{
134    let mut layer = SessionManagerLayer::new(store)
135        .with_name(config.cookie_name.clone())
136        .with_path(config.cookie_path.clone())
137        .with_same_site(config.same_site.into())
138        .with_http_only(config.http_only)
139        .with_secure(config.secure);
140
141    if let Some(ttl) = config.ttl {
142        layer = layer.with_expiry(Expiry::OnInactivity(
143            Duration::seconds(ttl.as_secs() as i64),
144        ));
145    }
146
147    layer
148}
149
150#[derive(Debug, Snafu)]
151pub enum SessionContextError {
152    #[snafu(display("session context is missing"))]
153    MissingContext,
154    #[snafu(display("session operation failed: {source}"))]
155    Session {
156        source: tower_sessions::session::Error,
157    },
158    #[snafu(display("post-auth redirect is invalid: {source}"))]
159    RedirectTarget { source: RedirectTargetError },
160}
161
162pub type SessionContextResult<T> = Result<T, SessionContextError>;
163
164impl SessionContextError {
165    pub fn status_code(&self) -> StatusCode {
166        match self {
167            Self::MissingContext => StatusCode::UNAUTHORIZED,
168            Self::Session { .. } | Self::RedirectTarget { .. } => StatusCode::INTERNAL_SERVER_ERROR,
169        }
170    }
171}
172
173impl ToErrorPresentation for SessionContextError {
174    fn to_error_presentation(&self) -> ErrorPresentation {
175        match self {
176            SessionContextError::MissingContext => ErrorPresentation::new(
177                "authentication_required",
178                "Sign in to continue.",
179                UserRecovery::Reauthenticate,
180            ),
181            SessionContextError::Session { .. } => ErrorPresentation::new(
182                "session_unavailable",
183                "The session is temporarily unavailable.",
184                UserRecovery::Retry,
185            ),
186            SessionContextError::RedirectTarget { .. } => ErrorPresentation::new(
187                "session_post_auth_redirect_invalid",
188                "The configured post-auth redirect is invalid.",
189                UserRecovery::ContactSupport,
190            ),
191        }
192    }
193}
194
195impl SessionContextConfig {
196    pub fn resolve_post_auth_redirect(
197        &self,
198        requested_post_auth_redirect: Option<&str>,
199    ) -> SessionContextResult<String> {
200        UriRelativeRedirectTargetResolver::from_config(self.post_auth_redirect.clone())
201            .map_err(|source| SessionContextError::RedirectTarget { source })?
202            .resolve_redirect_target(requested_post_auth_redirect)
203            .map(|value| value.to_string())
204            .map_err(|source| SessionContextError::RedirectTarget { source })
205    }
206}
207
208#[derive(Clone)]
209pub struct SessionContextSession {
210    session: Session,
211    session_context_key: String,
212}
213
214impl From<Session> for SessionContextSession {
215    fn from(session: Session) -> Self {
216        Self {
217            session,
218            session_context_key: DEFAULT_SESSION_CONTEXT_KEY.to_string(),
219        }
220    }
221}
222
223impl SessionContextSession {
224    pub fn new(session: Session) -> Self {
225        Self::from(session)
226    }
227
228    pub fn from_config(session: Session, config: &SessionContextConfig) -> Self {
229        Self {
230            session,
231            session_context_key: config.session_context_key.clone(),
232        }
233    }
234
235    pub fn with_key(session: Session, session_context_key: impl Into<String>) -> Self {
236        Self {
237            session,
238            session_context_key: session_context_key.into(),
239        }
240    }
241
242    pub fn raw_session(&self) -> &Session {
243        &self.session
244    }
245
246    pub async fn insert<Extra>(&self, context: &SessionContext<Extra>) -> SessionContextResult<()>
247    where
248        Extra: Serialize,
249    {
250        self.session
251            .insert(&self.session_context_key, context)
252            .await
253            .map_err(|source| SessionContextError::Session { source })
254    }
255
256    pub async fn get<Extra>(&self) -> SessionContextResult<Option<SessionContext<Extra>>>
257    where
258        Extra: DeserializeOwned,
259    {
260        self.session
261            .get(&self.session_context_key)
262            .await
263            .map_err(|source| SessionContextError::Session { source })
264    }
265
266    pub async fn require<Extra>(&self) -> SessionContextResult<SessionContext<Extra>>
267    where
268        Extra: DeserializeOwned,
269    {
270        self.get().await?.ok_or(SessionContextError::MissingContext)
271    }
272
273    pub async fn clear(&self) -> SessionContextResult<()> {
274        self.session
275            .remove_value(&self.session_context_key)
276            .await
277            .map(|_| ())
278            .map_err(|source| SessionContextError::Session { source })
279    }
280
281    pub async fn is_authenticated<Extra>(&self) -> SessionContextResult<bool>
282    where
283        Extra: DeserializeOwned,
284    {
285        Ok(self.get::<Extra>().await?.is_some())
286    }
287
288    pub async fn cycle_id(&self) -> SessionContextResult<()> {
289        self.session
290            .cycle_id()
291            .await
292            .map_err(|source| SessionContextError::Session { source })
293    }
294
295    pub async fn flush(&self) -> SessionContextResult<()> {
296        self.session
297            .flush()
298            .await
299            .map_err(|source| SessionContextError::Session { source })
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use securitydept_utils::redirect::RedirectTargetRule;
306
307    use super::*;
308
309    #[test]
310    fn test_default_config() {
311        let config = SessionContextConfig::default();
312        assert_eq!(config.cookie_name, DEFAULT_COOKIE_NAME);
313        assert_eq!(config.session_context_key, DEFAULT_SESSION_CONTEXT_KEY);
314        assert_eq!(config.cookie_path, "/");
315        assert!(config.http_only);
316        assert!(!config.secure);
317        assert_eq!(config.same_site, SessionCookieSameSite::Lax);
318        assert_eq!(config.ttl, Some(StdDuration::from_secs(86_400)));
319        assert_eq!(
320            config.post_auth_redirect.default_redirect_target.as_deref(),
321            Some("/")
322        );
323    }
324
325    #[test]
326    fn test_context_with_extra_data() {
327        let context = SessionContext::builder()
328            .principal(
329                SessionPrincipal::builder()
330                    .subject("dev-session")
331                    .display_name("dev")
332                    .build(),
333            )
334            .attributes(HashMap::from([(
335                "mode".to_string(),
336                Value::String("dev".to_string()),
337            )]))
338            .extra(HashMap::from([(
339                "provider".to_string(),
340                Value::String("local".to_string()),
341            )]))
342            .build();
343
344        assert_eq!(context.principal.subject, "dev-session");
345        assert_eq!(context.principal.display_name, "dev");
346        assert_eq!(
347            context.attributes.get("mode"),
348            Some(&Value::String("dev".to_string()))
349        );
350        assert_eq!(
351            context
352                .extra
353                .as_ref()
354                .and_then(|extra| extra.get("provider")),
355            Some(&Value::String("local".to_string()))
356        );
357    }
358
359    #[test]
360    fn test_post_auth_redirect_resolution() {
361        let config = SessionContextConfig::builder()
362            .post_auth_redirect(RedirectTargetConfig::dynamic_default_and_dynamic_targets(
363                "/",
364                [RedirectTargetRule::Strict {
365                    value: "/app".to_string(),
366                }],
367            ))
368            .build();
369
370        assert_eq!(
371            config
372                .resolve_post_auth_redirect(None)
373                .expect("default redirect should resolve"),
374            "/"
375        );
376        assert_eq!(
377            config
378                .resolve_post_auth_redirect(Some("/app"))
379                .expect("dynamic redirect should resolve"),
380            "/app"
381        );
382    }
383}