Skip to main content

securitydept_session_context/
lib.rs

1pub mod config;
2#[cfg(feature = "service")]
3mod service;
4
5use std::collections::HashMap;
6#[cfg(test)]
7use std::time::Duration as StdDuration;
8
9pub use config::{
10    NoopSessionContextConfigValidator, ResolvedSessionContextConfig, SessionContextConfig,
11    SessionContextConfigSource, SessionContextConfigValidationError,
12    SessionContextConfigValidationFailure, SessionContextConfigValidator,
13    SessionContextFixedPostAuthRedirectValidator,
14};
15use http::StatusCode;
16#[cfg(test)]
17use securitydept_utils::redirect::RedirectTargetConfig;
18use securitydept_utils::{
19    error::{ErrorPresentation, ToErrorPresentation, UserRecovery},
20    principal::AuthenticatedPrincipal,
21    redirect::RedirectTargetError,
22};
23use serde::{Serialize, de::DeserializeOwned};
24use serde_json::Value;
25#[cfg(feature = "service")]
26pub use service::{
27    DevSessionAuthService, OidcSessionAuthService, OidcSessionAuthServiceConfig,
28    SessionAuthServiceError, SessionAuthServiceTrait,
29};
30use snafu::Snafu;
31use tower_sessions::{
32    Expiry, Session, SessionManagerLayer, SessionStore,
33    cookie::{SameSite, time::Duration},
34};
35use typed_builder::TypedBuilder;
36
37pub const DEFAULT_COOKIE_NAME: &str = "securitydept_session";
38pub const DEFAULT_SESSION_CONTEXT_KEY: &str = "securitydept.session_context";
39
40pub type SessionPrincipal = AuthenticatedPrincipal;
41
42#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, TypedBuilder)]
43pub struct SessionContext<Extra = HashMap<String, Value>> {
44    pub principal: SessionPrincipal,
45    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
46    #[builder(default)]
47    pub attributes: HashMap<String, Value>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    #[builder(default, setter(strip_option))]
50    pub extra: Option<Extra>,
51}
52
53#[derive(Debug, Clone, Copy, Serialize, serde::Deserialize, PartialEq, Eq, Default)]
54#[serde(rename_all = "snake_case")]
55pub enum SessionCookieSameSite {
56    Strict,
57    #[default]
58    Lax,
59    None,
60}
61
62impl From<SessionCookieSameSite> for SameSite {
63    fn from(value: SessionCookieSameSite) -> Self {
64        match value {
65            SessionCookieSameSite::Strict => SameSite::Strict,
66            SessionCookieSameSite::Lax => SameSite::Lax,
67            SessionCookieSameSite::None => SameSite::None,
68        }
69    }
70}
71
72pub fn build_session_layer<Store>(
73    config: &ResolvedSessionContextConfig,
74    store: Store,
75) -> SessionManagerLayer<Store>
76where
77    Store: SessionStore,
78{
79    let mut layer = SessionManagerLayer::new(store)
80        .with_name(config.cookie_name.clone())
81        .with_path(config.cookie_path.clone())
82        .with_same_site(config.same_site.into())
83        .with_http_only(config.http_only)
84        .with_secure(config.secure);
85
86    if let Some(ttl) = config.ttl {
87        layer = layer.with_expiry(Expiry::OnInactivity(
88            Duration::seconds(ttl.as_secs() as i64),
89        ));
90    }
91
92    layer
93}
94
95#[derive(Debug, Snafu)]
96pub enum SessionContextError {
97    #[snafu(display("session context is missing"))]
98    MissingContext,
99    #[snafu(display("session operation failed: {source}"))]
100    Session {
101        source: tower_sessions::session::Error,
102    },
103    #[snafu(display("post-auth redirect is invalid: {source}"))]
104    RedirectTarget { source: RedirectTargetError },
105}
106
107pub type SessionContextResult<T> = Result<T, SessionContextError>;
108
109impl SessionContextError {
110    pub fn status_code(&self) -> StatusCode {
111        match self {
112            Self::MissingContext => StatusCode::UNAUTHORIZED,
113            Self::Session { .. } | Self::RedirectTarget { .. } => StatusCode::INTERNAL_SERVER_ERROR,
114        }
115    }
116}
117
118impl ToErrorPresentation for SessionContextError {
119    fn to_error_presentation(&self) -> ErrorPresentation {
120        match self {
121            SessionContextError::MissingContext => ErrorPresentation::new(
122                "authentication_required",
123                "Sign in to continue.",
124                UserRecovery::Reauthenticate,
125            ),
126            SessionContextError::Session { .. } => ErrorPresentation::new(
127                "session_unavailable",
128                "The session is temporarily unavailable.",
129                UserRecovery::Retry,
130            ),
131            SessionContextError::RedirectTarget { .. } => ErrorPresentation::new(
132                "session_post_auth_redirect_invalid",
133                "The configured post-auth redirect is invalid.",
134                UserRecovery::ContactSupport,
135            ),
136        }
137    }
138}
139
140#[derive(Clone)]
141pub struct SessionContextSession {
142    session: Session,
143    session_context_key: String,
144}
145
146impl From<Session> for SessionContextSession {
147    fn from(session: Session) -> Self {
148        Self {
149            session,
150            session_context_key: DEFAULT_SESSION_CONTEXT_KEY.to_string(),
151        }
152    }
153}
154
155impl SessionContextSession {
156    pub fn new(session: Session) -> Self {
157        Self::from(session)
158    }
159
160    pub fn from_resolved_config(session: Session, config: &ResolvedSessionContextConfig) -> Self {
161        Self {
162            session,
163            session_context_key: config.session_context_key.clone(),
164        }
165    }
166
167    pub fn with_key(session: Session, session_context_key: impl Into<String>) -> Self {
168        Self {
169            session,
170            session_context_key: session_context_key.into(),
171        }
172    }
173
174    pub fn raw_session(&self) -> &Session {
175        &self.session
176    }
177
178    pub async fn insert<Extra>(&self, context: &SessionContext<Extra>) -> SessionContextResult<()>
179    where
180        Extra: Serialize,
181    {
182        self.session
183            .insert(&self.session_context_key, context)
184            .await
185            .map_err(|source| SessionContextError::Session { source })
186    }
187
188    pub async fn get<Extra>(&self) -> SessionContextResult<Option<SessionContext<Extra>>>
189    where
190        Extra: DeserializeOwned,
191    {
192        self.session
193            .get(&self.session_context_key)
194            .await
195            .map_err(|source| SessionContextError::Session { source })
196    }
197
198    pub async fn require<Extra>(&self) -> SessionContextResult<SessionContext<Extra>>
199    where
200        Extra: DeserializeOwned,
201    {
202        self.get().await?.ok_or(SessionContextError::MissingContext)
203    }
204
205    pub async fn clear(&self) -> SessionContextResult<()> {
206        self.session
207            .remove_value(&self.session_context_key)
208            .await
209            .map(|_| ())
210            .map_err(|source| SessionContextError::Session { source })
211    }
212
213    pub async fn is_authenticated<Extra>(&self) -> SessionContextResult<bool>
214    where
215        Extra: DeserializeOwned,
216    {
217        Ok(self.get::<Extra>().await?.is_some())
218    }
219
220    pub async fn cycle_id(&self) -> SessionContextResult<()> {
221        self.session
222            .cycle_id()
223            .await
224            .map_err(|source| SessionContextError::Session { source })
225    }
226
227    pub async fn flush(&self) -> SessionContextResult<()> {
228        self.session
229            .flush()
230            .await
231            .map_err(|source| SessionContextError::Session { source })
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use securitydept_utils::redirect::RedirectTargetRule;
238
239    use super::*;
240
241    #[test]
242    fn test_default_config() {
243        let config = SessionContextConfig::default();
244        assert_eq!(config.cookie_name, DEFAULT_COOKIE_NAME);
245        assert_eq!(config.session_context_key, DEFAULT_SESSION_CONTEXT_KEY);
246        assert_eq!(config.cookie_path, "/");
247        assert!(config.http_only);
248        assert!(!config.secure);
249        assert_eq!(config.same_site, SessionCookieSameSite::Lax);
250        assert_eq!(config.ttl, Some(StdDuration::from_secs(86_400)));
251        assert_eq!(
252            config.post_auth_redirect.default_redirect_target.as_deref(),
253            Some("/")
254        );
255    }
256
257    #[test]
258    fn test_context_with_extra_data() {
259        let context = SessionContext::builder()
260            .principal(
261                SessionPrincipal::builder()
262                    .subject("dev-session")
263                    .display_name("dev")
264                    .build(),
265            )
266            .attributes(HashMap::from([(
267                "mode".to_string(),
268                Value::String("dev".to_string()),
269            )]))
270            .extra(HashMap::from([(
271                "provider".to_string(),
272                Value::String("local".to_string()),
273            )]))
274            .build();
275
276        assert_eq!(context.principal.subject, "dev-session");
277        assert_eq!(context.principal.display_name, "dev");
278        assert_eq!(
279            context.attributes.get("mode"),
280            Some(&Value::String("dev".to_string()))
281        );
282        assert_eq!(
283            context
284                .extra
285                .as_ref()
286                .and_then(|extra| extra.get("provider")),
287            Some(&Value::String("local".to_string()))
288        );
289    }
290
291    #[test]
292    fn test_post_auth_redirect_resolution() {
293        let config = SessionContextConfigSource::resolve_all(
294            &SessionContextConfig::builder()
295                .post_auth_redirect(RedirectTargetConfig::dynamic_default_and_dynamic_targets(
296                    "/",
297                    [RedirectTargetRule::Strict {
298                        value: "/app".to_string(),
299                    }],
300                ))
301                .build(),
302        )
303        .expect("session context config should resolve");
304
305        assert_eq!(
306            config
307                .resolve_post_auth_redirect(None)
308                .expect("default redirect should resolve"),
309            "/"
310        );
311        assert_eq!(
312            config
313                .resolve_post_auth_redirect(Some("/app"))
314                .expect("dynamic redirect should resolve"),
315            "/app"
316        );
317    }
318
319    #[test]
320    fn fixed_post_auth_redirect_validator_rejects_override() {
321        let config = SessionContextConfig::builder()
322            .post_auth_redirect(RedirectTargetConfig::strict_default("/admin"))
323            .build();
324        let validator = SessionContextFixedPostAuthRedirectValidator::new(
325            RedirectTargetConfig::strict_default("/"),
326        );
327
328        let error = SessionContextConfigSource::resolve_all_with_validator(&config, &validator)
329            .expect_err("unexpected session post_auth_redirect should be rejected");
330
331        assert!(matches!(
332            error,
333            SessionContextConfigValidationFailure::Validation { source }
334                if source.field_path == "post_auth_redirect"
335                    && source.code == "fixed_post_auth_redirect_conflict"
336        ));
337    }
338}