wisp_mux/extensions/
password.rs1use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
35
36use async_trait::async_trait;
37use bytes::{Buf, BufMut, Bytes, BytesMut};
38
39use crate::{
40 ws::{LockedWebSocketWrite, WebSocketRead},
41 Role, WispError,
42};
43
44use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
45
46#[derive(Debug, Clone)]
47pub struct PasswordProtocolExtension {
53 pub username: String,
57 pub password: String,
61 role: Role,
62}
63
64impl PasswordProtocolExtension {
65 pub const ID: u8 = 0x02;
67
68 pub fn new_server() -> Self {
72 Self {
73 username: String::new(),
74 password: String::new(),
75 role: Role::Server,
76 }
77 }
78
79 pub fn new_client(username: String, password: String) -> Self {
84 Self {
85 username,
86 password,
87 role: Role::Client,
88 }
89 }
90}
91
92#[async_trait]
93impl ProtocolExtension for PasswordProtocolExtension {
94 fn get_id(&self) -> u8 {
95 Self::ID
96 }
97
98 fn get_supported_packets(&self) -> &'static [u8] {
99 &[]
100 }
101
102 fn encode(&self) -> Bytes {
103 match self.role {
104 Role::Server => Bytes::new(),
105 Role::Client => {
106 let username = Bytes::from(self.username.clone().into_bytes());
107 let password = Bytes::from(self.password.clone().into_bytes());
108 let username_len = u8::try_from(username.len()).expect("username was too long");
109 let password_len = u16::try_from(password.len()).expect("password was too long");
110
111 let mut bytes =
112 BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
113 bytes.put_u8(username_len);
114 bytes.put_u16_le(password_len);
115 bytes.extend(username);
116 bytes.extend(password);
117 bytes.freeze()
118 }
119 }
120 }
121
122 async fn handle_handshake(
123 &mut self,
124 _: &mut dyn WebSocketRead,
125 _: &LockedWebSocketWrite,
126 ) -> Result<(), WispError> {
127 Ok(())
128 }
129
130 async fn handle_packet(
131 &mut self,
132 _: Bytes,
133 _: &mut dyn WebSocketRead,
134 _: &LockedWebSocketWrite,
135 ) -> Result<(), WispError> {
136 Ok(())
137 }
138
139 fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
140 Box::new(self.clone())
141 }
142}
143
144#[derive(Debug)]
145enum PasswordProtocolExtensionError {
146 Utf8Error(FromUtf8Error),
147 InvalidUsername,
148 InvalidPassword,
149}
150
151impl Display for PasswordProtocolExtensionError {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 use PasswordProtocolExtensionError as E;
154 match self {
155 E::Utf8Error(e) => write!(f, "{}", e),
156 E::InvalidUsername => write!(f, "Invalid username"),
157 E::InvalidPassword => write!(f, "Invalid password"),
158 }
159 }
160}
161
162impl Error for PasswordProtocolExtensionError {}
163
164impl From<PasswordProtocolExtensionError> for WispError {
165 fn from(value: PasswordProtocolExtensionError) -> Self {
166 WispError::ExtensionImplError(Box::new(value))
167 }
168}
169
170impl From<FromUtf8Error> for PasswordProtocolExtensionError {
171 fn from(value: FromUtf8Error) -> Self {
172 PasswordProtocolExtensionError::Utf8Error(value)
173 }
174}
175
176impl From<PasswordProtocolExtension> for AnyProtocolExtension {
177 fn from(value: PasswordProtocolExtension) -> Self {
178 AnyProtocolExtension(Box::new(value))
179 }
180}
181
182pub struct PasswordProtocolExtensionBuilder {
186 pub users: HashMap<String, String>,
188 pub username: String,
190 pub password: String,
192}
193
194impl PasswordProtocolExtensionBuilder {
195 pub fn new_server(users: HashMap<String, String>) -> Self {
198 Self {
199 users,
200 username: String::new(),
201 password: String::new(),
202 }
203 }
204
205 pub fn new_client(username: String, password: String) -> Self {
208 Self {
209 users: HashMap::new(),
210 username,
211 password,
212 }
213 }
214}
215
216impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
217 fn get_id(&self) -> u8 {
218 PasswordProtocolExtension::ID
219 }
220
221 fn build_from_bytes(
222 &self,
223 mut payload: Bytes,
224 role: crate::Role,
225 ) -> Result<AnyProtocolExtension, WispError> {
226 match role {
227 Role::Server => {
228 if payload.remaining() < 3 {
229 return Err(WispError::PacketTooSmall);
230 }
231
232 let username_len = payload.get_u8();
233 let password_len = payload.get_u16_le();
234 if payload.remaining() < (password_len + username_len as u16) as usize {
235 return Err(WispError::PacketTooSmall);
236 }
237
238 use PasswordProtocolExtensionError as EError;
239 let username =
240 String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
241 .map_err(|x| WispError::from(EError::from(x)))?;
242 let password =
243 String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
244 .map_err(|x| WispError::from(EError::from(x)))?;
245
246 let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
247 return Err(EError::InvalidUsername.into());
248 };
249
250 if *user.1 != password {
251 return Err(EError::InvalidPassword.into());
252 }
253
254 Ok(PasswordProtocolExtension {
255 username,
256 password,
257 role,
258 }
259 .into())
260 }
261 Role::Client => {
262 Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
263 }
264 }
265 }
266
267 fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
268 match role {
269 Role::Server => PasswordProtocolExtension::new_server(),
270 Role::Client => {
271 PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
272 }
273 }
274 .into()
275 }
276}