1#[cfg(feature = "std")]
19use std::fmt;
20#[cfg(not(feature = "std"))]
21use core::fmt;
22
23#[cfg(feature = "std")]
24use std::num::NonZeroU32;
25#[cfg(not(feature = "std"))]
26use core::num::NonZeroU32;
27
28#[cfg(not(feature = "std"))]
29use alloc::string::String;
30
31#[cfg(not(feature = "std"))]
32use alloc::vec::Vec;
33
34#[cfg(not(feature = "std"))]
35use alloc::vec;
36
37#[cfg(not(feature = "std"))]
38use alloc::format;
39
40#[cfg(not(feature = "std"))]
41use crate::alloc::string::ToString;
42
43#[cfg(feature = "use_ring")]
44use ring::rand::{SecureRandom, SystemRandom};
45
46use crate::ScramServerError;
47
48use super::scram_error::{ScramResult, ScramErrorCode};
49use super::{scram_error, scram_error_map};
50
51#[repr(usize)]
54#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
55pub enum ScramTypeAlias
56{
57 #[cfg(not(feature = "exclude_sha1"))]
58 Sha1 = 0,
59
60 Sha256 = 1,
61 Sha256Plus = 2,
62 Sha512 = 3,
63 Sha512Plus = 4,
64}
65
66impl From<ScramTypeAlias> for usize
67{
68 fn from(value: ScramTypeAlias) -> Self
69 {
70 return value as usize;
71 }
72}
73
74#[derive(Debug, Eq, PartialEq, Clone, Copy)]
76pub struct ScramType
77{
78 pub scram_name: &'static str,
80
81 pub scram_alias: ScramTypeAlias,
82
83 pub scram_chan_bind: bool,
85}
86
87impl PartialEq<str> for ScramType
88{
89 fn eq(&self, other: &str) -> bool
90 {
91 return self.scram_name == other;
92 }
93}
94
95impl PartialEq<ScramTypeAlias> for ScramType
96{
97 fn eq(&self, other: &ScramTypeAlias) -> bool
98 {
99 return self.scram_alias == *other;
100 }
101}
102
103impl fmt::Display for ScramType
104{
105 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result
106 {
107 write!(f, "scram: {}, channel_bind: {}", self.scram_name, self.scram_chan_bind)
108 }
109}
110
111#[cfg(not(feature = "exclude_sha1"))]
112pub const SCRAM_TYPE_1: ScramType = ScramType{scram_name:"SCRAM-SHA-1", scram_alias: ScramTypeAlias::Sha1, scram_chan_bind: false};
113pub const SCRAM_TYPE_256: ScramType = ScramType{scram_name:"SCRAM-SHA-256", scram_alias: ScramTypeAlias::Sha256, scram_chan_bind: false};
114pub const SCRAM_TYPE_256_PLUS: ScramType = ScramType{scram_name:"SCRAM-SHA-256-PLUS", scram_alias: ScramTypeAlias::Sha256Plus, scram_chan_bind: true};
115pub const SCRAM_TYPE_512: ScramType = ScramType{scram_name:"SCRAM-SHA-512", scram_alias: ScramTypeAlias::Sha512, scram_chan_bind: false};
116pub const SCRAM_TYPE_512_PLUS: ScramType = ScramType{scram_name:"SCRAM-SHA-512-PLUS", scram_alias: ScramTypeAlias::Sha512Plus, scram_chan_bind: true};
117
118#[derive(Debug, Clone)]
120pub struct ScramTypes(&'static [ScramType]);
121
122pub const SCRAM_TYPES: &'static ScramTypes =
124 &ScramTypes(
125 &[
126 #[cfg(not(feature = "exclude_sha1"))]
127 SCRAM_TYPE_1,
128
129 SCRAM_TYPE_256,
130 SCRAM_TYPE_256_PLUS,
131 SCRAM_TYPE_512,
132 SCRAM_TYPE_512_PLUS,
133 ]
134 );
135
136impl ScramTypes
137{
138 pub const
141 fn new(table: &'static [ScramType]) -> Self
142 {
143 return ScramTypes(table);
144 }
145
146 pub
156 fn adrvertise<S: AsRef<str>>(&self, sep: S) -> String
157 {
158 return
159 self
160 .0
161 .iter()
162 .map(|f| f.scram_name)
163 .collect::<Vec<&'static str>>()
164 .join(sep.as_ref());
165 }
166
167 pub
169 fn advertise_to_fmt<S: AsRef<str>>(&self, f: &mut fmt::Formatter, sep: S) -> fmt::Result
170 {
171 for (scr_type, i) in self.0.iter().zip(0..self.0.len())
172 {
173 write!(f, "{}", scr_type.scram_name)?;
174
175 if i+1 < self.0.len()
176 {
177 write!(f, "{}", sep.as_ref())?;
178 }
179 }
180
181 return Ok(());
182 }
183
184 pub
196 fn get_scramtype<S: AsRef<str>>(&self, scram: S) -> ScramResult<&'static ScramType>
197 {
198 let scram_name = scram.as_ref();
199
200 for scr_type in self.0.iter()
201 {
202 if scr_type == scram_name
203 {
204 return Ok(scr_type);
205 }
206 }
207
208 scram_error!(ScramErrorCode::ExternalError, ScramServerError::OtherError,
209 "unknown scram type: {}", scram_name);
210 }
211
212 pub
225 fn get_scramtype_numeric(&self, scram: ScramTypeAlias) -> ScramResult<&'static ScramType>
226 {
227 for scr_type in self.0.iter()
230 {
231 if scr_type == &scram
232 {
233 return Ok(scr_type);
234 }
235 }
236
237 scram_error!(ScramErrorCode::ExternalError, ScramServerError::OtherError,
238 "unknown scram type: {:?}", scram);
239 }
240}
241
242
243pub struct ScramCommon{}
244impl ScramCommon
245{
246 pub const SCRAM_RAW_NONCE_LEN: usize = 32;
248
249 pub const MOCK_AUTH_NONCE_LEN: usize = 16;
251
252 pub const SCRAM_DEFAULT_SALT_ITER: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(4096) };
254
255 pub const SCRAM_MAX_ITERS: u32 = 999999999;
256
257 pub
267 fn sc_random(len: usize) -> ScramResult<Vec<u8>>
268 {
269 let mut data = vec![0_u8; len];
270
271 getrandom::fill(&mut data)
272 .map_err(|e|
273 scram_error_map!(ScramErrorCode::ExternalError, ScramServerError::OtherError,
274 "scram getrandom err, {}", e)
275 )?;
276
277 return Ok(data);
278 }
279
280 pub unsafe
302 fn sc_random_unsafe(len: usize) -> ScramResult<Vec<u8>>
303 {
304 if len == 0
305 {
306 scram_error!(
307 ScramErrorCode::ExternalError, ScramServerError::OtherError,
308 "sc_random_unsafe vec requested, req: {} can not be zero",
309 len,
310 );
311 }
312
313 let mut data = Vec::<u8>::with_capacity(len);
314
315 unsafe { data.set_len(len) };
317
318 getrandom::fill(&mut data)
319 .map_err(|e|
320 scram_error_map!(ScramErrorCode::ExternalError, ScramServerError::OtherError, "scram getrandom err, {}", e)
321 )?;
322
323 return Ok(data);
324 }
325
326 #[cfg(feature = "use_ring")]
338 pub
339 fn sc_random_ring_secure(len: usize) -> ScramResult<Vec<u8>>
340 {
341 let mut data = vec![0_u8; len];
342
343 let sys_random = SystemRandom::new();
344
345 sys_random
346 .fill(&mut data)
347 .map_err(|e|
348 scram_error_map!(ScramErrorCode::ExternalError, ScramServerError::OtherError, "scram ring SystemRandom err, {}", e)
349 )?;
350
351 return Ok(data);
352 }
353
354 #[cfg(not(feature = "use_ring"))]
365 pub
366 fn sc_random_ring_secure(len: usize) -> ScramResult<Vec<u8>>
367 {
368 return Self::sc_random(len);
369 }
370}
371
372impl ScramCommon
373{
374 pub(crate)
375 fn sanitize_char(c: char) -> String
376 {
377 if c.is_ascii_graphic() == true
378 {
379 return c.to_string();
380 }
381 else
382 {
383 let mut buf = [0_u8; 4];
384 c.encode_utf8(&mut buf);
385
386 let formatted: String =
387 buf[0..c.len_utf8()].into_iter()
388 .map(|c| format!("\\x{:02x}", c))
389 .collect();
390
391 return formatted;
392 }
393 }
394
395 #[allow(dead_code)]
396 pub(crate)
397 fn sanitize_str(st: &str) -> String
398 {
399 let mut out = String::with_capacity(st.len());
400
401 for c in st.chars()
402 {
403 if c.is_ascii_alphanumeric() == true ||
404 c.is_ascii_punctuation() == true ||
405 c == ' '
406 {
407 out.push(c);
408 }
409 else
410 {
411 let mut buf = [0_u8; 4];
412 c.encode_utf8(&mut buf);
413
414 let formatted: String =
415 buf[0..c.len_utf8()].into_iter()
416 .map(|c| format!("\\x{:02x}", c))
417 .collect();
418
419 out.push_str(&formatted);
420 }
421 }
422
423 return out;
424 }
425
426 pub(crate)
427 fn sanitize_str_unicode(st: &str) -> String
428 {
429 let mut out = String::with_capacity(st.len());
430
431 for c in st.chars()
432 {
433 if c.is_alphanumeric() == true ||
434 c.is_ascii_punctuation() == true ||
435 c == ' '
436 {
437 out.push(c);
438 }
439 else
440 {
441 let mut buf = [0_u8; 4];
442 c.encode_utf8(&mut buf);
443
444 let formatted: String =
445 buf[0..c.len_utf8()].into_iter()
446 .map(|c| format!("\\x{:02x}", c))
447 .collect();
448
449 out.push_str(&formatted);
450 }
451 }
452
453 return out;
454 }
455}
456
457#[cfg(feature = "std")]
458#[cfg(test)]
459mod tests
460{
461 use std::time::Instant;
462
463 use crate::{ScramCommon, ScramResult, ScramTypeAlias, SCRAM_TYPES};
464
465 #[test]
466 fn test_sc_random()
467 {
468 let start = Instant::now();
469
470 let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random(128);
471
472 let end = start.elapsed();
473
474 println!("elapsed: {:?}", end);
475
476 assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
477
478 let res = res.unwrap();
479 assert_eq!(res.len(), 128, "length: {}", res.len());
480
481 res.iter().for_each(|i| print!("{:02X}", i));
482 println!("");
483
484 let start = Instant::now();
486
487 let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random(128);
488
489 let end = start.elapsed();
490
491 println!("elapsed: {:?}", end);
492
493 assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
494
495 let res = res.unwrap();
496 assert_eq!(res.len(), 128, "length: {}", res.len());
497
498 res.iter().for_each(|i| print!("{:02X}", i));
499 println!("");
500
501 }
502
503
504 #[test]
505 fn test_sc_random_unsafe()
506 {
507 let start = Instant::now();
508
509 let res: Result<Vec<u8>, crate::ScramRuntimeError> = unsafe { ScramCommon::sc_random_unsafe(128) };
510
511 let end = start.elapsed();
512
513 println!("elapsed: {:?}", end);
514
515
516 assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
517
518 let res = res.unwrap();
519 assert_eq!(res.len(), 128, "length: {}", res.len());
520
521 res.iter().for_each(|i| print!("{:02X}", i));
522 println!("");
523
524 let start = Instant::now();
526
527 let res: Result<Vec<u8>, crate::ScramRuntimeError> = unsafe { ScramCommon::sc_random_unsafe(128) };
528
529 let end = start.elapsed();
530
531 println!("elapsed: {:?}", end);
532
533
534 assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
535
536 let res = res.unwrap();
537 assert_eq!(res.len(), 128, "length: {}", res.len());
538
539 res.iter().for_each(|i| print!("{:02X}", i));
540 println!("");
541 }
542
543
544 #[cfg(feature = "use_ring")]
545 #[test]
546 fn test_sc_random_ring()
547 {
548 let start = Instant::now();
549
550 let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random_ring_secure(128);
551
552 let end = start.elapsed();
553
554 println!("elapsed: {:?}", end);
555
556 assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
557
558 let res = res.unwrap();
559 assert_eq!(res.len(), 128, "length: {}", res.len());
560
561 res.iter().for_each(|i| print!("{:02X}", i));
562 println!("");
563
564 let start = Instant::now();
567
568 let res: Result<Vec<u8>, crate::ScramRuntimeError> = ScramCommon::sc_random_ring_secure(128);
569
570 let end = start.elapsed();
571
572 println!("elapsed: {:?}", end);
573
574 assert_eq!(res.is_ok(), true, "error: '{}'", res.err().unwrap());
575
576 let res = res.unwrap();
577 assert_eq!(res.len(), 128, "length: {}", res.len());
578
579 res.iter().for_each(|i| print!("{:02X}", i));
580 println!("");
581 }
582
583 #[test]
584 fn sanitize_unicode()
585 {
586 let res = ScramCommon::sanitize_str_unicode("る\n\0bp234");
587
588 assert_eq!(res.as_str(), "る\\x0a\\x00bp234");
589 }
590
591 #[test]
592 fn test_scram_types()
593 {
594 let start = Instant::now();
595 #[cfg(not(feature = "exclude_sha1"))]
596 assert_eq!(
597 SCRAM_TYPES.adrvertise(", "),
598 "SCRAM-SHA-1, SCRAM-SHA-256, SCRAM-SHA-256-PLUS, SCRAM-SHA-512, SCRAM-SHA-512-PLUS"
599 );
600
601 #[cfg(feature = "exclude_sha1")]
602 assert_eq!(
603 SCRAM_TYPES.adrvertise(", "),
604 "SCRAM-SHA-256, SCRAM-SHA-256-PLUS, SCRAM-SHA-512, SCRAM-SHA-512-PLUS"
605 );
606 let el = start.elapsed();
607 println!("took: {:?}", el);
608
609 let mut ind = 0;
611
612 let start = Instant::now();
613 #[cfg(not(feature = "exclude_sha1"))]
614 assert_eq!(
615 SCRAM_TYPES.get_scramtype("SCRAM-SHA-1"),
616 ScramResult::Ok(&SCRAM_TYPES.0[ind])
617 );
618
619 #[cfg(not(feature = "exclude_sha1"))]
620 {
621 ind += 1;
622 }
623
624 let el = start.elapsed();
625 println!("took: {:?}", el);
626
627 let start = Instant::now();
628 assert_eq!(
629 SCRAM_TYPES.get_scramtype("SCRAM-SHA-256"),
630 ScramResult::Ok(&SCRAM_TYPES.0[ind])
631 );
632 let el = start.elapsed();
633 println!("took: {:?}", el);
634
635 ind += 1;
636
637 let start = Instant::now();
638 assert_eq!(
639 SCRAM_TYPES.get_scramtype("SCRAM-SHA-256-PLUS"),
640 ScramResult::Ok(&SCRAM_TYPES.0[ind])
641 );
642 let el = start.elapsed();
643 println!("took: {:?}", el);
644
645 ind += 1;
646
647 let start = Instant::now();
648 assert_eq!(
649 SCRAM_TYPES.get_scramtype("SCRAM-SHA-512"),
650 ScramResult::Ok(&SCRAM_TYPES.0[ind])
651 );
652 let el = start.elapsed();
653 println!("took: {:?}", el);
654
655 ind += 1;
656
657 assert_eq!(
658 SCRAM_TYPES.get_scramtype("SCRAM-SHA-512-PLUS"),
659 ScramResult::Ok(&SCRAM_TYPES.0[ind])
660 );
661
662 ind = 0;
665
666 let start = Instant::now();
667 #[cfg(not(feature = "exclude_sha1"))]
668 assert_eq!(
669 SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha1),
670 ScramResult::Ok(&SCRAM_TYPES.0[0])
671 );
672
673 #[cfg(not(feature = "exclude_sha1"))]
674 {
675 ind += 1;
676 }
677 let el = start.elapsed();
678 println!("took: {:?}", el);
679
680 let start = Instant::now();
681 assert_eq!(
682 SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha256),
683 ScramResult::Ok(&SCRAM_TYPES.0[ind])
684 );
685 let el = start.elapsed();
686 println!("took: {:?}", el);
687
688 ind += 1;
689
690 let start = Instant::now();
691 assert_eq!(
692 SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha256Plus),
693 ScramResult::Ok(&SCRAM_TYPES.0[ind])
694 );
695 let el = start.elapsed();
696 println!("took: {:?}", el);
697
698 ind += 1;
699
700 assert_eq!(
701 SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha512),
702 ScramResult::Ok(&SCRAM_TYPES.0[ind])
703 );
704
705 ind += 1;
706 assert_eq!(
707 SCRAM_TYPES.get_scramtype_numeric(ScramTypeAlias::Sha512Plus),
708 ScramResult::Ok(&SCRAM_TYPES.0[ind])
709 );
710 }
711}