1#![deny(missing_docs)]
4pub mod client_request;
5pub mod error;
6pub mod server_response;
7
8use crate::error::SqrlError;
9use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
10use ed25519_dalek::{Signature, VerifyingKey};
11use std::{collections::HashMap, fmt, result};
12use url::Url;
13
14pub const SQRL_PROTOCOL: &str = "sqrl";
16
17pub const PROTOCOL_VERSIONS: &str = "1";
19
20pub type Result<G> = result::Result<G, SqrlError>;
22
23#[derive(Debug, PartialEq)]
25pub struct SqrlUrl {
26 url: Url,
27}
28
29impl SqrlUrl {
30 pub fn parse(url: &str) -> Result<Self> {
37 let parsed = Url::parse(url)?;
38 if parsed.scheme() != SQRL_PROTOCOL {
39 return Err(SqrlError::new(format!(
40 "Invalid sqrl url, incorrect protocol: {}",
41 url
42 )));
43 }
44 if parsed.domain().is_none() {
45 return Err(SqrlError::new(format!(
46 "Invalid sqrl url, missing domain: {}",
47 url
48 )));
49 }
50
51 Ok(SqrlUrl { url: parsed })
52 }
53
54 pub fn get_auth_domain(&self) -> String {
62 format!("{}{}", self.get_domain(), self.get_path())
63 }
64
65 fn get_domain(&self) -> String {
66 self.url.domain().unwrap().to_lowercase()
67 }
68
69 fn get_path(&self) -> String {
70 let path = self.url.path().strip_suffix('/').unwrap_or(self.url.path());
71 path.to_owned()
72 }
73}
74
75impl fmt::Display for SqrlUrl {
76 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77 write!(f, "{}", self.url)
78 }
79}
80
81pub(crate) fn get_or_error(
82 map: &HashMap<String, String>,
83 key: &str,
84 error_message: &str,
85) -> Result<String> {
86 match map.get(key) {
87 Some(x) => Ok(x.to_owned()),
88 None => Err(SqrlError::new(error_message.to_owned())),
89 }
90}
91
92pub(crate) fn parse_query_data(query: &str) -> Result<HashMap<String, String>> {
93 let mut map = HashMap::<String, String>::new();
94 for token in query.split('&') {
95 if let Some((key, value)) = token.split_once('=') {
96 map.insert(key.to_owned(), value.to_owned());
97 } else {
98 return Err(SqrlError::new("Invalid query data".to_owned()));
99 }
100 }
101 Ok(map)
102}
103
104pub(crate) fn decode_public_key(key: &str) -> Result<VerifyingKey> {
105 let bytes: [u8; 32];
106 match BASE64_URL_SAFE_NO_PAD.decode(key) {
107 Ok(x) => bytes = vec_to_u8_32(&x)?,
108 Err(_) => {
109 return Err(SqrlError::new(format!(
110 "Failed to decode base64 encoded public key {}",
111 key
112 )))
113 }
114 }
115
116 match VerifyingKey::from_bytes(&bytes) {
117 Ok(x) => Ok(x),
118 Err(e) => Err(SqrlError::new(format!(
119 "Failed to generate public key from {}: {}",
120 key, e
121 ))),
122 }
123}
124
125pub(crate) fn decode_signature(key: &str) -> Result<Signature> {
126 let bytes: [u8; 64];
127 match BASE64_URL_SAFE_NO_PAD.decode(key) {
128 Ok(x) => bytes = vec_to_u8_64(&x)?,
129 Err(_) => {
130 return Err(SqrlError::new(format!(
131 "Failed to decode base64 encoded signature {}",
132 key
133 )))
134 }
135 }
136
137 Ok(Signature::from_bytes(&bytes))
138}
139
140pub(crate) fn parse_newline_data(data: &str) -> Result<HashMap<String, String>> {
141 let mut map = HashMap::<String, String>::new();
142 for token in data.split('\n') {
143 if let Some((key, value)) = token.split_once('=') {
144 map.insert(key.to_owned(), value.trim().to_owned());
145 } else if !token.is_empty() {
146 return Err(SqrlError::new(format!("Invalid newline data {}", token)));
147 }
148 }
149
150 Ok(map)
151}
152
153pub(crate) fn encode_newline_data(map: &HashMap<&str, &str>) -> String {
154 let mut result = String::new();
155 for (key, value) in map.iter() {
156 result += &format!("\n{key}={value}");
157 }
158
159 result
160}
161
162pub(crate) fn vec_to_u8_32(vector: &[u8]) -> Result<[u8; 32]> {
163 let mut result = [0; 32];
164 if vector.len() != 32 {
165 return Err(SqrlError::new(format!(
166 "Error converting vec<u8> to [u8; 32]: Expected 32 bytes, but found {}",
167 vector.len()
168 )));
169 }
170
171 result[..32].copy_from_slice(&vector[..32]);
172 Ok(result)
173}
174
175pub(crate) fn vec_to_u8_64(vector: &[u8]) -> Result<[u8; 64]> {
176 let mut result = [0; 64];
177 if vector.len() != 64 {
178 return Err(SqrlError::new(format!(
179 "Error converting vec<u8> to [u8; 64]: Expected 64 bytes, but found {}",
180 vector.len()
181 )));
182 }
183
184 result[..64].copy_from_slice(&vector[..64]);
185 Ok(result)
186}
187
188#[derive(Debug, PartialEq)]
190pub struct ProtocolVersion {
191 versions: u128,
192 max_version: u8,
193}
194
195impl ProtocolVersion {
196 pub fn new(versions: &str) -> Result<Self> {
203 let mut prot = ProtocolVersion {
204 versions: 0,
205 max_version: 0,
206 };
207 for sub in versions.split(',') {
208 if sub.contains('-') {
209 let mut versions = sub.split('-');
210
211 let low: u8 = match versions.next() {
213 Some(x) => x.parse::<u8>()?,
214 None => {
215 return Err(SqrlError::new(format!("Invalid version number {}", sub)));
216 }
217 };
218 let high: u8 = match versions.next() {
219 Some(x) => x.parse::<u8>()?,
220 None => {
221 return Err(SqrlError::new(format!("Invalid version number {}", sub)));
222 }
223 };
224
225 if low >= high {
227 return Err(SqrlError::new(format!("Invalid version number {}", sub)));
228 }
229
230 for i in low..high + 1 {
232 prot.versions |= 0b00000001 << (i - 1);
233 }
234 if high > prot.max_version {
235 prot.max_version = high;
236 }
237 } else {
238 let version = sub.parse::<u8>()?;
239 prot.versions |= 0b00000001 << (version - 1);
240 if version > prot.max_version {
241 prot.max_version = version;
242 }
243 }
244 }
245
246 Ok(prot)
247 }
248
249 pub fn get_max_matching_version(&self, other: &ProtocolVersion) -> Result<u8> {
259 let min_max = if self.max_version > other.max_version {
260 other.max_version
261 } else {
262 self.max_version
263 };
264
265 let matches = self.versions & other.versions;
266
267 let bit: u128 = 0b00000001 << min_max;
269 for i in 0..min_max {
270 if matches & (bit >> i) == bit >> i {
271 return Ok(min_max - i + 1);
272 }
273 }
274
275 Err(SqrlError::new(format!(
276 "No matching supported version! Ours: {} Theirs: {}",
277 self, other
278 )))
279 }
280}
281
282impl fmt::Display for ProtocolVersion {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 let mut versions: Vec<String> = Vec::new();
285 let mut current_min: Option<u8> = None;
286 let mut bit: u128 = 0b00000001;
287 for i in 0..self.max_version {
288 if self.versions & bit == bit {
289 if current_min.is_none() {
292 current_min = Some(i);
293 }
294 } else {
295 if let Some(min) = current_min {
297 if i == min + 1 {
298 versions.push(format!("{}", min + 1));
300 } else {
301 versions.push(format!("{}-{}", min + 1, i));
302 }
303
304 current_min = None;
305 }
306 }
307
308 bit <<= 1;
309 }
310
311 if let Some(min) = current_min {
313 if self.max_version == min + 1 {
314 versions.push(format!("{}", min + 1));
316 } else {
317 versions.push(format!("{}-{}", min + 1, self.max_version));
318 }
319 }
320
321 write!(f, "{}", versions.join(","))
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn protocol_version_create_valid_version() {
331 ProtocolVersion::new("1,2,6-7").unwrap();
332 }
333
334 #[test]
335 fn protocol_version_create_invalid_version() {
336 if let Ok(version) = ProtocolVersion::new("1,2,7-3") {
337 panic!("Version considered valid! {}", version);
338 }
339 }
340
341 #[test]
342 fn protocol_version_match_highest_version() {
343 let client = ProtocolVersion::new("1-7").unwrap();
344 let server = ProtocolVersion::new("1,3,5").unwrap();
345 assert_eq!(5, client.get_max_matching_version(&server).unwrap());
346 }
347
348 #[test]
349 fn protocol_version_no_version_match() {
350 let client = ProtocolVersion::new("1-3,5-7").unwrap();
351 let server = ProtocolVersion::new("4,8-12").unwrap();
352 if let Ok(x) = client.get_max_matching_version(&server) {
353 panic!("Matching version found! {}", x);
354 }
355 }
356}