1#![deny(
40    bad_style,
41    dead_code,
42    improper_ctypes,
43    rustdoc::broken_intra_doc_links,
44    non_shorthand_field_patterns,
45    no_mangle_generic_items,
46    overflowing_literals,
47    path_statements,
48    patterns_in_fns_without_body,
49    private_bounds,
50    private_interfaces,
51    unconditional_recursion,
52    unused,
53    unused_allocation,
54    unused_comparisons,
55    unused_parens,
56    while_true,
57    missing_debug_implementations,
58    missing_copy_implementations,
59    missing_docs,
60    trivial_casts,
61    trivial_numeric_casts,
62    unnameable_types,
63    unused_extern_crates,
64    unused_import_braces,
65    unused_qualifications,
66    unused_results
67)]
68#![no_std]
69
70use generic_array::{ArrayLength, GenericArray};
71use typenum::ToInt;
72
73use zeroize::Zeroize;
74
75pub trait PseudoRandomFunctionKey {
77    type KeyHandle;
79
80    fn key_handle(&self) -> &Self::KeyHandle;
82}
83
84pub trait PseudoRandomFunction<'a> {
89    type KeyHandle;
91    type PrfOutputSize: ArrayLength<u8> + ToInt<usize>;
93    type Error;
95
96    fn init(
110        &mut self,
111        key: &'a dyn PseudoRandomFunctionKey<KeyHandle = Self::KeyHandle>,
112    ) -> Result<(), Self::Error>;
113
114    fn update(&mut self, msg: &[u8]) -> Result<(), Self::Error>;
129
130    fn finish(&mut self, out: &mut [u8]) -> Result<usize, Self::Error>;
145}
146
147#[derive(Copy, Clone, Debug)]
149pub struct CounterMode {
150    pub counter_length: usize,
152}
153
154#[derive(Copy, Clone, Debug)]
156pub struct FeedbackMode<'a> {
157    pub iv: Option<&'a [u8]>,
159    pub counter_length: Option<usize>,
161}
162
163#[derive(Copy, Clone, Debug)]
165pub struct DoublePipelineIterationMode {
166    pub counter_length: Option<usize>,
168}
169
170#[derive(Copy, Clone, Debug)]
172pub enum KDFMode<'a> {
173    CounterMode(CounterMode),
175    FeedbackMode(FeedbackMode<'a>),
177    DoublePipelineIterationMode(DoublePipelineIterationMode),
179}
180
181#[derive(Copy, Clone, Debug)]
183pub enum CounterLocation {
184    NoCounter,
186    BeforeFixedInput,
188    BeforeIter,
190    MiddleOfFixedInput(usize),
192    AfterFixedInput,
194    AfterIter,
196}
197
198#[derive(Debug)]
200pub struct FixedInput<'a> {
201    pub fixed_input: &'a [u8],
203    pub counter_location: CounterLocation,
205}
206
207#[derive(Debug)]
209pub struct SpecifiedInput<'a> {
210    pub label: &'a [u8],
212    pub context: &'a [u8],
214}
215
216#[derive(Debug)]
218pub enum InputType<'a> {
219    FixedInput(FixedInput<'a>),
222    SpecifiedInput(SpecifiedInput<'a>),
224}
225
226pub fn kbkdf<'a, T: PseudoRandomFunction<'a>>(
241    kdf_mode: &KDFMode,
242    input_type: &InputType,
243    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
244    prf: &mut T,
245    derived_key: &mut [u8],
246) -> Result<(), T::Error> {
247    match kdf_mode {
248        KDFMode::CounterMode(counter_mode) => {
249            kbkdf_counter::<T>(counter_mode, input_type, key, prf, derived_key)
250        }
251        KDFMode::FeedbackMode(feedback_mode) => {
252            kbkdf_feedback::<T>(feedback_mode, input_type, key, prf, derived_key)
253        }
254        KDFMode::DoublePipelineIterationMode(double_pipeline) => {
255            kbkdf_double_pipeline::<T>(double_pipeline, input_type, key, prf, derived_key)
256        }
257    }
258}
259
260fn kbkdf_counter<'a, T: PseudoRandomFunction<'a>>(
261    counter_mode: &CounterMode,
262    input_type: &InputType,
263    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
264    prf: &mut T,
265    derived_key: &mut [u8],
266) -> Result<(), T::Error> {
267    let l = derived_key.len() * 8;
269    let h = T::PrfOutputSize::to_int() * 8;
270    let n = calculate_counter(l, h);
271    let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
272    assert!(
273        n < 2_usize.pow(counter_mode.counter_length as u32),
274        "Invalid derived key length"
275    );
276    for i in 1..=n {
277        prf.init(key)?;
278        let counter = i.to_be_bytes();
279        let counter = &counter[(counter.len() - counter_mode.counter_length / 8)..];
280        match input_type {
281            InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
282                CounterLocation::NoCounter => prf.update(fixed_input.fixed_input)?,
283                CounterLocation::BeforeFixedInput => {
284                    prf.update(counter)?;
285                    prf.update(fixed_input.fixed_input)?;
286                }
287                CounterLocation::MiddleOfFixedInput(position) => {
288                    prf.update(&fixed_input.fixed_input[..position])?;
289                    prf.update(counter)?;
290                    prf.update(&fixed_input.fixed_input[position..])?;
291                }
292                CounterLocation::AfterFixedInput => {
293                    prf.update(fixed_input.fixed_input)?;
294                    prf.update(counter)?;
295                }
296                _ => panic!(
297                    "Invalid counter location for KBKDF In Counter Mode: {:?}",
298                    fixed_input.counter_location
299                ),
300            },
301            InputType::SpecifiedInput(specified_input) => {
302                prf.update(counter)?;
303                prf.update(specified_input.label)?;
304                prf.update(b"\0")?;
305                prf.update(specified_input.context)?;
306                let length = (l as u32).to_be_bytes();
307                prf.update(&length)?;
308            }
309        }
310        let _ = prf.finish(intermediate_key.as_mut_slice())?;
311        insert_result(i, intermediate_key.as_slice(), derived_key);
312        intermediate_key.zeroize();
313    }
314
315    Ok(())
316}
317
318fn kbkdf_double_pipeline<'a, T: PseudoRandomFunction<'a>>(
319    double_feedback: &DoublePipelineIterationMode,
320    input_type: &InputType,
321    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
322    prf: &mut T,
323    derived_key: &mut [u8],
324) -> Result<(), T::Error> {
325    let l = derived_key.len() * 8;
326    let h = T::PrfOutputSize::to_int() * 8;
327    let n = calculate_counter(l, h);
328    let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
329    let mut feedback = GenericArray::<u8, T::PrfOutputSize>::default();
330    let length = (l as u32).to_be_bytes();
331    assert!(
332        n < 2_usize.pow(32),
333        "Invalid length provided for derived key"
334    );
335    for i in 1..=n {
336        let counter = i.to_be_bytes();
337        let counter = feedback_counter(double_feedback.counter_length, counter.as_slice());
338        prf.init(key)?;
340        if i == 1 {
341            match input_type {
342                InputType::FixedInput(fixed_input) => {
343                    prf.update(fixed_input.fixed_input)?;
344                }
345                InputType::SpecifiedInput(specified_input) => {
346                    prf.update(specified_input.label)?;
347                    prf.update(b"\0")?;
348                    prf.update(specified_input.context)?;
349                    prf.update(length.as_slice())?;
350                }
351            }
352        } else {
353            prf.update(feedback.as_slice())?;
354        }
355        let _ = prf.finish(feedback.as_mut_slice())?;
356
357        prf.init(key)?;
358
359        match input_type {
360            InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
361                CounterLocation::NoCounter => {
362                    prf.update(feedback.as_slice())?;
363                    prf.update(fixed_input.fixed_input)?;
364                }
365                CounterLocation::BeforeIter => {
366                    prf.update(
367                        counter
368                            .expect("Counter length not provided for BeforeIter counter location"),
369                    )?;
370                    prf.update(feedback.as_slice())?;
371                    prf.update(fixed_input.fixed_input)?;
372                }
373                CounterLocation::AfterFixedInput => {
374                    prf.update(feedback.as_slice())?;
375                    prf.update(fixed_input.fixed_input)?;
376                    prf.update(counter.expect(
377                        "Counter length not provided for AfterFixedInput counter location",
378                    ))?;
379                }
380                CounterLocation::AfterIter => {
381                    prf.update(feedback.as_slice())?;
382                    prf.update(
383                        counter
384                            .expect("Counter length not provided for AfterIter counter location"),
385                    )?;
386                    prf.update(fixed_input.fixed_input)?;
387                }
388                _ => panic!(
389                    "Invalid counter location for double feedback: {:?}",
390                    fixed_input.counter_location
391                ),
392            },
393            InputType::SpecifiedInput(specified_input) => {
394                prf.update(feedback.as_slice())?;
395                if let Some(counter) = counter {
396                    prf.update(counter)?;
397                }
398                prf.update(specified_input.label)?;
399                prf.update(b"\0")?;
400                prf.update(specified_input.context)?;
401                prf.update(&length)?;
402            }
403        }
404
405        let _ = prf.finish(intermediate_key.as_mut_slice())?;
406        insert_result(i, intermediate_key.as_slice(), derived_key);
407        intermediate_key.zeroize();
408    }
409
410    Ok(())
411}
412
413fn kbkdf_feedback<'a, T: PseudoRandomFunction<'a>>(
414    feedback_mode: &FeedbackMode,
415    input_type: &InputType,
416    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
417    prf: &mut T,
418    derived_key: &mut [u8],
419) -> Result<(), T::Error> {
420    let l = derived_key.len() * 8;
421    let h = T::PrfOutputSize::to_int() * 8;
422    let n = calculate_counter(l, h);
423
424    let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
425    let mut has_intermediate = feedback_mode.iv.is_some();
426    if let Some(iv) = feedback_mode.iv {
427        assert_eq!(iv.len(), T::PrfOutputSize::to_int());
428        intermediate_key.copy_from_slice(iv);
429    }
430    assert!(n < 2_usize.pow(32), "Invalid derived_key length provided");
431    for i in 1..=n {
432        prf.init(key)?;
433        let counter = i.to_be_bytes();
434        let counter = feedback_counter(feedback_mode.counter_length, counter.as_slice());
435        match input_type {
436            InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
437                CounterLocation::NoCounter => {
438                    if has_intermediate {
439                        prf.update(intermediate_key.as_slice())?;
440                    }
441                    prf.update(fixed_input.fixed_input)?;
442                }
443                CounterLocation::BeforeIter => {
444                    prf.update(
445                        counter
446                            .expect("Counter length not provided for BeforeIter counter location"),
447                    )?;
448                    if has_intermediate {
449                        prf.update(intermediate_key.as_slice())?;
450                    }
451                    prf.update(fixed_input.fixed_input)?;
452                }
453                CounterLocation::AfterIter => {
454                    if has_intermediate {
455                        prf.update(intermediate_key.as_slice())?;
456                    }
457                    prf.update(
458                        counter
459                            .expect("Counter length not provided for AfterIter counter location"),
460                    )?;
461                    prf.update(fixed_input.fixed_input)?;
462                }
463                CounterLocation::AfterFixedInput => {
464                    if has_intermediate {
465                        prf.update(intermediate_key.as_slice())?;
466                    }
467                    prf.update(fixed_input.fixed_input)?;
468                    prf.update(counter.expect(
469                        "Counter length not provided for AfterFixedInput counter location",
470                    ))?;
471                }
472                _ => panic!(
473                    "Invalid counter location provided for KDF feedback mode: {:?}",
474                    fixed_input.counter_location
475                ),
476            },
477            InputType::SpecifiedInput(specified_input) => {
478                if has_intermediate {
479                    prf.update(intermediate_key.as_slice())?;
480                }
481                if let Some(counter) = counter {
482                    prf.update(counter)?;
483                }
484                prf.update(specified_input.label)?;
485                prf.update(b"\0")?;
486                prf.update(specified_input.context)?;
487                let length = (l as u32).to_be_bytes();
488                prf.update(&length)?;
489            }
490        }
491        let _ = prf.finish(intermediate_key.as_mut_slice())?;
492        insert_result(i, intermediate_key.as_slice(), derived_key);
493        has_intermediate = true;
494    }
495
496    Ok(())
497}
498
499fn calculate_counter(derived_key_len_bits: usize, prf_output_size_in_bits: usize) -> usize {
500    derived_key_len_bits / prf_output_size_in_bits
501        + if derived_key_len_bits % prf_output_size_in_bits != 0 {
502            1
503        } else {
504            0
505        }
506}
507
508fn feedback_counter(counter_length: Option<usize>, counter: &[u8]) -> Option<&[u8]> {
509    match counter_length {
510        None => None,
511        Some(length) => Some(&counter[(counter.len() - length / 8)..]),
512    }
513}
514
515fn insert_result(counter: usize, intermediate: &[u8], result: &mut [u8]) {
516    let low_index = (counter - 1) * intermediate.len();
517    assert!(
518        low_index < result.len(),
519        "The starting insert index should not exceed bounds of result slice"
520    );
521    let high_index = core::cmp::min(low_index + intermediate.len(), result.len());
522    assert!(
523        high_index <= result.len(),
524        "Ending insert index should not exceed bounds of result slice"
525    );
526    result[low_index..high_index].clone_from_slice(&intermediate[..(high_index - low_index)]);
527}