1use crate::{Authentication, Error, Stage};
16use rama::{Context, Service};
17use rsasl::prelude::Mechname;
18use tansu_sans_io::{ApiKey, ErrorCode, SaslHandshakeRequest, SaslHandshakeResponse};
19use tracing::{debug, instrument};
20
21#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
22pub struct SaslHandshakeService;
23
24impl ApiKey for SaslHandshakeService {
25 const KEY: i16 = SaslHandshakeRequest::KEY;
26}
27
28impl<S> Service<S, SaslHandshakeRequest> for SaslHandshakeService
29where
30 S: Send + Sync + 'static,
31{
32 type Response = SaslHandshakeResponse;
33 type Error = Error;
34
35 #[instrument(skip(self, ctx), ret)]
36 async fn serve(
37 &self,
38 ctx: Context<S>,
39 req: SaslHandshakeRequest,
40 ) -> Result<Self::Response, Self::Error> {
41 if let Some(authentication) = ctx.get::<Authentication>().cloned() {
42 authentication.stage
43 .lock()
44 .map_err(Into::into)
45 .and_then(|mut guard| {
46 if let Some(Stage::Server(server)) = guard.take()
47 && let Ok(mechanism) = Mechname::parse(req.mechanism.as_bytes())
48 {
49 debug!(available = ?server.get_available().into_iter().map(|mechanism|mechanism.mechanism.as_str()).collect::<Vec<_>>());
50
51 server
52 .start_suggested(mechanism)
53 .inspect_err(|err| debug!(?err, ?mechanism))
54 .map_err(Into::into)
55 .map(|session| {
56 let mechanisms = [session.get_mechname().to_string()];
57
58 _ = guard.replace(Stage::Session(session));
59
60 SaslHandshakeResponse::default()
61 .error_code(ErrorCode::None.into())
62 .mechanisms(Some(mechanisms.into()))
63 })
64 } else {
65 Ok(SaslHandshakeResponse::default()
66 .error_code(ErrorCode::UnsupportedSaslMechanism.into())
67 .mechanisms(Some([req.mechanism].into())))
68 }
69 })
70 } else {
71 Ok(SaslHandshakeResponse::default()
72 .error_code(ErrorCode::UnsupportedSaslMechanism.into())
73 .mechanisms(Some([req.mechanism].into())))
74 }
75 }
76}