wisp_mux/extensions/
password.rs

1//! Password protocol extension.
2//!
3//! Passwords are sent in plain text!!
4//!
5//! # Example
6//! Server:
7//! ```
8//! let mut passwords = HashMap::new();
9//! passwords.insert("user1".to_string(), "pw".to_string());
10//! let (mux, fut) = ServerMux::new(
11//!     rx,
12//!     tx,
13//!     128,
14//!     Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))])
15//! );
16//! ```
17//!
18//! Client:
19//! ```
20//! let (mux, fut) = ClientMux::new(
21//!     rx,
22//!     tx,
23//!     128,
24//!     Some(&[
25//!          Box::new(PasswordProtocolExtensionBuilder::new_client(
26//!             "user1".to_string(),
27//!             "pw".to_string()
28//!         ))
29//!     ])
30//! );
31//! ```
32//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#0x02---password-authentication)
33
34use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error};
35
36use async_trait::async_trait;
37use bytes::{Buf, BufMut, Bytes, BytesMut};
38
39use crate::{
40	ws::{LockedWebSocketWrite, WebSocketRead},
41	Role, WispError,
42};
43
44use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder};
45
46#[derive(Debug, Clone)]
47/// Password protocol extension.
48///
49/// **Passwords are sent in plain text!!**
50/// **This extension will panic when encoding if the username's length does not fit within a u8
51/// or the password's length does not fit within a u16.**
52pub struct PasswordProtocolExtension {
53	/// The username to log in with.
54	///
55	/// This string's length must fit within a u8.
56	pub username: String,
57	/// The password to log in with.
58	///
59	/// This string's length must fit within a u16.
60	pub password: String,
61	role: Role,
62}
63
64impl PasswordProtocolExtension {
65	/// Password protocol extension ID.
66	pub const ID: u8 = 0x02;
67
68	/// Create a new password protocol extension for the server.
69	///
70	/// This signifies that the server requires a password.
71	pub fn new_server() -> Self {
72		Self {
73			username: String::new(),
74			password: String::new(),
75			role: Role::Server,
76		}
77	}
78
79	/// Create a new password protocol extension for the client, with a username and password.
80	///
81	/// The username's length must fit within a u8. The password's length must fit within a
82	/// u16.
83	pub fn new_client(username: String, password: String) -> Self {
84		Self {
85			username,
86			password,
87			role: Role::Client,
88		}
89	}
90}
91
92#[async_trait]
93impl ProtocolExtension for PasswordProtocolExtension {
94	fn get_id(&self) -> u8 {
95		Self::ID
96	}
97
98	fn get_supported_packets(&self) -> &'static [u8] {
99		&[]
100	}
101
102	fn encode(&self) -> Bytes {
103		match self.role {
104			Role::Server => Bytes::new(),
105			Role::Client => {
106				let username = Bytes::from(self.username.clone().into_bytes());
107				let password = Bytes::from(self.password.clone().into_bytes());
108				let username_len = u8::try_from(username.len()).expect("username was too long");
109				let password_len = u16::try_from(password.len()).expect("password was too long");
110
111				let mut bytes =
112					BytesMut::with_capacity(3 + username_len as usize + password_len as usize);
113				bytes.put_u8(username_len);
114				bytes.put_u16_le(password_len);
115				bytes.extend(username);
116				bytes.extend(password);
117				bytes.freeze()
118			}
119		}
120	}
121
122	async fn handle_handshake(
123		&mut self,
124		_: &mut dyn WebSocketRead,
125		_: &LockedWebSocketWrite,
126	) -> Result<(), WispError> {
127		Ok(())
128	}
129
130	async fn handle_packet(
131		&mut self,
132		_: Bytes,
133		_: &mut dyn WebSocketRead,
134		_: &LockedWebSocketWrite,
135	) -> Result<(), WispError> {
136		Ok(())
137	}
138
139	fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send> {
140		Box::new(self.clone())
141	}
142}
143
144#[derive(Debug)]
145enum PasswordProtocolExtensionError {
146	Utf8Error(FromUtf8Error),
147	InvalidUsername,
148	InvalidPassword,
149}
150
151impl Display for PasswordProtocolExtensionError {
152	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153		use PasswordProtocolExtensionError as E;
154		match self {
155			E::Utf8Error(e) => write!(f, "{}", e),
156			E::InvalidUsername => write!(f, "Invalid username"),
157			E::InvalidPassword => write!(f, "Invalid password"),
158		}
159	}
160}
161
162impl Error for PasswordProtocolExtensionError {}
163
164impl From<PasswordProtocolExtensionError> for WispError {
165	fn from(value: PasswordProtocolExtensionError) -> Self {
166		WispError::ExtensionImplError(Box::new(value))
167	}
168}
169
170impl From<FromUtf8Error> for PasswordProtocolExtensionError {
171	fn from(value: FromUtf8Error) -> Self {
172		PasswordProtocolExtensionError::Utf8Error(value)
173	}
174}
175
176impl From<PasswordProtocolExtension> for AnyProtocolExtension {
177	fn from(value: PasswordProtocolExtension) -> Self {
178		AnyProtocolExtension(Box::new(value))
179	}
180}
181
182/// Password protocol extension builder.
183///
184/// **Passwords are sent in plain text!!**
185pub struct PasswordProtocolExtensionBuilder {
186	/// Map of users and their passwords to allow. Only used on server.
187	pub users: HashMap<String, String>,
188	/// Username to authenticate with. Only used on client.
189	pub username: String,
190	/// Password to authenticate with. Only used on client.
191	pub password: String,
192}
193
194impl PasswordProtocolExtensionBuilder {
195	/// Create a new password protocol extension builder for the server, with a map of users
196	/// and passwords to allow.
197	pub fn new_server(users: HashMap<String, String>) -> Self {
198		Self {
199			users,
200			username: String::new(),
201			password: String::new(),
202		}
203	}
204
205	/// Create a new password protocol extension builder for the client, with a username and
206	/// password to authenticate with.
207	pub fn new_client(username: String, password: String) -> Self {
208		Self {
209			users: HashMap::new(),
210			username,
211			password,
212		}
213	}
214}
215
216impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder {
217	fn get_id(&self) -> u8 {
218		PasswordProtocolExtension::ID
219	}
220
221	fn build_from_bytes(
222		&self,
223		mut payload: Bytes,
224		role: crate::Role,
225	) -> Result<AnyProtocolExtension, WispError> {
226		match role {
227			Role::Server => {
228				if payload.remaining() < 3 {
229					return Err(WispError::PacketTooSmall);
230				}
231
232				let username_len = payload.get_u8();
233				let password_len = payload.get_u16_le();
234				if payload.remaining() < (password_len + username_len as u16) as usize {
235					return Err(WispError::PacketTooSmall);
236				}
237
238				use PasswordProtocolExtensionError as EError;
239				let username =
240					String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec())
241						.map_err(|x| WispError::from(EError::from(x)))?;
242				let password =
243					String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec())
244						.map_err(|x| WispError::from(EError::from(x)))?;
245
246				let Some(user) = self.users.iter().find(|x| *x.0 == username) else {
247					return Err(EError::InvalidUsername.into());
248				};
249
250				if *user.1 != password {
251					return Err(EError::InvalidPassword.into());
252				}
253
254				Ok(PasswordProtocolExtension {
255					username,
256					password,
257					role,
258				}
259				.into())
260			}
261			Role::Client => {
262				Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into())
263			}
264		}
265	}
266
267	fn build_to_extension(&self, role: Role) -> AnyProtocolExtension {
268		match role {
269			Role::Server => PasswordProtocolExtension::new_server(),
270			Role::Client => {
271				PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone())
272			}
273		}
274		.into()
275	}
276}