scram_rs/
scram_common.rs

1/*-
2 * Scram-rs - a SCRAM authentification authorization library
3 * 
4 * Copyright (C) 2021  Aleksandr Morozov
5 * Copyright (C) 2025 Aleksandr Morozov
6 * 
7 * The syslog-rs crate can be redistributed and/or modified
8 * under the terms of either of the following licenses:
9 *
10 *   1. the Mozilla Public License Version 2.0 (the “MPL”) OR
11 *
12 *   2. The MIT License (MIT)
13 *                     
14 *   3. EUROPEAN UNION PUBLIC LICENCE v. 1.2 EUPL © the European Union 2007, 2016
15 */
16
17
18#[cfg(feature = "std")]
19use std::fmt;
20#[cfg(not(feature = "std"))]
21use core::fmt;
22
23#[cfg(feature = "std")]
24use std::num::NonZeroU32;
25#[cfg(not(feature = "std"))]
26use core::num::NonZeroU32;
27
28#[cfg(not(feature = "std"))]
29use alloc::string::String;
30
31#[cfg(not(feature = "std"))]
32use alloc::vec::Vec;
33
34#[cfg(not(feature = "std"))]
35use alloc::vec;
36
37#[cfg(not(feature = "std"))]
38use alloc::format;
39
40#[cfg(not(feature = "std"))]
41use crate::alloc::string::ToString;
42
43#[cfg(feature = "use_ring")]
44use ring::rand::{SecureRandom, SystemRandom};
45
46use crate::ScramServerError;
47
48use super::scram_error::{ScramResult, ScramErrorCode};
49use super::{scram_error, scram_error_map};
50
51/// A numeric alias for the [SCRAM_TYPES]. If any changes were made in
52/// [SCRAM_TYPES] then verify that [ScramTypeAlias] is in order.
53#[repr(usize)]
54#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
55pub enum ScramTypeAlias
56{
57    #[cfg(not(feature = "exclude_sha1"))]
58    Sha1 = 0,
59
60    Sha256 = 1,
61    Sha256Plus = 2,
62    Sha512 = 3,
63    Sha512Plus = 4,
64}
65
66impl From<ScramTypeAlias> for usize
67{
68    fn from(value: ScramTypeAlias) -> Self 
69    {
70        return value as usize;
71    }
72}
73
74/// A structured data about supported mechanisms
75#[derive(Debug, Eq, PartialEq, Clone, Copy)]
76pub struct ScramType
77{
78    /// Scram type encoded as in RFC without trailing \r\n or \n
79    pub scram_name: &'static str,
80
81    pub scram_alias: ScramTypeAlias,
82
83    /// Is channel binding supported (-PLUS)
84    pub scram_chan_bind: bool,
85}
86
87impl PartialEq<str> for ScramType
88{
89    fn eq(&self, other: &str) -> bool 
90    {
91        return self.scram_name == other;
92    }
93}
94
95impl PartialEq<ScramTypeAlias> for ScramType
96{
97    fn eq(&self, other: &ScramTypeAlias) -> bool 
98    {
99        return self.scram_alias == *other;
100    }
101}
102
103impl fmt::Display for ScramType
104{
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result 
106    {
107        write!(f, "scram: {}, channel_bind: {}", self.scram_name, self.scram_chan_bind)
108    }
109}
110
111#[cfg(not(feature = "exclude_sha1"))]
112pub const SCRAM_TYPE_1: ScramType =         ScramType{scram_name:"SCRAM-SHA-1",         scram_alias: ScramTypeAlias::Sha1,          scram_chan_bind: false};
113pub const SCRAM_TYPE_256: ScramType =       ScramType{scram_name:"SCRAM-SHA-256",       scram_alias: ScramTypeAlias::Sha256,        scram_chan_bind: false};
114pub const SCRAM_TYPE_256_PLUS: ScramType =  ScramType{scram_name:"SCRAM-SHA-256-PLUS",  scram_alias: ScramTypeAlias::Sha256Plus,    scram_chan_bind: true};
115pub const SCRAM_TYPE_512: ScramType =       ScramType{scram_name:"SCRAM-SHA-512",       scram_alias: ScramTypeAlias::Sha512,        scram_chan_bind: false};
116pub const SCRAM_TYPE_512_PLUS: ScramType =  ScramType{scram_name:"SCRAM-SHA-512-PLUS",  scram_alias: ScramTypeAlias::Sha512Plus,    scram_chan_bind: true};
117
118/// All supported SCRAM types.
119#[derive(Debug, Clone)]
120pub struct ScramTypes(&'static [ScramType]);
121
122/// A table of all supported versions.
123pub const SCRAM_TYPES: &'static ScramTypes = 
124    &ScramTypes(
125        &[
126            #[cfg(not(feature = "exclude_sha1"))]
127            SCRAM_TYPE_1,
128
129            SCRAM_TYPE_256,
130            SCRAM_TYPE_256_PLUS,
131            SCRAM_TYPE_512,
132            SCRAM_TYPE_512_PLUS,
133        ]
134    );
135
136impl ScramTypes
137{
138    /// Creates a new table which can be used later. It also can be used to construct
139    /// overrided table during compilation.
140    pub const 
141    fn new(table: &'static [ScramType]) -> Self
142    {
143        return ScramTypes(table);
144    }
145
146    /// Outputs all supported types with separator.
147    /// 
148    /// # Arguments
149    /// 
150    /// * `sep` - a [str] which should separate the output.
151    /// 
152    /// # Returns 
153    /// 
154    /// A [String] is retuned.
155    pub 
156    fn adrvertise<S: AsRef<str>>(&self, sep: S) -> String
157    {
158        return 
159            self
160                .0
161                .iter()
162                .map(|f| f.scram_name)
163                .collect::<Vec<&'static str>>()
164                .join(sep.as_ref());
165    }
166
167    /// Outputs all supported types to [fmt] with separator `sep`.
168    pub 
169    fn advertise_to_fmt<S: AsRef<str>>(&self, f: &mut fmt::Formatter, sep: S) -> fmt::Result 
170    {
171        for (scr_type, i) in self.0.iter().zip(0..self.0.len())
172        {
173            write!(f, "{}", scr_type.scram_name)?;
174
175            if i+1 < self.0.len()
176            {
177                write!(f, "{}", sep.as_ref())?;
178            }
179        }
180
181        return Ok(());
182    }
183
184    /// Retrieves the SCRAM type by name which are hardcoded in [SCRAM_TYPES] 
185    /// i.e SCRAM-SHA256.
186    /// 
187    /// # Arguments
188    /// 
189    /// * `scram` - a scram auth type
190    /// 
191    /// # Returns
192    /// 
193    /// * [ScramResult] - a reference to record from table with static lifetime
194    ///                     or Error [ScramErrorCode::ExternalError] if not found
195    pub 
196    fn get_scramtype<S: AsRef<str>>(&self, scram: S) -> ScramResult<&'static ScramType>
197    {
198        let scram_name = scram.as_ref();
199
200        for scr_type in self.0.iter()
201        {
202            if scr_type == scram_name
203            {
204                return Ok(scr_type);
205            }
206        }
207
208        scram_error!(ScramErrorCode::ExternalError, ScramServerError::OtherError, 
209            "unknown scram type: {}", scram_name);
210    }
211
212    /// Retrieves the SCRAM type from [SCRAM_TYPES] by the numeric alias which 
213    /// are hardcoded in [ScramTypeAlias] 
214    /// i.e SCRAM-SHA256.
215    /// 
216    /// # Arguments
217    /// 
218    /// * `scram` - a scram numeric auth type [ScramTypeAlias]
219    /// 
220    /// # Returns
221    /// 
222    /// * [ScramResult] - a reference to record from table with static lifetime
223    ///                     or Error [ScramErrorCode::ExternalError] if not found
224    pub 
225    fn get_scramtype_numeric(&self, scram: ScramTypeAlias) -> ScramResult<&'static ScramType>
226    {
227        // binary search would be faster, but the list should be strictly sorted!.
228
229        for scr_type in self.0.iter()
230        {
231            if scr_type == &scram
232            {
233                return Ok(scr_type);
234            }
235        }
236
237        scram_error!(ScramErrorCode::ExternalError, ScramServerError::OtherError,
238            "unknown scram type: {:?}", scram);
239    }
240}
241
242
243pub struct ScramCommon{}
244impl ScramCommon
245{
246    /// A default raw (non base64) nonce length
247    pub const SCRAM_RAW_NONCE_LEN: usize = 32;
248
249    /// A mock salt default len
250    pub const MOCK_AUTH_NONCE_LEN: usize = 16;
251
252    /// Default HMAC iterations
253    pub const SCRAM_DEFAULT_SALT_ITER: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(4096) };
254
255    pub const SCRAM_MAX_ITERS: u32 = 999999999;
256
257    /// Generates random secuence of bytes. Slower than unsafe function.
258    /// 
259    /// # Arguments
260    /// 
261    /// * `len` - a length of the array
262    /// 
263    /// # Returns
264    /// 
265    /// * [ScramResult] Ok - elements or Error
266    pub 
267    fn sc_random(len: usize) -> ScramResult<Vec<u8>>
268    {
269        let mut data = vec![0_u8; len];
270
271        getrandom::fill(&mut data)
272            .map_err(|e| 
273                scram_error_map!(ScramErrorCode::ExternalError, ScramServerError::OtherError, 
274                    "scram getrandom err, {}", e)
275            )?;
276
277        return Ok(data);
278    }
279
280    /// Generates random secuence of bytes in unsafe way i.e not initializing the vec.
281    /// 
282    /// If capacity will be larger (if the len is not aligned) than requested, will contain 
283    /// garbadge at the rest indexes.
284    /// 
285    /// # Panic
286    /// 
287    /// Does not panic
288    /// 
289    /// # Arguments
290    /// 
291    /// * `len` - a length of the array
292    /// 
293    /// # Returns
294    /// 
295    /// A [ScramResult] is returned:
296    /// 
297    /// * [Result::Ok] with the [Vec] if [u8]
298    /// 
299    /// * [Result::Err] with error: 
300    /// 
301    pub unsafe 
302    fn sc_random_unsafe(len: usize) -> ScramResult<Vec<u8>>
303    {
304        if len == 0
305        {
306            scram_error!(
307                ScramErrorCode::ExternalError, ScramServerError::OtherError, 
308                "sc_random_unsafe vec requested, req: {} can not be zero", 
309                len,
310            );
311        }
312
313        let mut data = Vec::<u8>::with_capacity(len);
314
315        // override length
316        unsafe { data.set_len(len) };
317
318        getrandom::fill(&mut data)
319            .map_err(|e| 
320                scram_error_map!(ScramErrorCode::ExternalError, ScramServerError::OtherError, "scram getrandom err, {}", e)
321            )?;
322
323        return Ok(data);
324    }
325
326    /// Generates random secuence of bytes using Ring crate.
327    /// 
328    /// Considered unsafe because of the Ring crate.
329    /// 
330    /// # Arguments
331    /// 
332    /// * `len` - a length of the array
333    /// 
334    /// # Returns
335    /// 
336    /// * [ScramResult] Ok - elements or Error
337    #[cfg(feature = "use_ring")]
338    pub 
339    fn sc_random_ring_secure(len: usize) -> ScramResult<Vec<u8>>
340    {
341        let mut data = vec![0_u8; len];
342
343        let sys_random = SystemRandom::new();
344
345        sys_random
346            .fill(&mut data)
347            .map_err(|e| 
348                scram_error_map!(ScramErrorCode::ExternalError, ScramServerError::OtherError, "scram ring SystemRandom err, {}", e)
349            )?;
350
351        return Ok(data);    
352    }
353
354    /// Generates random secuence of bytes using Ring crate, but Ring crate is
355    /// not enabled, so the `sc_random` will be used.
356    /// 
357    /// # Arguments
358    /// 
359    /// * `len` - a length of the array
360    /// 
361    /// # Returns
362    /// 
363    /// * [ScramResult] Ok - elements or Error
364    #[cfg(not(feature = "use_ring"))]
365    pub 
366    fn sc_random_ring_secure(len: usize) -> ScramResult<Vec<u8>>
367    {
368        return Self::sc_random(len);
369    }
370}
371
372impl ScramCommon
373{
374    pub(crate)
375    fn sanitize_char(c: char) -> String
376    {
377        if c.is_ascii_graphic() == true
378        {
379            return c.to_string();
380        }
381        else
382        {
383            let mut buf = [0_u8; 4];
384                c.encode_utf8(&mut buf);
385
386            let formatted: String = 
387                buf[0..c.len_utf8()].into_iter()
388                    .map(|c| format!("\\x{:02x}", c))
389                    .collect();
390
391            return formatted;
392        }
393    }
394
395    #[allow(dead_code)]
396    pub(crate)
397    fn sanitize_str(st: &str) -> String
398    {
399        let mut out = String::with_capacity(st.len());
400
401        for c in st.chars()
402        {
403            if c.is_ascii_alphanumeric() == true ||
404                c.is_ascii_punctuation() == true ||
405                c == ' '
406            {
407                out.push(c);
408            }
409            else
410            {
411                let mut buf = [0_u8; 4];
412                c.encode_utf8(&mut buf);
413
414                let formatted: String = 
415                    buf[0..c.len_utf8()].into_iter()
416                        .map(|c| format!("\\x{:02x}", c))
417                        .collect();
418
419                out.push_str(&formatted);
420            }
421        }
422
423        return out;
424    }
425
426    pub(crate)
427    fn sanitize_str_unicode(st: &str) -> String
428    {
429        let mut out = String::with_capacity(st.len());
430
431        for c in st.chars()
432        {
433            if c.is_alphanumeric() == true ||
434                c.is_ascii_punctuation() == true ||
435                c == ' '
436            {
437                out.push(c);
438            }
439            else
440            {
441                let mut buf = [0_u8; 4];
442                c.encode_utf8(&mut buf);
443
444                let formatted: String = 
445                    buf[0..c.len_utf8()].into_iter()
446                        .map(|c| format!("\\x{:02x}", c))
447                        .collect();
448
449                out.push_str(&formatted);
450            }
451        }
452
453        return out;
454    }
455}
456
457#[cfg(feature = "std")]
458#[cfg(test)]
459mod tests
460{
461    use std::time::Instant;
462
463    use crate::{ScramCommon, ScramResult, ScramTypeAlias, SCRAM_TYPES};
464
465    #[test]
466    fn test_sc_random()
467    {
468        let start = Instant::now();
469
470        let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random(128);
471
472        let end = start.elapsed();
473
474        println!("elapsed: {:?}", end);
475
476        assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
477
478        let res = res.unwrap();
479        assert_eq!(res.len(), 128, "length: {}", res.len());
480
481        res.iter().for_each(|i| print!("{:02X}", i));
482        println!(""); 
483
484        // ---
485        let start = Instant::now();
486
487        let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random(128);
488
489        let end = start.elapsed();
490
491        println!("elapsed: {:?}", end);
492
493        assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
494
495        let res = res.unwrap();
496        assert_eq!(res.len(), 128, "length: {}", res.len());
497
498        res.iter().for_each(|i| print!("{:02X}", i));
499        println!(""); 
500
501    }
502
503
504    #[test]
505    fn test_sc_random_unsafe()
506    {
507        let start = Instant::now();
508
509        let res: Result<Vec<u8>, crate::ScramRuntimeError> = unsafe { ScramCommon::sc_random_unsafe(128) };
510
511        let end = start.elapsed();
512
513        println!("elapsed: {:?}", end);
514
515
516        assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
517
518        let res = res.unwrap();
519        assert_eq!(res.len(), 128, "length: {}", res.len());
520
521        res.iter().for_each(|i| print!("{:02X}", i));
522        println!(""); 
523
524        // ---
525        let start = Instant::now();
526
527        let res: Result<Vec<u8>, crate::ScramRuntimeError> = unsafe { ScramCommon::sc_random_unsafe(128) };
528
529        let end = start.elapsed();
530
531        println!("elapsed: {:?}", end);
532
533
534        assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
535
536        let res = res.unwrap();
537        assert_eq!(res.len(), 128, "length: {}", res.len());
538
539        res.iter().for_each(|i| print!("{:02X}", i));
540        println!(""); 
541    }
542
543
544    #[cfg(feature = "use_ring")]
545    #[test]
546    fn test_sc_random_ring()
547    {
548        let start = Instant::now();
549
550        let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random_ring_secure(128);
551
552        let end = start.elapsed();
553
554        println!("elapsed: {:?}", end);
555
556        assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
557
558        let res = res.unwrap();
559        assert_eq!(res.len(), 128, "length: {}", res.len());
560
561        res.iter().for_each(|i| print!("{:02X}", i));
562        println!(""); 
563
564        // --- 
565        
566        let start = Instant::now();
567
568        let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random_ring_secure(128);
569
570        let end = start.elapsed();
571
572        println!("elapsed: {:?}", end);
573
574        assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
575
576        let res = res.unwrap();
577        assert_eq!(res.len(), 128, "length: {}", res.len());
578
579        res.iter().for_each(|i| print!("{:02X}", i));
580        println!("");
581    }
582
583    #[test]
584    fn sanitize_unicode()
585    {
586        let res = ScramCommon::sanitize_str_unicode("る\n\0bp234");
587
588        assert_eq!(res.as_str(), "る\\x0a\\x00bp234");
589    }
590
591    #[test]
592    fn test_scram_types()
593    {
594        let start = Instant::now();
595        #[cfg(not(feature = "exclude_sha1"))]
596        assert_eq!(
597            SCRAM_TYPES.adrvertise(", "), 
598            "SCRAM-SHA-1, SCRAM-SHA-256, SCRAM-SHA-256-PLUS, SCRAM-SHA-512, SCRAM-SHA-512-PLUS"
599        );
600
601        #[cfg(feature = "exclude_sha1")]
602        assert_eq!(
603            SCRAM_TYPES.adrvertise(", "), 
604            "SCRAM-SHA-256, SCRAM-SHA-256-PLUS, SCRAM-SHA-512, SCRAM-SHA-512-PLUS"
605        );
606        let el = start.elapsed();
607        println!("took: {:?}", el);
608
609        // --
610        let mut ind = 0;
611
612        let start = Instant::now();
613        #[cfg(not(feature = "exclude_sha1"))]
614        assert_eq!(
615            SCRAM_TYPES.get_scramtype("SCRAM-SHA-1"),
616            ScramResult::Ok(&SCRAM_TYPES.0[ind])
617        );
618
619        #[cfg(not(feature = "exclude_sha1"))]
620        {
621            ind += 1;
622        }
623
624        let el = start.elapsed();
625        println!("took: {:?}", el);
626
627        let start = Instant::now();
628        assert_eq!(
629            SCRAM_TYPES.get_scramtype("SCRAM-SHA-256"),
630            ScramResult::Ok(&SCRAM_TYPES.0[ind])
631        );
632        let el = start.elapsed();
633        println!("took: {:?}", el);
634
635        ind += 1;
636
637        let start = Instant::now();
638        assert_eq!(
639            SCRAM_TYPES.get_scramtype("SCRAM-SHA-256-PLUS"),
640            ScramResult::Ok(&SCRAM_TYPES.0[ind])
641        );
642        let el = start.elapsed();
643        println!("took: {:?}", el);
644
645        ind += 1;
646
647        let start = Instant::now();
648        assert_eq!(
649            SCRAM_TYPES.get_scramtype("SCRAM-SHA-512"),
650            ScramResult::Ok(&SCRAM_TYPES.0[ind])
651        );
652        let el = start.elapsed();
653        println!("took: {:?}", el);
654
655        ind += 1;
656
657        assert_eq!(
658            SCRAM_TYPES.get_scramtype("SCRAM-SHA-512-PLUS"),
659            ScramResult::Ok(&SCRAM_TYPES.0[ind])
660        );
661
662        // -- 
663
664        ind = 0;
665
666        let start = Instant::now();
667        #[cfg(not(feature = "exclude_sha1"))]
668        assert_eq!(
669            SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha1),
670            ScramResult::Ok(&SCRAM_TYPES.0[0])
671        );
672
673        #[cfg(not(feature = "exclude_sha1"))]
674        {
675            ind += 1;
676        }
677        let el = start.elapsed();
678        println!("took: {:?}", el);
679
680        let start = Instant::now();
681        assert_eq!(
682            SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha256),
683            ScramResult::Ok(&SCRAM_TYPES.0[ind])
684        );
685        let el = start.elapsed();
686        println!("took: {:?}", el);
687
688        ind += 1;
689
690        let start = Instant::now();
691        assert_eq!(
692            SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha256Plus),
693            ScramResult::Ok(&SCRAM_TYPES.0[ind])
694        );
695        let el = start.elapsed();
696        println!("took: {:?}", el);
697
698        ind += 1;
699
700        assert_eq!(
701            SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha512),
702            ScramResult::Ok(&SCRAM_TYPES.0[ind])
703        );
704
705        ind += 1;
706        assert_eq!(
707            SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha512Plus),
708            ScramResult::Ok(&SCRAM_TYPES.0[ind])
709        );
710    }
711}