1use rand::prelude::*;
2use serde::{Deserialize, Serialize};
3use sha2::Digest;
4
5#[derive(Serialize, Deserialize, Debug, Clone)]
6#[serde(transparent)]
7pub struct SecretKey(pub String);
8impl SecretKey {
9 pub fn client_id(&self) -> ClientId {
10 ClientId(base64::encode(
11 &sha2::Sha256::digest(self.0.as_bytes()).to_vec(),
12 ))
13 }
14
15 pub fn sha256_hex(&self) -> String {
16 hex::encode(sha2::Sha256::digest(self.0.as_bytes()).to_vec())
17 }
18}
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21#[serde(transparent)]
22pub struct ReconnectToken(pub String);
23
24#[derive(Serialize, Deserialize, Debug, Clone)]
25#[serde(rename_all = "snake_case")]
26pub enum ServerHello {
27 Success {
28 sub_domain: String,
29 hostname: String,
30 client_id: ClientId,
31 },
32 SubDomainInUse,
33 InvalidSubDomain,
34 AuthFailed,
35 Error(String),
36}
37
38impl ServerHello {
39 #[allow(unused)]
40 pub fn random_domain() -> String {
41 let mut rng = rand::thread_rng();
42 std::iter::repeat(())
43 .map(|_| rng.sample(rand::distributions::Alphanumeric))
44 .take(8)
45 .collect::<String>()
46 .to_lowercase()
47 }
48
49 #[allow(unused)]
50 pub fn prefixed_random_domain(prefix: &str) -> String {
51 format!("{}-{}", prefix, Self::random_domain())
52 }
53
54 pub fn prefixed_client_domain(prefix: &str, client_id: &ClientId, account_id: &str) -> String {
55 let input = format!("{}||{}", client_id, account_id);
56 let hash = sha2::Sha256::digest(input.as_bytes()).to_vec();
57 let encoded = base64::encode_config(&hash, base64::URL_SAFE_NO_PAD)
58 .to_lowercase()
59 .chars()
60 .take_while(|c| c.is_alphanumeric())
61 .take(8)
62 .collect::<String>();
63
64 format!("{}-{}", prefix, encoded)
65 }
66}
67
68#[derive(Serialize, Deserialize, Debug, Clone)]
69pub struct ClientHello {
70 id: ClientId,
72 pub sub_domain: Option<String>,
73 pub client_type: ClientType,
74 pub reconnect_token: Option<ReconnectToken>,
75 version: Option<ProtocolVersion>,
76}
77
78#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
79#[serde(rename_all = "snake_case")]
80pub enum ProtocolVersion {
81 #[serde(rename = "pre")]
82 BeforeVersioning,
83 V1,
84}
85
86impl ProtocolVersion {
87 pub const CURRENT: ProtocolVersion = ProtocolVersion::V1;
88}
89
90impl ClientHello {
91 pub fn get_version(&self) -> ProtocolVersion {
92 self.version.unwrap_or(ProtocolVersion::BeforeVersioning)
93 }
94
95 pub fn generate(sub_domain: Option<String>, typ: ClientType) -> Self {
96 ClientHello {
97 id: ClientId::generate(),
98 client_type: typ,
99 sub_domain,
100 reconnect_token: None,
101 version: Some(ProtocolVersion::CURRENT),
102 }
103 }
104
105 pub fn reconnect(reconnect_token: ReconnectToken) -> Self {
106 ClientHello {
107 id: ClientId::generate(),
108 sub_domain: None,
109 client_type: ClientType::Anonymous,
110 reconnect_token: Some(reconnect_token),
111 version: Some(ProtocolVersion::CURRENT),
112 }
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
117pub enum ClientType {
118 Auth { key: SecretKey },
119 Anonymous,
120}
121
122#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
123#[serde(transparent)]
124pub struct ClientId(String);
125
126impl std::fmt::Display for ClientId {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 self.0.fmt(f)
129 }
130}
131impl ClientId {
132 pub fn generate() -> Self {
133 let mut id = [0u8; 32];
134 rand::thread_rng().fill_bytes(&mut id);
135 ClientId(base64::encode_config(&id, base64::URL_SAFE_NO_PAD))
136 }
137
138 pub fn safe_id(self) -> ClientId {
139 ClientId(base64::encode(
140 &sha2::Sha256::digest(self.0.as_bytes()).to_vec(),
141 ))
142 }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Hash)]
146pub struct StreamId([u8; 8]);
147
148impl StreamId {
149 pub fn generate() -> StreamId {
150 let mut id = [0u8; 8];
151 rand::thread_rng().fill_bytes(&mut id);
152 StreamId(id)
153 }
154
155 pub fn to_string(&self) -> String {
156 format!(
157 "stream_{}",
158 base64::encode_config(&self.0, base64::URL_SAFE_NO_PAD)
159 )
160 }
161}
162
163#[derive(Debug, Clone)]
164pub enum ControlPacket {
165 Init(StreamId),
166 Data(StreamId, Vec<u8>),
167 Refused(StreamId),
168 End(StreamId),
169 Ping(Option<ReconnectToken>),
170 Terminate {
172 reason: String,
173 },
174}
175
176pub const PING_INTERVAL: u64 = 30;
177
178const EMPTY_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
179const TOKEN_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]);
180
181impl ControlPacket {
182 pub fn serialize(self) -> Vec<u8> {
183 match self {
184 ControlPacket::Init(sid) => [vec![0x01], sid.0.to_vec()].concat(),
185 ControlPacket::Data(sid, data) => [vec![0x02], sid.0.to_vec(), data].concat(),
186 ControlPacket::Refused(sid) => [vec![0x03], sid.0.to_vec()].concat(),
187 ControlPacket::End(sid) => [vec![0x04], sid.0.to_vec()].concat(),
188 ControlPacket::Ping(tok) => {
189 let data = tok.map_or(EMPTY_STREAM.0.to_vec(), |t| {
190 vec![TOKEN_STREAM.0.to_vec(), t.0.into_bytes()].concat()
191 });
192 [vec![0x05], data].concat()
193 }
194 ControlPacket::Terminate { reason } => {
195 [vec![0x06], EMPTY_STREAM.0.to_vec(), reason.into_bytes()].concat()
196 }
197 }
198 }
199
200 pub fn packet_type(&self) -> &str {
201 match &self {
202 ControlPacket::Ping(_) => "PING",
203 ControlPacket::Init(_) => "INIT STREAM",
204 ControlPacket::Data(_, _) => "STREAM DATA",
205 ControlPacket::Refused(_) => "REFUSED",
206 ControlPacket::End(_) => "END STREAM",
207 ControlPacket::Terminate { .. } => "TERMINATE",
208 }
209 }
210
211 pub fn deserialize(data: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
212 if data.len() < 9 {
213 return Err("invalid DataPacket, missing stream id".into());
214 }
215
216 let mut stream_id = [0u8; 8];
217 stream_id.clone_from_slice(&data[1..9]);
218 let stream_id = StreamId(stream_id);
219
220 let packet = match data[0] {
221 0x01 => ControlPacket::Init(stream_id),
222 0x02 => ControlPacket::Data(stream_id, data[9..].to_vec()),
223 0x03 => ControlPacket::Refused(stream_id),
224 0x04 => ControlPacket::End(stream_id),
225 0x05 => {
226 if stream_id == EMPTY_STREAM {
227 ControlPacket::Ping(None)
228 } else {
229 ControlPacket::Ping(Some(ReconnectToken(
230 String::from_utf8_lossy(&data[9..]).to_string(),
231 )))
232 }
233 }
234 0x06 => ControlPacket::Terminate {
235 reason: String::from_utf8_lossy(&data[9..]).to_string(),
236 },
237 _ => return Err("invalid control byte in DataPacket".into()),
238 };
239
240 Ok(packet)
241 }
242}