sqrl_protocol/
lib.rs

1//! Code needed for SQRL client and server communication
2
3#![deny(missing_docs)]
4pub mod client_request;
5pub mod error;
6pub mod server_response;
7
8use crate::error::SqrlError;
9use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
10use ed25519_dalek::{Signature, VerifyingKey};
11use std::{collections::HashMap, fmt, result};
12use url::Url;
13
14/// The general protocl for SQRL urls
15pub const SQRL_PROTOCOL: &str = "sqrl";
16
17/// The current list of supported versions
18pub const PROTOCOL_VERSIONS: &str = "1";
19
20/// A default result type for the crate
21pub type Result<G> = result::Result<G, SqrlError>;
22
23/// Parses a SQRL url and breaks it into its parts
24#[derive(Debug, PartialEq)]
25pub struct SqrlUrl {
26    url: Url,
27}
28
29impl SqrlUrl {
30    /// Parse a SQRL url string and convert it into the object
31    /// ```rust
32    /// use sqrl_protocol::SqrlUrl;
33    ///
34    /// let sqrl_url = SqrlUrl::parse("sqrl://example.com?nut=1234abcd").unwrap();
35    /// ```
36    pub fn parse(url: &str) -> Result<Self> {
37        let parsed = Url::parse(url)?;
38        if parsed.scheme() != SQRL_PROTOCOL {
39            return Err(SqrlError::new(format!(
40                "Invalid sqrl url, incorrect protocol: {}",
41                url
42            )));
43        }
44        if parsed.domain().is_none() {
45            return Err(SqrlError::new(format!(
46                "Invalid sqrl url, missing domain: {}",
47                url
48            )));
49        }
50
51        Ok(SqrlUrl { url: parsed })
52    }
53
54    /// Get the auth domain used for calculating identities
55    /// ```rust
56    /// use sqrl_protocol::SqrlUrl;
57    ///
58    /// let sqrl_url = SqrlUrl::parse("sqrl://example.com/auth/path?nut=1234abcd").unwrap();
59    /// assert_eq!("example.com/auth/path", sqrl_url.get_auth_domain())
60    /// ```
61    pub fn get_auth_domain(&self) -> String {
62        format!("{}{}", self.get_domain(), self.get_path())
63    }
64
65    fn get_domain(&self) -> String {
66        self.url.domain().unwrap().to_lowercase()
67    }
68
69    fn get_path(&self) -> String {
70        let path = self.url.path().strip_suffix('/').unwrap_or(self.url.path());
71        path.to_owned()
72    }
73}
74
75impl fmt::Display for SqrlUrl {
76    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77        write!(f, "{}", self.url)
78    }
79}
80
81pub(crate) fn get_or_error(
82    map: &HashMap<String, String>,
83    key: &str,
84    error_message: &str,
85) -> Result<String> {
86    match map.get(key) {
87        Some(x) => Ok(x.to_owned()),
88        None => Err(SqrlError::new(error_message.to_owned())),
89    }
90}
91
92pub(crate) fn parse_query_data(query: &str) -> Result<HashMap<String, String>> {
93    let mut map = HashMap::<String, String>::new();
94    for token in query.split('&') {
95        if let Some((key, value)) = token.split_once('=') {
96            map.insert(key.to_owned(), value.to_owned());
97        } else {
98            return Err(SqrlError::new("Invalid query data".to_owned()));
99        }
100    }
101    Ok(map)
102}
103
104pub(crate) fn decode_public_key(key: &str) -> Result<VerifyingKey> {
105    let bytes: [u8; 32];
106    match BASE64_URL_SAFE_NO_PAD.decode(key) {
107        Ok(x) => bytes = vec_to_u8_32(&x)?,
108        Err(_) => {
109            return Err(SqrlError::new(format!(
110                "Failed to decode base64 encoded public key {}",
111                key
112            )))
113        }
114    }
115
116    match VerifyingKey::from_bytes(&bytes) {
117        Ok(x) => Ok(x),
118        Err(e) => Err(SqrlError::new(format!(
119            "Failed to generate public key from {}: {}",
120            key, e
121        ))),
122    }
123}
124
125pub(crate) fn decode_signature(key: &str) -> Result<Signature> {
126    let bytes: [u8; 64];
127    match BASE64_URL_SAFE_NO_PAD.decode(key) {
128        Ok(x) => bytes = vec_to_u8_64(&x)?,
129        Err(_) => {
130            return Err(SqrlError::new(format!(
131                "Failed to decode base64 encoded signature {}",
132                key
133            )))
134        }
135    }
136
137    Ok(Signature::from_bytes(&bytes))
138}
139
140pub(crate) fn parse_newline_data(data: &str) -> Result<HashMap<String, String>> {
141    let mut map = HashMap::<String, String>::new();
142    for token in data.split('\n') {
143        if let Some((key, value)) = token.split_once('=') {
144            map.insert(key.to_owned(), value.trim().to_owned());
145        } else if !token.is_empty() {
146            return Err(SqrlError::new(format!("Invalid newline data {}", token)));
147        }
148    }
149
150    Ok(map)
151}
152
153pub(crate) fn encode_newline_data(map: &HashMap<&str, &str>) -> String {
154    let mut result = String::new();
155    for (key, value) in map.iter() {
156        result += &format!("\n{key}={value}");
157    }
158
159    result
160}
161
162pub(crate) fn vec_to_u8_32(vector: &[u8]) -> Result<[u8; 32]> {
163    let mut result = [0; 32];
164    if vector.len() != 32 {
165        return Err(SqrlError::new(format!(
166            "Error converting vec<u8> to [u8; 32]: Expected 32 bytes, but found {}",
167            vector.len()
168        )));
169    }
170
171    result[..32].copy_from_slice(&vector[..32]);
172    Ok(result)
173}
174
175pub(crate) fn vec_to_u8_64(vector: &[u8]) -> Result<[u8; 64]> {
176    let mut result = [0; 64];
177    if vector.len() != 64 {
178        return Err(SqrlError::new(format!(
179            "Error converting vec<u8> to [u8; 64]: Expected 64 bytes, but found {}",
180            vector.len()
181        )));
182    }
183
184    result[..64].copy_from_slice(&vector[..64]);
185    Ok(result)
186}
187
188/// The versions of the sqrl protocol supported by a client/server
189#[derive(Debug, PartialEq)]
190pub struct ProtocolVersion {
191    versions: u128,
192    max_version: u8,
193}
194
195impl ProtocolVersion {
196    /// Create a new object based on the version string
197    /// ```rust
198    /// use sqrl_protocol::ProtocolVersion;
199    ///
200    /// let version = ProtocolVersion::new("1,3,6-10").unwrap();
201    /// ```
202    pub fn new(versions: &str) -> Result<Self> {
203        let mut prot = ProtocolVersion {
204            versions: 0,
205            max_version: 0,
206        };
207        for sub in versions.split(',') {
208            if sub.contains('-') {
209                let mut versions = sub.split('-');
210
211                // Parse out the lower and higher end of the range
212                let low: u8 = match versions.next() {
213                    Some(x) => x.parse::<u8>()?,
214                    None => {
215                        return Err(SqrlError::new(format!("Invalid version number {}", sub)));
216                    }
217                };
218                let high: u8 = match versions.next() {
219                    Some(x) => x.parse::<u8>()?,
220                    None => {
221                        return Err(SqrlError::new(format!("Invalid version number {}", sub)));
222                    }
223                };
224
225                // Make sure the range is valid
226                if low >= high {
227                    return Err(SqrlError::new(format!("Invalid version number {}", sub)));
228                }
229
230                // Set the neccesary values
231                for i in low..high + 1 {
232                    prot.versions |= 0b00000001 << (i - 1);
233                }
234                if high > prot.max_version {
235                    prot.max_version = high;
236                }
237            } else {
238                let version = sub.parse::<u8>()?;
239                prot.versions |= 0b00000001 << (version - 1);
240                if version > prot.max_version {
241                    prot.max_version = version;
242                }
243            }
244        }
245
246        Ok(prot)
247    }
248
249    /// Compares two protocol version objects, returning the highest version
250    /// supported by both
251    /// ```rust
252    /// use sqrl_protocol::ProtocolVersion;
253    ///
254    /// let version = ProtocolVersion::new("1,3,5,7,9").unwrap();
255    /// let version2 = ProtocolVersion::new("2,4,5,8,10").unwrap();
256    /// assert_eq!(5, version.get_max_matching_version(&version2).unwrap());
257    /// ```
258    pub fn get_max_matching_version(&self, other: &ProtocolVersion) -> Result<u8> {
259        let min_max = if self.max_version > other.max_version {
260            other.max_version
261        } else {
262            self.max_version
263        };
264
265        let matches = self.versions & other.versions;
266
267        // Start from the highest match and work our way back
268        let bit: u128 = 0b00000001 << min_max;
269        for i in 0..min_max {
270            if matches & (bit >> i) == bit >> i {
271                return Ok(min_max - i + 1);
272            }
273        }
274
275        Err(SqrlError::new(format!(
276            "No matching supported version! Ours: {} Theirs: {}",
277            self, other
278        )))
279    }
280}
281
282impl fmt::Display for ProtocolVersion {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        let mut versions: Vec<String> = Vec::new();
285        let mut current_min: Option<u8> = None;
286        let mut bit: u128 = 0b00000001;
287        for i in 0..self.max_version {
288            if self.versions & bit == bit {
289                // If we don't have a current min set it.
290                // Otherwise, keep going until the range ends
291                if current_min.is_none() {
292                    current_min = Some(i);
293                }
294            } else {
295                // Did we experience a range, or just a single one?
296                if let Some(min) = current_min {
297                    if i == min + 1 {
298                        // A streak of one
299                        versions.push(format!("{}", min + 1));
300                    } else {
301                        versions.push(format!("{}-{}", min + 1, i));
302                    }
303
304                    current_min = None;
305                }
306            }
307
308            bit <<= 1;
309        }
310
311        // If we still have a min set, we need to run that same code again
312        if let Some(min) = current_min {
313            if self.max_version == min + 1 {
314                // A streak of one
315                versions.push(format!("{}", min + 1));
316            } else {
317                versions.push(format!("{}-{}", min + 1, self.max_version));
318            }
319        }
320
321        write!(f, "{}", versions.join(","))
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn protocol_version_create_valid_version() {
331        ProtocolVersion::new("1,2,6-7").unwrap();
332    }
333
334    #[test]
335    fn protocol_version_create_invalid_version() {
336        if let Ok(version) = ProtocolVersion::new("1,2,7-3") {
337            panic!("Version considered valid! {}", version);
338        }
339    }
340
341    #[test]
342    fn protocol_version_match_highest_version() {
343        let client = ProtocolVersion::new("1-7").unwrap();
344        let server = ProtocolVersion::new("1,3,5").unwrap();
345        assert_eq!(5, client.get_max_matching_version(&server).unwrap());
346    }
347
348    #[test]
349    fn protocol_version_no_version_match() {
350        let client = ProtocolVersion::new("1-3,5-7").unwrap();
351        let server = ProtocolVersion::new("4,8-12").unwrap();
352        if let Ok(x) = client.get_max_matching_version(&server) {
353            panic!("Matching version found! {}", x);
354        }
355    }
356}