rust_kbkdf/
lib.rs

1//! # Rust Implementation of NIST SP800-108 Key Based Key Derivation Function (KBKDF)
2//!
3//! This crate provides a Rust implementation of the [NIST SP800-108](https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf)
4//! standard for performing key-derivation based on a source key.
5//!
6//! This crate implements the KBKDF in the following modes:
7//!
8//! * Counter
9//! * Feedback
10//! * Double-Pipeline Iteration
11//!
12//! This crate was designed such that the user may provide their own Pseudo Random Function (as defined in Section 4 of
13//! [SP800-108](https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf)) via the implementation of
14//! two traits:
15//!
16//! * [`PseudoRandomFunctionKey`]
17//! * [`PseudoRandomFunction`]
18//!
19//! ## Psuedo Random Function Trait
20//!
21//! The purpose of the PRF trait is to allow a user to provide their own implementation of a PRF (as defined in Section 4
22//! of [SP800-108](https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf)).
23//!
24//! **Please note, that in order for an implementation of KBKDF to be NIST approved, an approved PRF must be used!**
25//!
26//! The author of this crate _does not_ guarantee that this implementation is NIST approved!
27//!
28//! ## Pseudo Random Function Key
29//!
30//! This trait is used to ensure that the implementation of the `PseudoRandomFunction` trait can access the necessary
31//! source key in a way that passes Rust's borrow checker.
32//!
33//! ## Example
34//!
35//! An example of how to use the two traits are found in the `tests` module utilizing the [OpenSSL Crate](https://crates.io/crates/openssl).
36
37// This list comes from
38// https://github.com/rust-unofficial/patterns/blob/master/anti_patterns/deny-warnings.md
39#![deny(
40    bad_style,
41    dead_code,
42    improper_ctypes,
43    rustdoc::broken_intra_doc_links,
44    non_shorthand_field_patterns,
45    no_mangle_generic_items,
46    overflowing_literals,
47    path_statements,
48    patterns_in_fns_without_body,
49    private_bounds,
50    private_interfaces,
51    unconditional_recursion,
52    unused,
53    unused_allocation,
54    unused_comparisons,
55    unused_parens,
56    while_true,
57    missing_debug_implementations,
58    missing_copy_implementations,
59    missing_docs,
60    trivial_casts,
61    trivial_numeric_casts,
62    unnameable_types,
63    unused_extern_crates,
64    unused_import_braces,
65    unused_qualifications,
66    unused_results
67)]
68#![no_std]
69
70use generic_array::{ArrayLength, GenericArray};
71use typenum::ToInt;
72
73use zeroize::Zeroize;
74
75/// Defines how a PseudoRandomFunction handles a key
76pub trait PseudoRandomFunctionKey {
77    /// The key handle type this returns
78    type KeyHandle;
79
80    /// Returns the key handle held by this instance
81    fn key_handle(&self) -> &Self::KeyHandle;
82}
83
84/// Defines how the KBKDF crate will interact with PRFs
85/// This allows the user of this crate to provide their own implementation of a PRF, however, only
86/// SP800-108 specified PRFs are allowed in the approved mode of operation.  Given that, this crate
87/// cannot test for that and assumes that the user is using an approved PRF.
88pub trait PseudoRandomFunction<'a> {
89    /// The type kf key handle the PRF is expecting
90    type KeyHandle;
91    /// The PRF output size
92    type PrfOutputSize: ArrayLength<u8> + ToInt<usize>;
93    /// The error type returned
94    type Error;
95
96    /// Initializes the pseudo random function
97    ///
98    /// # Arguments
99    ///
100    /// * `key` - The key (K<sub>1</sub>)
101    ///
102    /// # Returns
103    ///
104    /// Either nothing or an [`Error`]
105    ///
106    /// # Panics
107    ///
108    /// This function is allowed to panic if [`init`](PseudoRandomFunction::init) is called while already initialized
109    fn init(
110        &mut self,
111        key: &'a dyn PseudoRandomFunctionKey<KeyHandle = Self::KeyHandle>,
112    ) -> Result<(), Self::Error>;
113
114    /// Updates the PRF function
115    ///
116    /// # Arguments
117    ///
118    /// * `msg` - The next message to input into the PRF
119    ///
120    /// # Returns
121    ///
122    /// Either nothing or an [`Error`]
123    ///
124    /// # Panics
125    ///
126    /// This function is allowed to panic if [`update`](PseudoRandomFunction::update)
127    /// is called before [`init`](PseudoRandomFunction::init)
128    fn update(&mut self, msg: &[u8]) -> Result<(), Self::Error>;
129
130    /// Finishes the PRF and returns the value in a buffer
131    ///
132    /// # Arguments
133    ///
134    /// * `out` - The result of the PRF
135    ///
136    /// # Returns
137    ///
138    /// Either nothing or an [`Error`]
139    ///
140    /// # Panics
141    ///
142    /// This function is allowed to panic if [`finish`](PseudoRandomFunction::finish)
143    /// is called before [`init`](PseudoRandomFunction::init)
144    fn finish(&mut self, out: &mut [u8]) -> Result<usize, Self::Error>;
145}
146
147/// Counter mode options
148#[derive(Copy, Clone, Debug)]
149pub struct CounterMode {
150    /// Length of the binary representation of the counter, in bits
151    pub counter_length: usize,
152}
153
154/// Defines options for KDF in feedback mode
155#[derive(Copy, Clone, Debug)]
156pub struct FeedbackMode<'a> {
157    /// Initial value used in first iteration of feedback mode
158    pub iv: Option<&'a [u8]>,
159    /// Length of the binary representation of the counter, in bits.  If not provided, counter unused
160    pub counter_length: Option<usize>,
161}
162
163/// Defines options for KDF in double-pipeline iteration mode
164#[derive(Copy, Clone, Debug)]
165pub struct DoublePipelineIterationMode {
166    /// Length of the binary representation of the counter, in bits.  If not provided, counter unused
167    pub counter_length: Option<usize>,
168}
169
170/// Defines types and arguments for specific KDF modes
171#[derive(Copy, Clone, Debug)]
172pub enum KDFMode<'a> {
173    /// KDF in counter mode (SP800-108 Section 5.1)
174    CounterMode(CounterMode),
175    /// KDF in feedback mode (SP800-108 Section 5.2)
176    FeedbackMode(FeedbackMode<'a>),
177    /// KDF in double-pipeline iteration mode (SP800-108 Section 5.3)
178    DoublePipelineIterationMode(DoublePipelineIterationMode),
179}
180
181/// Used to set location of counter when using fixed input
182#[derive(Copy, Clone, Debug)]
183pub enum CounterLocation {
184    /// No use for counter
185    NoCounter,
186    /// Counter before fixed input
187    BeforeFixedInput,
188    /// Before the iteration variable
189    BeforeIter,
190    /// Counter is placed at a specified bit location
191    MiddleOfFixedInput(usize),
192    /// Counter after fixed input
193    AfterFixedInput,
194    /// Counter after the iteration variable
195    AfterIter,
196}
197
198/// Fixed input used when implementation is under test
199#[derive(Debug)]
200pub struct FixedInput<'a> {
201    /// The fixed input
202    pub fixed_input: &'a [u8],
203    /// The location of the counter
204    pub counter_location: CounterLocation,
205}
206
207/// Specified input for PRF
208#[derive(Debug)]
209pub struct SpecifiedInput<'a> {
210    /// Identifies purpose of the derived keying material
211    pub label: &'a [u8],
212    /// Information related to the derived keying material
213    pub context: &'a [u8],
214}
215
216/// The type of input.  May be a fixed input
217#[derive(Debug)]
218pub enum InputType<'a> {
219    /// Fixed input with a specific counter location.  This should only be used when the implementation
220    /// is undergoing ACVP testing (see <https://pages.nist.gov/ACVP/draft-celi-acvp-kbkdf.html#SP800-108>)
221    FixedInput(FixedInput<'a>),
222    /// Input specifying label and context
223    SpecifiedInput(SpecifiedInput<'a>),
224}
225
226/// Performs [SP800-108](https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-108.pdf)
227/// key-based key derivation function
228///
229/// # Inputs
230///
231/// * `kdf_mode` - Which mode the the derivation function will run in
232/// * `input_type` - The type of input used to derive the key
233/// * `key` - The base key to use to derive the key
234/// * `prf` - The Pseudo-random function used to derive the key
235/// * `derived_key` - The output key
236///
237/// # Panics
238///
239/// If invalid options are provided, this function will panic
240pub fn kbkdf<'a, T: PseudoRandomFunction<'a>>(
241    kdf_mode: &KDFMode,
242    input_type: &InputType,
243    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
244    prf: &mut T,
245    derived_key: &mut [u8],
246) -> Result<(), T::Error> {
247    match kdf_mode {
248        KDFMode::CounterMode(counter_mode) => {
249            kbkdf_counter::<T>(counter_mode, input_type, key, prf, derived_key)
250        }
251        KDFMode::FeedbackMode(feedback_mode) => {
252            kbkdf_feedback::<T>(feedback_mode, input_type, key, prf, derived_key)
253        }
254        KDFMode::DoublePipelineIterationMode(double_pipeline) => {
255            kbkdf_double_pipeline::<T>(double_pipeline, input_type, key, prf, derived_key)
256        }
257    }
258}
259
260fn kbkdf_counter<'a, T: PseudoRandomFunction<'a>>(
261    counter_mode: &CounterMode,
262    input_type: &InputType,
263    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
264    prf: &mut T,
265    derived_key: &mut [u8],
266) -> Result<(), T::Error> {
267    // Step 1 -> n = CEIL(L/h)
268    let l = derived_key.len() * 8;
269    let h = T::PrfOutputSize::to_int() * 8;
270    let n = calculate_counter(l, h);
271    let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
272    assert!(
273        n < 2_usize.pow(counter_mode.counter_length as u32),
274        "Invalid derived key length"
275    );
276    for i in 1..=n {
277        prf.init(key)?;
278        let counter = i.to_be_bytes();
279        let counter = &counter[(counter.len() - counter_mode.counter_length / 8)..];
280        match input_type {
281            InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
282                CounterLocation::NoCounter => prf.update(fixed_input.fixed_input)?,
283                CounterLocation::BeforeFixedInput => {
284                    prf.update(counter)?;
285                    prf.update(fixed_input.fixed_input)?;
286                }
287                CounterLocation::MiddleOfFixedInput(position) => {
288                    prf.update(&fixed_input.fixed_input[..position])?;
289                    prf.update(counter)?;
290                    prf.update(&fixed_input.fixed_input[position..])?;
291                }
292                CounterLocation::AfterFixedInput => {
293                    prf.update(fixed_input.fixed_input)?;
294                    prf.update(counter)?;
295                }
296                _ => panic!(
297                    "Invalid counter location for KBKDF In Counter Mode: {:?}",
298                    fixed_input.counter_location
299                ),
300            },
301            InputType::SpecifiedInput(specified_input) => {
302                prf.update(counter)?;
303                prf.update(specified_input.label)?;
304                prf.update(b"\0")?;
305                prf.update(specified_input.context)?;
306                let length = (l as u32).to_be_bytes();
307                prf.update(&length)?;
308            }
309        }
310        let _ = prf.finish(intermediate_key.as_mut_slice())?;
311        insert_result(i, intermediate_key.as_slice(), derived_key);
312        intermediate_key.zeroize();
313    }
314
315    Ok(())
316}
317
318fn kbkdf_double_pipeline<'a, T: PseudoRandomFunction<'a>>(
319    double_feedback: &DoublePipelineIterationMode,
320    input_type: &InputType,
321    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
322    prf: &mut T,
323    derived_key: &mut [u8],
324) -> Result<(), T::Error> {
325    let l = derived_key.len() * 8;
326    let h = T::PrfOutputSize::to_int() * 8;
327    let n = calculate_counter(l, h);
328    let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
329    let mut feedback = GenericArray::<u8, T::PrfOutputSize>::default();
330    let length = (l as u32).to_be_bytes();
331    assert!(
332        n < 2_usize.pow(32),
333        "Invalid length provided for derived key"
334    );
335    for i in 1..=n {
336        let counter = i.to_be_bytes();
337        let counter = feedback_counter(double_feedback.counter_length, counter.as_slice());
338        // First calculate feedback, if the first iteration use the provided input vaalue
339        prf.init(key)?;
340        if i == 1 {
341            match input_type {
342                InputType::FixedInput(fixed_input) => {
343                    prf.update(fixed_input.fixed_input)?;
344                }
345                InputType::SpecifiedInput(specified_input) => {
346                    prf.update(specified_input.label)?;
347                    prf.update(b"\0")?;
348                    prf.update(specified_input.context)?;
349                    prf.update(length.as_slice())?;
350                }
351            }
352        } else {
353            prf.update(feedback.as_slice())?;
354        }
355        let _ = prf.finish(feedback.as_mut_slice())?;
356
357        prf.init(key)?;
358
359        match input_type {
360            InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
361                CounterLocation::NoCounter => {
362                    prf.update(feedback.as_slice())?;
363                    prf.update(fixed_input.fixed_input)?;
364                }
365                CounterLocation::BeforeIter => {
366                    prf.update(
367                        counter
368                            .expect("Counter length not provided for BeforeIter counter location"),
369                    )?;
370                    prf.update(feedback.as_slice())?;
371                    prf.update(fixed_input.fixed_input)?;
372                }
373                CounterLocation::AfterFixedInput => {
374                    prf.update(feedback.as_slice())?;
375                    prf.update(fixed_input.fixed_input)?;
376                    prf.update(counter.expect(
377                        "Counter length not provided for AfterFixedInput counter location",
378                    ))?;
379                }
380                CounterLocation::AfterIter => {
381                    prf.update(feedback.as_slice())?;
382                    prf.update(
383                        counter
384                            .expect("Counter length not provided for AfterIter counter location"),
385                    )?;
386                    prf.update(fixed_input.fixed_input)?;
387                }
388                _ => panic!(
389                    "Invalid counter location for double feedback: {:?}",
390                    fixed_input.counter_location
391                ),
392            },
393            InputType::SpecifiedInput(specified_input) => {
394                prf.update(feedback.as_slice())?;
395                if let Some(counter) = counter {
396                    prf.update(counter)?;
397                }
398                prf.update(specified_input.label)?;
399                prf.update(b"\0")?;
400                prf.update(specified_input.context)?;
401                prf.update(&length)?;
402            }
403        }
404
405        let _ = prf.finish(intermediate_key.as_mut_slice())?;
406        insert_result(i, intermediate_key.as_slice(), derived_key);
407        intermediate_key.zeroize();
408    }
409
410    Ok(())
411}
412
413fn kbkdf_feedback<'a, T: PseudoRandomFunction<'a>>(
414    feedback_mode: &FeedbackMode,
415    input_type: &InputType,
416    key: &'a dyn PseudoRandomFunctionKey<KeyHandle = T::KeyHandle>,
417    prf: &mut T,
418    derived_key: &mut [u8],
419) -> Result<(), T::Error> {
420    let l = derived_key.len() * 8;
421    let h = T::PrfOutputSize::to_int() * 8;
422    let n = calculate_counter(l, h);
423
424    let mut intermediate_key = GenericArray::<u8, T::PrfOutputSize>::default();
425    let mut has_intermediate = feedback_mode.iv.is_some();
426    if let Some(iv) = feedback_mode.iv {
427        assert_eq!(iv.len(), T::PrfOutputSize::to_int());
428        intermediate_key.copy_from_slice(iv);
429    }
430    assert!(n < 2_usize.pow(32), "Invalid derived_key length provided");
431    for i in 1..=n {
432        prf.init(key)?;
433        let counter = i.to_be_bytes();
434        let counter = feedback_counter(feedback_mode.counter_length, counter.as_slice());
435        match input_type {
436            InputType::FixedInput(fixed_input) => match fixed_input.counter_location {
437                CounterLocation::NoCounter => {
438                    if has_intermediate {
439                        prf.update(intermediate_key.as_slice())?;
440                    }
441                    prf.update(fixed_input.fixed_input)?;
442                }
443                CounterLocation::BeforeIter => {
444                    prf.update(
445                        counter
446                            .expect("Counter length not provided for BeforeIter counter location"),
447                    )?;
448                    if has_intermediate {
449                        prf.update(intermediate_key.as_slice())?;
450                    }
451                    prf.update(fixed_input.fixed_input)?;
452                }
453                CounterLocation::AfterIter => {
454                    if has_intermediate {
455                        prf.update(intermediate_key.as_slice())?;
456                    }
457                    prf.update(
458                        counter
459                            .expect("Counter length not provided for AfterIter counter location"),
460                    )?;
461                    prf.update(fixed_input.fixed_input)?;
462                }
463                CounterLocation::AfterFixedInput => {
464                    if has_intermediate {
465                        prf.update(intermediate_key.as_slice())?;
466                    }
467                    prf.update(fixed_input.fixed_input)?;
468                    prf.update(counter.expect(
469                        "Counter length not provided for AfterFixedInput counter location",
470                    ))?;
471                }
472                _ => panic!(
473                    "Invalid counter location provided for KDF feedback mode: {:?}",
474                    fixed_input.counter_location
475                ),
476            },
477            InputType::SpecifiedInput(specified_input) => {
478                if has_intermediate {
479                    prf.update(intermediate_key.as_slice())?;
480                }
481                if let Some(counter) = counter {
482                    prf.update(counter)?;
483                }
484                prf.update(specified_input.label)?;
485                prf.update(b"\0")?;
486                prf.update(specified_input.context)?;
487                let length = (l as u32).to_be_bytes();
488                prf.update(&length)?;
489            }
490        }
491        let _ = prf.finish(intermediate_key.as_mut_slice())?;
492        insert_result(i, intermediate_key.as_slice(), derived_key);
493        has_intermediate = true;
494    }
495
496    Ok(())
497}
498
499fn calculate_counter(derived_key_len_bits: usize, prf_output_size_in_bits: usize) -> usize {
500    derived_key_len_bits / prf_output_size_in_bits
501        + if derived_key_len_bits % prf_output_size_in_bits != 0 {
502            1
503        } else {
504            0
505        }
506}
507
508fn feedback_counter(counter_length: Option<usize>, counter: &[u8]) -> Option<&[u8]> {
509    match counter_length {
510        None => None,
511        Some(length) => Some(&counter[(counter.len() - length / 8)..]),
512    }
513}
514
515fn insert_result(counter: usize, intermediate: &[u8], result: &mut [u8]) {
516    let low_index = (counter - 1) * intermediate.len();
517    assert!(
518        low_index < result.len(),
519        "The starting insert index should not exceed bounds of result slice"
520    );
521    let high_index = core::cmp::min(low_index + intermediate.len(), result.len());
522    assert!(
523        high_index <= result.len(),
524        "Ending insert index should not exceed bounds of result slice"
525    );
526    result[low_index..high_index].clone_from_slice(&intermediate[..(high_index - low_index)]);
527}