tansu_auth/
authenticate.rs1use std::io::Cursor;
16
17use crate::{Authentication, Error, Stage};
18use bytes::Bytes;
19use rama::{Context, Service};
20use rsasl::prelude::State;
21use tansu_sans_io::{ApiKey, ErrorCode, SaslAuthenticateRequest, SaslAuthenticateResponse};
22use tokio::task;
23use tracing::debug;
24
25#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
26pub struct SaslAuthenticateService {
27 session_lifetime_ms: Option<i64>,
28}
29
30impl Default for SaslAuthenticateService {
31 fn default() -> Self {
32 Self {
33 session_lifetime_ms: Some(60_000),
34 }
35 }
36}
37
38impl SaslAuthenticateService {
39 pub fn session_lifetime_ms(self, session_lifetime_ms: Option<i64>) -> Self {
40 Self {
41 session_lifetime_ms,
42 }
43 }
44}
45
46impl ApiKey for SaslAuthenticateService {
47 const KEY: i16 = SaslAuthenticateRequest::KEY;
48}
49
50impl<S> Service<S, SaslAuthenticateRequest> for SaslAuthenticateService
51where
52 S: Send + Sync + 'static,
53{
54 type Response = SaslAuthenticateResponse;
55 type Error = Error;
56
57 async fn serve(
58 &self,
59 ctx: Context<S>,
60 req: SaslAuthenticateRequest,
61 ) -> Result<Self::Response, Self::Error> {
62 if let Some(authentication) = ctx.get::<Authentication>().cloned() {
63 let session_lifetime_ms = self.session_lifetime_ms;
64
65 task::spawn_blocking(move || {
66 authentication
67 .stage
68 .lock()
69 .map_err(Into::into)
70 .map(|mut guard| {
71 if let Some(Stage::Session(session)) = guard.as_mut() {
72 let mut outcome = Cursor::new(Vec::new());
73
74 let Ok(state) = session
75 .step(Some(&req.auth_bytes), &mut outcome)
76 .inspect(|state| debug!(?state))
77 .inspect_err(|err| debug!(?err))
78 else {
79 _ = guard.take();
80
81 return SaslAuthenticateResponse::default()
82 .error_code(ErrorCode::SaslAuthenticationFailed.into())
83 .error_message(Some(
84 ErrorCode::SaslAuthenticationFailed.to_string(),
85 ))
86 .auth_bytes(Bytes::from_static(b""))
87 .session_lifetime_ms(Some(0));
88 };
89
90 let success = session
91 .validation()
92 .transpose()
93 .ok()
94 .flatten()
95 .inspect(|success| debug!(?success));
96
97 if let State::Finished(_) = state {
98 _ = guard.replace(Stage::Finished(success))
99 }
100
101 SaslAuthenticateResponse::default()
102 .error_code(ErrorCode::None.into())
103 .error_message(Some("NONE".into()))
104 .auth_bytes(Bytes::from(outcome.into_inner()))
105 .session_lifetime_ms(session_lifetime_ms)
106 } else {
107 _ = guard.take();
108
109 SaslAuthenticateResponse::default()
110 .error_code(ErrorCode::IllegalSaslState.into())
111 .error_message(Some(ErrorCode::IllegalSaslState.to_string()))
112 .auth_bytes(Bytes::from_static(b""))
113 .session_lifetime_ms(Some(0))
114 }
115 })
116 })
117 .await?
118 } else {
119 Ok(SaslAuthenticateResponse::default()
120 .error_code(ErrorCode::IllegalSaslState.into())
121 .error_message(Some(ErrorCode::IllegalSaslState.to_string()))
122 .auth_bytes(Bytes::from_static(b""))
123 .session_lifetime_ms(Some(0)))
124 }
125 }
126}