Skip to main content

tansu_auth/
lib.rs

1// Copyright ⓒ 2024-2026 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use rsasl::{
16    callback::{Context, Request, SessionCallback, SessionData},
17    config::SASLConfig,
18    mechanisms::scram::properties::ScramStoredPassword,
19    prelude::{SASLError, SASLServer, Session, SessionError, Validation},
20    property::{AuthId, AuthzId, Password},
21    validate::{Validate, ValidationError},
22};
23use std::{
24    fmt::{self, Debug, Formatter},
25    str::FromStr,
26    sync::{Arc, Mutex, PoisonError},
27};
28use tansu_sans_io::ScramMechanism;
29use tansu_storage::Storage;
30use thiserror::Error;
31use tokio::task::JoinError;
32use tracing::{debug, instrument};
33
34mod authenticate;
35mod handshake;
36
37pub use authenticate::SaslAuthenticateService;
38pub use handshake::SaslHandshakeService;
39
40#[derive(Clone, Debug, Error)]
41pub enum Error {
42    Join(Arc<JoinError>),
43    Poison,
44    SansIo(#[from] tansu_sans_io::Error),
45    Sasl(Arc<SASLError>),
46    SaslSession(Arc<SessionError>),
47}
48
49impl fmt::Display for Error {
50    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51        write!(f, "{self:?}")
52    }
53}
54
55impl From<JoinError> for Error {
56    fn from(value: JoinError) -> Self {
57        Self::Join(Arc::new(value))
58    }
59}
60
61impl<T> From<PoisonError<T>> for Error {
62    fn from(_value: PoisonError<T>) -> Self {
63        Self::Poison
64    }
65}
66
67impl From<SASLError> for Error {
68    fn from(value: SASLError) -> Self {
69        Self::Sasl(Arc::new(value))
70    }
71}
72
73impl From<SessionError> for Error {
74    fn from(value: SessionError) -> Self {
75        Self::SaslSession(Arc::new(value))
76    }
77}
78
79#[derive(Clone, Default)]
80pub struct Authentication {
81    stage: Arc<Mutex<Option<Stage>>>,
82}
83
84pub enum Stage {
85    Server(SASLServer<Justification>),
86    Session(Session<Justification>),
87    Finished(Option<Success>),
88}
89
90impl Debug for Stage {
91    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
92        f.debug_struct(stringify!(Stage)).finish()
93    }
94}
95
96impl Authentication {
97    pub fn server(config: Arc<SASLConfig>) -> Self {
98        Self {
99            stage: Arc::new(Mutex::new(Some(Stage::Server(
100                SASLServer::<Justification>::new(config),
101            )))),
102        }
103    }
104
105    pub fn is_authenticated(&self) -> bool {
106        self.stage
107            .lock()
108            .map(|guard| matches!(guard.as_ref(), Some(Stage::Finished(_))))
109            .ok()
110            .unwrap_or_default()
111    }
112}
113
114impl Debug for Authentication {
115    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
116        f.debug_struct(stringify!(Authentication)).finish()
117    }
118}
119
120#[derive(Debug, Error)]
121pub enum AuthError {
122    Bad,
123    Io(tansu_sans_io::Error),
124    MissingProperty { mechanism: String, property: String },
125    NoSuchUser,
126    UnknownMechanism(String),
127}
128
129impl fmt::Display for AuthError {
130    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
131        write!(f, "{self:?}")
132    }
133}
134
135#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
136pub struct Success {
137    auth_id: String,
138}
139
140#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
141pub struct Justification;
142
143impl Validation for Justification {
144    type Value = Result<Success, AuthError>;
145}
146
147#[derive(Clone, Debug)]
148pub struct Callback<S> {
149    storage: S,
150}
151
152impl<S> Callback<S>
153where
154    S: Storage,
155{
156    pub fn new(storage: S) -> Self
157    where
158        S: Storage,
159    {
160        Self { storage }
161    }
162
163    #[instrument(skip_all)]
164    fn check(
165        &self,
166        session_data: &SessionData,
167        context: &Context<'_>,
168    ) -> Result<Result<Success, AuthError>, Error> {
169        debug!(mechanism = %session_data.mechanism().mechanism);
170
171        if session_data.mechanism().mechanism == "PLAIN" {
172            Ok(context
173                .get_ref::<Password>()
174                .ok_or(AuthError::MissingProperty {
175                    mechanism: session_data.mechanism().mechanism.to_string(),
176                    property: "Password".into(),
177                })
178                .and(
179                    context
180                        .get_ref::<AuthId>()
181                        .inspect(|auth_id| {
182                            debug!(mechanism = %session_data.mechanism().mechanism, auth_id)
183                        })
184                        .ok_or(AuthError::MissingProperty {
185                            mechanism: session_data.mechanism().mechanism.to_string(),
186                            property: "AuthId".into(),
187                        }).map(ToString::to_string).map(|auth_id| {
188                            Success { auth_id }
189                        })
190                ))
191        } else if session_data.mechanism().mechanism.starts_with("SCRAM-") {
192            Ok(context
193                .get_ref::<AuthId>()
194                .inspect(|auth_id| debug!(mechanism = %session_data.mechanism().mechanism, auth_id))
195                .ok_or(AuthError::MissingProperty {
196                    mechanism: session_data.mechanism().mechanism.to_string(),
197                    property: "AuthId".into(),
198                })
199                .and_then(|auth_id| {
200                    context
201                        .get_ref::<AuthzId>()
202                        .inspect(|authz_id| {
203                            debug!(mechanism = %session_data.mechanism().mechanism, authz_id)
204                        })
205                        .map_or(Ok(Success{
206                            auth_id:auth_id.to_string()
207                        }), |authz_id| {
208                            if authz_id == auth_id {
209                                Ok(Success{
210                                    auth_id:auth_id.to_string()
211                                })
212                            } else {
213                                Err(AuthError::Bad)
214                            }
215                        })
216                }))
217        } else {
218            Ok(Err(AuthError::UnknownMechanism(
219                session_data.mechanism().mechanism.to_string(),
220            )))
221        }
222    }
223}
224
225impl<S> SessionCallback for Callback<S>
226where
227    S: Storage,
228{
229    #[instrument(skip_all)]
230    fn callback(
231        &self,
232        session_data: &SessionData,
233        context: &Context<'_>,
234        request: &mut Request<'_>,
235    ) -> Result<(), SessionError> {
236        debug!(?session_data);
237
238        if session_data.mechanism().mechanism.starts_with("SCRAM-") {
239            let mechanism = ScramMechanism::from_str(session_data.mechanism().mechanism)
240                .map_err(|error| SessionError::Boxed(Box::new(error)))?;
241
242            let auth_id = context
243                .get_ref::<AuthId>()
244                .ok_or(SessionError::ValidationError(
245                    ValidationError::MissingRequiredProperty,
246                ))?;
247
248            debug!(?auth_id, ?mechanism);
249
250            let rt = tokio::runtime::Builder::new_current_thread()
251                .enable_all()
252                .build()?;
253
254            let storage = self.storage.clone();
255
256            if let Ok(Some(credential)) = rt
257                .block_on(async { storage.user_scram_credential(auth_id, mechanism).await })
258                .inspect_err(|err| debug!(auth_id, ?mechanism, ?err))
259            {
260                _ = request
261                    .satisfy::<ScramStoredPassword<'_>>(&ScramStoredPassword::new(
262                        credential.iterations as u32,
263                        &credential.salt[..],
264                        &credential.stored_key[..],
265                        &credential.server_key[..],
266                    ))
267                    .inspect_err(|err| debug!(auth_id, ?mechanism, ?err))?;
268            }
269        }
270
271        Ok(())
272    }
273
274    #[instrument(skip_all)]
275    fn validate(
276        &self,
277        session_data: &SessionData,
278        context: &Context<'_>,
279        validate: &mut Validate<'_>,
280    ) -> Result<(), ValidationError> {
281        debug!(?session_data);
282
283        _ = validate.with::<Justification, _>(|| {
284            self.check(session_data, context)
285                .map_err(|e| ValidationError::Boxed(Box::new(e)))
286        })?;
287
288        Ok(())
289    }
290}
291
292pub fn configuration<S>(storage: S) -> Result<Arc<SASLConfig>, Error>
293where
294    S: Storage,
295{
296    SASLConfig::builder()
297        .with_defaults()
298        .with_callback(Callback::new(storage))
299        .map_err(Into::into)
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    fn is_send<T: Send>() {}
307    fn is_sync<T: Sync>() {}
308
309    #[test]
310    fn authentication() {
311        is_send::<Authentication>();
312        is_sync::<Authentication>();
313    }
314}