Skip to main content

tansu_auth/
authenticate.rs

1// Copyright ⓒ 2024-2026 Peter Morgan <peter.james.morgan@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io::Cursor;
16
17use crate::{Authentication, Error, Stage};
18use bytes::Bytes;
19use rama::{Context, Service};
20use rsasl::prelude::State;
21use tansu_sans_io::{ApiKey, ErrorCode, SaslAuthenticateRequest, SaslAuthenticateResponse};
22use tokio::task;
23use tracing::debug;
24
25#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
26pub struct SaslAuthenticateService {
27    session_lifetime_ms: Option<i64>,
28}
29
30impl Default for SaslAuthenticateService {
31    fn default() -> Self {
32        Self {
33            session_lifetime_ms: Some(60_000),
34        }
35    }
36}
37
38impl SaslAuthenticateService {
39    pub fn session_lifetime_ms(self, session_lifetime_ms: Option<i64>) -> Self {
40        Self {
41            session_lifetime_ms,
42        }
43    }
44}
45
46impl ApiKey for SaslAuthenticateService {
47    const KEY: i16 = SaslAuthenticateRequest::KEY;
48}
49
50impl<S> Service<S, SaslAuthenticateRequest> for SaslAuthenticateService
51where
52    S: Send + Sync + 'static,
53{
54    type Response = SaslAuthenticateResponse;
55    type Error = Error;
56
57    async fn serve(
58        &self,
59        ctx: Context<S>,
60        req: SaslAuthenticateRequest,
61    ) -> Result<Self::Response, Self::Error> {
62        if let Some(authentication) = ctx.get::<Authentication>().cloned() {
63            let session_lifetime_ms = self.session_lifetime_ms;
64
65            task::spawn_blocking(move || {
66                authentication
67                    .stage
68                    .lock()
69                    .map_err(Into::into)
70                    .map(|mut guard| {
71                        if let Some(Stage::Session(session)) = guard.as_mut() {
72                            let mut outcome = Cursor::new(Vec::new());
73
74                            let Ok(state) = session
75                                .step(Some(&req.auth_bytes), &mut outcome)
76                                .inspect(|state| debug!(?state))
77                                .inspect_err(|err| debug!(?err))
78                            else {
79                                _ = guard.take();
80
81                                return SaslAuthenticateResponse::default()
82                                    .error_code(ErrorCode::SaslAuthenticationFailed.into())
83                                    .error_message(Some(
84                                        ErrorCode::SaslAuthenticationFailed.to_string(),
85                                    ))
86                                    .auth_bytes(Bytes::from_static(b""))
87                                    .session_lifetime_ms(Some(0));
88                            };
89
90                            let success = session
91                                .validation()
92                                .transpose()
93                                .ok()
94                                .flatten()
95                                .inspect(|success| debug!(?success));
96
97                            if let State::Finished(_) = state {
98                                _ = guard.replace(Stage::Finished(success))
99                            }
100
101                            SaslAuthenticateResponse::default()
102                                .error_code(ErrorCode::None.into())
103                                .error_message(Some("NONE".into()))
104                                .auth_bytes(Bytes::from(outcome.into_inner()))
105                                .session_lifetime_ms(session_lifetime_ms)
106                        } else {
107                            _ = guard.take();
108
109                            SaslAuthenticateResponse::default()
110                                .error_code(ErrorCode::IllegalSaslState.into())
111                                .error_message(Some(ErrorCode::IllegalSaslState.to_string()))
112                                .auth_bytes(Bytes::from_static(b""))
113                                .session_lifetime_ms(Some(0))
114                        }
115                    })
116            })
117            .await?
118        } else {
119            Ok(SaslAuthenticateResponse::default()
120                .error_code(ErrorCode::IllegalSaslState.into())
121                .error_message(Some(ErrorCode::IllegalSaslState.to_string()))
122                .auth_bytes(Bytes::from_static(b""))
123                .session_lifetime_ms(Some(0)))
124        }
125    }
126}