tunnelto_protocol/
lib.rs

1use rand::prelude::*;
2use serde::{Deserialize, Serialize};
3use sha2::Digest;
4
5#[derive(Serialize, Deserialize, Debug, Clone)]
6#[serde(transparent)]
7pub struct SecretKey(pub String);
8impl SecretKey {
9    pub fn client_id(&self) -> ClientId {
10        ClientId(base64::encode(
11            &sha2::Sha256::digest(self.0.as_bytes()).to_vec(),
12        ))
13    }
14
15    pub fn sha256_hex(&self) -> String {
16        hex::encode(sha2::Sha256::digest(self.0.as_bytes()).to_vec())
17    }
18}
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21#[serde(transparent)]
22pub struct ReconnectToken(pub String);
23
24#[derive(Serialize, Deserialize, Debug, Clone)]
25#[serde(rename_all = "snake_case")]
26pub enum ServerHello {
27    Success {
28        sub_domain: String,
29        hostname: String,
30        client_id: ClientId,
31    },
32    SubDomainInUse,
33    InvalidSubDomain,
34    AuthFailed,
35    Error(String),
36}
37
38impl ServerHello {
39    #[allow(unused)]
40    pub fn random_domain() -> String {
41        let mut rng = rand::thread_rng();
42        std::iter::repeat(())
43            .map(|_| rng.sample(rand::distributions::Alphanumeric))
44            .take(8)
45            .collect::<String>()
46            .to_lowercase()
47    }
48
49    #[allow(unused)]
50    pub fn prefixed_random_domain(prefix: &str) -> String {
51        format!("{}-{}", prefix, Self::random_domain())
52    }
53
54    pub fn prefixed_client_domain(prefix: &str, client_id: &ClientId, account_id: &str) -> String {
55        let input = format!("{}||{}", client_id, account_id);
56        let hash = sha2::Sha256::digest(input.as_bytes()).to_vec();
57        let encoded = base64::encode_config(&hash, base64::URL_SAFE_NO_PAD)
58            .to_lowercase()
59            .chars()
60            .take_while(|c| c.is_alphanumeric())
61            .take(8)
62            .collect::<String>();
63
64        format!("{}-{}", prefix, encoded)
65    }
66}
67
68#[derive(Serialize, Deserialize, Debug, Clone)]
69pub struct ClientHello {
70    /// deprecated: just send some garbage
71    id: ClientId,
72    pub sub_domain: Option<String>,
73    pub client_type: ClientType,
74    pub reconnect_token: Option<ReconnectToken>,
75    version: Option<ProtocolVersion>,
76}
77
78#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
79#[serde(rename_all = "snake_case")]
80pub enum ProtocolVersion {
81    #[serde(rename = "pre")]
82    BeforeVersioning,
83    V1,
84}
85
86impl ProtocolVersion {
87    pub const CURRENT: ProtocolVersion = ProtocolVersion::V1;
88}
89
90impl ClientHello {
91    pub fn get_version(&self) -> ProtocolVersion {
92        self.version.unwrap_or(ProtocolVersion::BeforeVersioning)
93    }
94
95    pub fn generate(sub_domain: Option<String>, typ: ClientType) -> Self {
96        ClientHello {
97            id: ClientId::generate(),
98            client_type: typ,
99            sub_domain,
100            reconnect_token: None,
101            version: Some(ProtocolVersion::CURRENT),
102        }
103    }
104
105    pub fn reconnect(reconnect_token: ReconnectToken) -> Self {
106        ClientHello {
107            id: ClientId::generate(),
108            sub_domain: None,
109            client_type: ClientType::Anonymous,
110            reconnect_token: Some(reconnect_token),
111            version: Some(ProtocolVersion::CURRENT),
112        }
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
117pub enum ClientType {
118    Auth { key: SecretKey },
119    Anonymous,
120}
121
122#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)]
123#[serde(transparent)]
124pub struct ClientId(String);
125
126impl std::fmt::Display for ClientId {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        self.0.fmt(f)
129    }
130}
131impl ClientId {
132    pub fn generate() -> Self {
133        let mut id = [0u8; 32];
134        rand::thread_rng().fill_bytes(&mut id);
135        ClientId(base64::encode_config(&id, base64::URL_SAFE_NO_PAD))
136    }
137
138    pub fn safe_id(self) -> ClientId {
139        ClientId(base64::encode(
140            &sha2::Sha256::digest(self.0.as_bytes()).to_vec(),
141        ))
142    }
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, Hash)]
146pub struct StreamId([u8; 8]);
147
148impl StreamId {
149    pub fn generate() -> StreamId {
150        let mut id = [0u8; 8];
151        rand::thread_rng().fill_bytes(&mut id);
152        StreamId(id)
153    }
154
155    pub fn to_string(&self) -> String {
156        format!(
157            "stream_{}",
158            base64::encode_config(&self.0, base64::URL_SAFE_NO_PAD)
159        )
160    }
161}
162
163#[derive(Debug, Clone)]
164pub enum ControlPacket {
165    Init(StreamId),
166    Data(StreamId, Vec<u8>),
167    Refused(StreamId),
168    End(StreamId),
169    Ping(Option<ReconnectToken>),
170    /// introdueced in V1
171    Terminate {
172        reason: String,
173    },
174}
175
176pub const PING_INTERVAL: u64 = 30;
177
178const EMPTY_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
179const TOKEN_STREAM: StreamId = StreamId([0xF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]);
180
181impl ControlPacket {
182    pub fn serialize(self) -> Vec<u8> {
183        match self {
184            ControlPacket::Init(sid) => [vec![0x01], sid.0.to_vec()].concat(),
185            ControlPacket::Data(sid, data) => [vec![0x02], sid.0.to_vec(), data].concat(),
186            ControlPacket::Refused(sid) => [vec![0x03], sid.0.to_vec()].concat(),
187            ControlPacket::End(sid) => [vec![0x04], sid.0.to_vec()].concat(),
188            ControlPacket::Ping(tok) => {
189                let data = tok.map_or(EMPTY_STREAM.0.to_vec(), |t| {
190                    vec![TOKEN_STREAM.0.to_vec(), t.0.into_bytes()].concat()
191                });
192                [vec![0x05], data].concat()
193            }
194            ControlPacket::Terminate { reason } => {
195                [vec![0x06], EMPTY_STREAM.0.to_vec(), reason.into_bytes()].concat()
196            }
197        }
198    }
199
200    pub fn packet_type(&self) -> &str {
201        match &self {
202            ControlPacket::Ping(_) => "PING",
203            ControlPacket::Init(_) => "INIT STREAM",
204            ControlPacket::Data(_, _) => "STREAM DATA",
205            ControlPacket::Refused(_) => "REFUSED",
206            ControlPacket::End(_) => "END STREAM",
207            ControlPacket::Terminate { .. } => "TERMINATE",
208        }
209    }
210
211    pub fn deserialize(data: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
212        if data.len() < 9 {
213            return Err("invalid DataPacket, missing stream id".into());
214        }
215
216        let mut stream_id = [0u8; 8];
217        stream_id.clone_from_slice(&data[1..9]);
218        let stream_id = StreamId(stream_id);
219
220        let packet = match data[0] {
221            0x01 => ControlPacket::Init(stream_id),
222            0x02 => ControlPacket::Data(stream_id, data[9..].to_vec()),
223            0x03 => ControlPacket::Refused(stream_id),
224            0x04 => ControlPacket::End(stream_id),
225            0x05 => {
226                if stream_id == EMPTY_STREAM {
227                    ControlPacket::Ping(None)
228                } else {
229                    ControlPacket::Ping(Some(ReconnectToken(
230                        String::from_utf8_lossy(&data[9..]).to_string(),
231                    )))
232                }
233            }
234            0x06 => ControlPacket::Terminate {
235                reason: String::from_utf8_lossy(&data[9..]).to_string(),
236            },
237            _ => return Err("invalid control byte in DataPacket".into()),
238        };
239
240        Ok(packet)
241    }
242}