1use rand::prelude::*;
2use serde::{Deserialize, Serialize};
3use sha2::Digest;
4
5#[derive(Serialize, Deserialize, Debug, Clone)]
6#[serde(transparent)]
7pub struct SecretKey(pub String);
8impl SecretKey {
9 pub fn generate() -> Self {
10 let mut rng = rand::thread_rng();
11 Self(
12 std::iter::repeat(())
13 .map(|_| rng.sample(rand::distributions::Alphanumeric))
14 .take(22)
15 .collect::<String>(),
16 )
17 }
18
19 pub fn client_id(&self) -> ClientId {
20 ClientId(base64::encode(
21 &sha2::Sha256::digest(self.0.as_bytes()).to_vec(),
22 ))
23 }
24}
25
26#[derive(Serialize, Deserialize, Debug, Clone)]
27#[serde(transparent)]
28pub struct ReconnectToken(pub String);
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31#[serde(rename_all = "snake_case")]
32pub enum ServerHello {
33 Success {
34 sub_domain: String,
35 hostname: String,
36 client_id: ClientId,
37 },
38 SubDomainInUse,
39 InvalidSubDomain,
40 AuthFailed,
41 Error(String),
42}
43
44impl ServerHello {
45 #[allow(unused)]
46 pub fn random_domain() -> String {
47 let mut rng = rand::thread_rng();
48 std::iter::repeat(())
49 .map(|_| rng.sample(rand::distributions::Alphanumeric))
50 .take(8)
51 .collect::<String>()
52 .to_lowercase()
53 }
54
55 #[allow(unused)]
56 pub fn prefixed_random_domain(prefix: &str) -> String {
57 format!("{}-{}", prefix, Self::random_domain())
58 }
59}
60
61#[derive(Serialize, Deserialize, Debug, Clone)]
62pub struct ClientHello {
63 id: ClientId,
65 pub sub_domain: Option<String>,
66 pub client_type: ClientType,
67 pub reconnect_token: Option<ReconnectToken>,
68}
69
70impl ClientHello {
71 pub fn generate(sub_domain: Option<String>, typ: ClientType) -> Self {
72 ClientHello {
73 id: ClientId::generate(),
74 client_type: typ,
75 sub_domain,
76 reconnect_token: None,
77 }
78 }
79
80 pub fn reconnect(reconnect_token: ReconnectToken) -> Self {
81 ClientHello {
82 id: ClientId::generate(),
83 sub_domain: None,
84 client_type: ClientType::Anonymous,
85 reconnect_token: Some(reconnect_token),
86 }
87 }
88}
89
90#[derive(Serialize, Deserialize, Debug, Clone)]
91pub enum ClientType {
92 Auth { key: SecretKey },
93 Anonymous,
94}
95
96#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
97#[serde(transparent)]
98pub struct ClientId(String);
99
100impl std::fmt::Display for ClientId {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 self.0.fmt(f)
103 }
104}
105impl ClientId {
106 pub fn generate() -> Self {
107 let mut id = [0u8; 32];
108 rand::thread_rng().fill_bytes(&mut id);
109 ClientId(base64::encode_config(&id, base64::URL_SAFE_NO_PAD))
110 }
111
112 pub fn safe_id(self) -> ClientId {
113 ClientId(base64::encode(
114 &sha2::Sha256::digest(self.0.as_bytes()).to_vec(),
115 ))
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq, Hash)]
120pub struct StreamId([u8; 8]);
121
122impl StreamId {
123 pub fn generate() -> StreamId {
124 let mut id = [0u8; 8];
125 rand::thread_rng().fill_bytes(&mut id);
126 StreamId(id)
127 }
128
129 pub fn to_string(&self) -> String {
130 format!(
131 "stream_{}",
132 base64::encode_config(&self.0, base64::URL_SAFE_NO_PAD)
133 )
134 }
135}
136
137#[derive(Debug, Clone)]
138pub enum ControlPacket {
139 Init(StreamId),
140 Data(StreamId, Vec<u8>),
141 Refused(StreamId),
142 End(StreamId),
143 Ping(Option<ReconnectToken>),
144}
145
146pub const PING_INTERVAL: u64 = 30;
147
148const EMPTY_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
149const TOKEN_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]);
150
151impl ControlPacket {
152 pub fn serialize(self) -> Vec<u8> {
153 match self {
154 ControlPacket::Init(sid) => [vec![0x01], sid.0.to_vec()].concat(),
155 ControlPacket::Data(sid, data) => [vec![0x02], sid.0.to_vec(), data].concat(),
156 ControlPacket::Refused(sid) => [vec![0x03], sid.0.to_vec()].concat(),
157 ControlPacket::End(sid) => [vec![0x04], sid.0.to_vec()].concat(),
158 ControlPacket::Ping(tok) => {
159 let data = tok.map_or(EMPTY_STREAM.0.to_vec(), |t| {
160 vec![TOKEN_STREAM.0.to_vec(), t.0.into_bytes()].concat()
161 });
162 [vec![0x05], data].concat()
163 }
164 }
165 }
166
167 pub fn packet_type(&self) -> &str {
168 match &self {
169 ControlPacket::Ping(_) => "PING",
170 ControlPacket::Init(_) => "INIT STREAM",
171 ControlPacket::Data(_, _) => "STREAM DATA",
172 ControlPacket::Refused(_) => "REFUSED",
173 ControlPacket::End(_) => "END STREAM",
174 }
175 }
176
177 pub fn deserialize(data: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
178 if data.len() < 9 {
179 return Err("invalid DataPacket, missing stream id".into());
180 }
181
182 let mut stream_id = [0u8; 8];
183 stream_id.clone_from_slice(&data[1..9]);
184 let stream_id = StreamId(stream_id);
185
186 let packet = match data[0] {
187 0x01 => ControlPacket::Init(stream_id),
188 0x02 => ControlPacket::Data(stream_id, data[9..].to_vec()),
189 0x03 => ControlPacket::Refused(stream_id),
190 0x04 => ControlPacket::End(stream_id),
191 0x05 => {
192 if stream_id == EMPTY_STREAM {
193 ControlPacket::Ping(None)
194 } else {
195 ControlPacket::Ping(Some(ReconnectToken(
196 String::from_utf8_lossy(&data[9..]).to_string(),
197 )))
198 }
199 }
200 _ => return Err("invalid control byte in DataPacket".into()),
201 };
202
203 Ok(packet)
204 }
205}