webrtc_srtp/context/
mod.rs

1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use std::collections::HashMap;
9
10use aes::Aes128;
11use aes::Aes256;
12use util::replay_detector::*;
13
14use crate::cipher::cipher_aead_aes_gcm::*;
15use crate::cipher::cipher_aes_cm_hmac_sha1::*;
16use crate::cipher::*;
17use crate::error::{Error, Result};
18use crate::option::*;
19use crate::protection_profile::*;
20
21pub mod srtcp;
22pub mod srtp;
23
24const MAX_ROC: u32 = u32::MAX;
25const SEQ_NUM_MEDIAN: u16 = 1 << 15;
26const SEQ_NUM_MAX: u16 = u16::MAX;
27
28/// Encrypt/Decrypt state for a single SRTP SSRC
29#[derive(Default)]
30pub(crate) struct SrtpSsrcState {
31    ssrc: u32,
32    index: u64,
33    rollover_has_processed: bool,
34    replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
35}
36
37/// Encrypt/Decrypt state for a single SRTCP SSRC
38#[derive(Default)]
39pub(crate) struct SrtcpSsrcState {
40    srtcp_index: usize,
41    ssrc: u32,
42    replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
43}
44
45impl SrtpSsrcState {
46    pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
47        let local_roc = (self.index >> 16) as u32;
48        let local_seq = self.index as u16;
49
50        let mut guess_roc = local_roc;
51
52        let diff = if self.rollover_has_processed {
53            let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
54            // When local_roc is equal to 0, and entering seq-local_seq > SEQ_NUM_MEDIAN
55            // judgment, it will cause guess_roc calculation error
56            if self.index > SEQ_NUM_MEDIAN as _ {
57                if local_seq < SEQ_NUM_MEDIAN {
58                    if seq > SEQ_NUM_MEDIAN as i32 {
59                        guess_roc = local_roc.wrapping_sub(1);
60                        seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
61                    } else {
62                        seq
63                    }
64                } else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
65                    guess_roc = local_roc.wrapping_add(1);
66                    seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
67                } else {
68                    seq
69                }
70            } else {
71                // local_roc is equal to 0
72                seq
73            }
74        } else {
75            0i32
76        };
77
78        (guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
79    }
80
81    /// https://tools.ietf.org/html/rfc3550#appendix-A.1
82    pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
83        if !self.rollover_has_processed {
84            self.index |= sequence_number as u64;
85            self.rollover_has_processed = true;
86        } else {
87            self.index = self.index.wrapping_add(diff as _);
88        }
89    }
90}
91
92/// Context represents a SRTP cryptographic context
93/// Context can only be used for one-way operations
94/// it must either used ONLY for encryption or ONLY for decryption
95pub struct Context {
96    cipher: Box<dyn Cipher + Send>,
97
98    srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
99    srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
100
101    new_srtp_replay_detector: ContextOption,
102    new_srtcp_replay_detector: ContextOption,
103}
104
105impl Context {
106    /// CreateContext creates a new SRTP Context
107    pub fn new(
108        master_key: &[u8],
109        master_salt: &[u8],
110        profile: ProtectionProfile,
111        srtp_ctx_opt: Option<ContextOption>,
112        srtcp_ctx_opt: Option<ContextOption>,
113    ) -> Result<Context> {
114        let key_len = profile.key_len();
115        let salt_len = profile.salt_len();
116
117        if master_key.len() != key_len {
118            return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
119        } else if master_salt.len() != salt_len {
120            return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
121        }
122
123        let cipher: Box<dyn Cipher + Send> = match profile {
124            ProtectionProfile::Aes128CmHmacSha1_32 | ProtectionProfile::Aes128CmHmacSha1_80 => {
125                Box::new(CipherAesCmHmacSha1::new(profile, master_key, master_salt)?)
126            }
127
128            ProtectionProfile::AeadAes128Gcm => Box::new(CipherAeadAesGcm::<Aes128>::new(
129                profile,
130                master_key,
131                master_salt,
132            )?),
133
134            ProtectionProfile::AeadAes256Gcm => Box::new(CipherAeadAesGcm::<Aes256>::new(
135                profile,
136                master_key,
137                master_salt,
138            )?),
139        };
140
141        let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
142            ctx_opt
143        } else {
144            srtp_no_replay_protection()
145        };
146
147        let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
148            ctx_opt
149        } else {
150            srtcp_no_replay_protection()
151        };
152
153        Ok(Context {
154            cipher,
155            srtp_ssrc_states: HashMap::new(),
156            srtcp_ssrc_states: HashMap::new(),
157            new_srtp_replay_detector: srtp_ctx_opt,
158            new_srtcp_replay_detector: srtcp_ctx_opt,
159        })
160    }
161
162    fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtpSsrcState {
163        let s = SrtpSsrcState {
164            ssrc,
165            replay_detector: Some((self.new_srtp_replay_detector)()),
166            ..Default::default()
167        };
168
169        self.srtp_ssrc_states.entry(ssrc).or_insert(s)
170    }
171
172    fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState {
173        let s = SrtcpSsrcState {
174            ssrc,
175            replay_detector: Some((self.new_srtcp_replay_detector)()),
176            ..Default::default()
177        };
178        self.srtcp_ssrc_states.entry(ssrc).or_insert(s)
179    }
180
181    /// roc returns SRTP rollover counter value of specified SSRC.
182    fn get_roc(&self, ssrc: u32) -> Option<u32> {
183        self.srtp_ssrc_states
184            .get(&ssrc)
185            .map(|s| (s.index >> 16) as _)
186    }
187
188    /// set_roc sets SRTP rollover counter value of specified SSRC.
189    fn set_roc(&mut self, ssrc: u32, roc: u32) {
190        let state = self.get_srtp_ssrc_state(ssrc);
191        state.index = (roc as u64) << 16;
192        state.rollover_has_processed = false;
193    }
194
195    /// index returns SRTCP index value of specified SSRC.
196    fn get_index(&self, ssrc: u32) -> Option<usize> {
197        self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
198    }
199
200    /// set_index sets SRTCP index value of specified SSRC.
201    fn set_index(&mut self, ssrc: u32, index: usize) {
202        self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
203    }
204}