1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
//! Wisp protocol extensions.
pub mod password;
pub mod udp;

use std::ops::{Deref, DerefMut};

use async_trait::async_trait;
use bytes::{BufMut, Bytes, BytesMut};

use crate::{
	ws::{LockedWebSocketWrite, WebSocketRead},
	Role, WispError,
};

/// Type-erased protocol extension that implements Clone.
#[derive(Debug)]
pub struct AnyProtocolExtension(Box<dyn ProtocolExtension + Sync + Send>);

impl AnyProtocolExtension {
	/// Create a new type-erased protocol extension.
	pub fn new<T: ProtocolExtension + Sync + Send + 'static>(extension: T) -> Self {
		Self(Box::new(extension))
	}
}

impl Deref for AnyProtocolExtension {
	type Target = dyn ProtocolExtension;
	fn deref(&self) -> &Self::Target {
		self.0.deref()
	}
}

impl DerefMut for AnyProtocolExtension {
	fn deref_mut(&mut self) -> &mut Self::Target {
		self.0.deref_mut()
	}
}

impl Clone for AnyProtocolExtension {
	fn clone(&self) -> Self {
		Self(self.0.box_clone())
	}
}

impl From<AnyProtocolExtension> for Bytes {
	fn from(value: AnyProtocolExtension) -> Self {
		let mut bytes = BytesMut::with_capacity(5);
		let payload = value.encode();
		bytes.put_u8(value.get_id());
		bytes.put_u32_le(payload.len() as u32);
		bytes.extend(payload);
		bytes.freeze()
	}
}

/// A Wisp protocol extension.
///
/// See [the
/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/v2/protocol.md#protocol-extensions).
#[async_trait]
pub trait ProtocolExtension: std::fmt::Debug {
	/// Get the protocol extension ID.
	fn get_id(&self) -> u8;
	/// Get the protocol extension's supported packets.
	///
	/// Used to decide whether to call the protocol extension's packet handler.
	fn get_supported_packets(&self) -> &'static [u8];

	/// Encode self into Bytes.
	fn encode(&self) -> Bytes;

	/// Handle the handshake part of a Wisp connection.
	///
	/// This should be used to send or receive data before any streams are created.
	async fn handle_handshake(
		&mut self,
		read: &mut dyn WebSocketRead,
		write: &LockedWebSocketWrite,
	) -> Result<(), WispError>;

	/// Handle receiving a packet.
	async fn handle_packet(
		&mut self,
		packet: Bytes,
		read: &mut dyn WebSocketRead,
		write: &LockedWebSocketWrite,
	) -> Result<(), WispError>;

	/// Clone the protocol extension.
	fn box_clone(&self) -> Box<dyn ProtocolExtension + Sync + Send>;
}

/// Trait to build a Wisp protocol extension from a payload.
pub trait ProtocolExtensionBuilder {
	/// Get the protocol extension ID.
	///
	/// Used to decide whether this builder should be used.
	fn get_id(&self) -> u8;

	/// Build a protocol extension from the extension's metadata.
	fn build_from_bytes(&self, bytes: Bytes, role: Role)
		-> Result<AnyProtocolExtension, WispError>;

	/// Build a protocol extension to send to the other side.
	fn build_to_extension(&self, role: Role) -> AnyProtocolExtension;
}