streaming_crypto/core_api/stream_v2/
core.rs1
2use std::sync::Arc;
5use std::ops::Deref;
6
7use crate::{
8 constants::{MASTER_KEY_LENGTHS, MAGIC_DICT, MAX_DICT_LEN, MIN_DICT_LEN, DEFAULT_QUEUE_CAP, DEFAULT_WORKERS, QUEUE_CAPS, WORKERS_COUNT},
9 crypto::{CryptoError, DigestAlg, derive_session_key_32},
10 headers::HeaderV1, recovery::AsyncLogManager,
11 parallelism::{HybridParallelismProfile, ParallelismConfig},
12 stream_v2::{
13 io::{InputSource, OutputSink, PayloadReader, open_input, open_output},
14 pipeline::{PipelineConfig, decrypt_pipeline, encrypt_pipeline},
15 segment_worker::{DecryptContext, EncryptContext}},
16 telemetry::TelemetrySnapshot,
17 types::StreamError
18};
19
20#[derive(Clone)]
21pub struct MasterKey(Vec<u8>);
22
23impl Deref for MasterKey {
24 type Target = [u8];
25
26 fn deref(&self) -> &Self::Target {
27 &self.0
28 }
29}
30
31impl MasterKey {
32 pub fn new(bytes: Vec<u8>) -> Self {
33 Self(bytes)
46 }
47
48 pub fn as_slice(&self) -> &[u8] {
49 &self.0
50 }
51
52 pub fn validate(bytes: &[u8]) -> Result<(), StreamError> {
54 let len = bytes.len();
55
56 if MASTER_KEY_LENGTHS.contains(&len) {
57 Ok(())
58 } else {
59 Err(StreamError::Crypto(
60 CryptoError::InvalidKeyLen {
61 expected: &MASTER_KEY_LENGTHS,
62 actual: len,
63 },
64 ))
65 }
66 }
67}
68
69
70#[derive(Clone, Debug)]
71pub struct EncryptParams<'a> {
72 pub header: HeaderV1,
73 pub dict: Option<&'a [u8]>,
74}
75impl<'a> EncryptParams<'a> {
76 pub fn validate(&self) -> Result<(), StreamError> {
77 validate_dictionary(self.dict.as_deref())?;
78 Ok(())
81 }
82}
83#[derive(Clone, Debug)]
84pub struct DecryptParams;
85impl DecryptParams {
86 pub fn validate(&self) -> Result<(), StreamError> {
87 Ok(())
88 }
89}
90
91#[derive(Debug, Clone)]
92pub struct ApiConfig {
93 pub with_buf: Option<bool>,
97
98 pub collect_metrics: Option<bool>,
101
102 pub alg: Option<DigestAlg>,
105
106 pub parallelism: Option<ParallelismConfig>,
109}
110
111impl Default for ApiConfig {
112 fn default() -> Self {
113 Self {
114 with_buf: Some(false), collect_metrics: Some(false), alg: Some(DigestAlg::Blake3), parallelism: Some(ParallelismConfig::default()),
118 }
119 }
120}
121
122impl ApiConfig {
123 pub fn new(with_buf: Option<bool>, collect_metrics: Option<bool>, alg: Option<DigestAlg>, parallelism: Option<ParallelismConfig>) -> Self {
124 Self {
125 with_buf: with_buf.or(Some(false)),
126 collect_metrics: collect_metrics.or(Some(false)),
127 alg: alg.or(Some(DigestAlg::Blake3)),
128 parallelism: Some(parallelism.unwrap_or_default()),
129 }
130 }
131 pub fn with_defaults(self) -> Self {
133 let defaults = ApiConfig::default();
134
135 Self {
136 with_buf: self.with_buf.or(defaults.with_buf),
137 collect_metrics: self.collect_metrics.or(defaults.collect_metrics),
138 alg: self.alg.or(defaults.alg),
139 parallelism: self.parallelism.or(defaults.parallelism),
140 }
141 }
142}
143
144fn setup_enc_context(master_key: &MasterKey, header: &HeaderV1, config: ApiConfig)
145 -> Result<(EncryptContext, HybridParallelismProfile, Arc<AsyncLogManager>), StreamError>
146{
147 let session_key = derive_session_key_32(&master_key, header).map_err(StreamError::Crypto)?;
148
149 let profile = HybridParallelismProfile::from_stream_header(header.clone(), config.parallelism)?;
150 let context = EncryptContext::new(header.clone(), profile.clone(), &session_key, config.alg.unwrap())
151 .map_err(StreamError::SegmentWorker)?;
152 let log_manager = Arc::new(AsyncLogManager::new("stream_v2_enc.log", 100)?);
153
154 Ok((context, profile, log_manager))
155}
156
157fn setup_dec_context(master_key: &MasterKey, header: &HeaderV1, config: ApiConfig)
158 -> Result<(DecryptContext, HybridParallelismProfile, Arc<AsyncLogManager>), StreamError>
159{
160 let session_key = derive_session_key_32(&master_key, header).map_err(StreamError::Crypto)?;
161
162 let profile = HybridParallelismProfile::from_stream_header(header.clone(), config.parallelism)?;
163 let context = DecryptContext::from_stream_header(header.clone(), profile.clone(), &session_key, config.alg.unwrap())
164 .map_err(StreamError::SegmentWorker)?;
165 let log_manager = Arc::new(AsyncLogManager::new("stream_v2_dec.log", 100)?);
166
167 Ok((context, profile, log_manager))
168}
169
170pub fn encrypt_stream_v2(
172 input: InputSource,
173 output: OutputSink,
174 master_key: &MasterKey,
175 params: EncryptParams,
176 config: ApiConfig,
177) -> Result<TelemetrySnapshot, StreamError> {
178 validate_encrypt_params(&master_key, ¶ms, None, None)?;
179
180 let final_config = config.with_defaults();
182
183 let reader = open_input(input)?;
184 let (writer, maybe_buf) = open_output(output, final_config.with_buf)?;
185
186 let mut payload_reader = PayloadReader::new(reader);
188
189 let (crypto, profile, log_manager) = setup_enc_context(&master_key, ¶ms.header, final_config)?;
190 let config_pipe = PipelineConfig::new(profile, maybe_buf.clone());
191
192 let crypto = Arc::new(crypto);
194
195 let mut snapshot = encrypt_pipeline(
197 &mut payload_reader,
198 writer,
199 crypto,
200 &config_pipe,
201 log_manager,
202 )?;
203
204 if let Some(ref arc_buf) = maybe_buf {
206 let buf = arc_buf.lock().unwrap();
207 snapshot.attach_output(buf.clone());
208 }
210
211 Ok(snapshot)
212}
213
214pub fn decrypt_stream_v2(
216 input: InputSource,
217 output: OutputSink,
218 master_key: &MasterKey,
219 params: DecryptParams,
220 config: ApiConfig,
221) -> Result<TelemetrySnapshot, StreamError> {
222 validate_decrypt_params(&master_key, ¶ms, None, None)?;
224
225 let final_config = config.with_defaults();
227
228 let reader = open_input(input)?;
229 let (writer, maybe_buf) = open_output(output, final_config.with_buf)?;
230
231 let (header, mut payload_reader) = PayloadReader::with_header(reader)?;
234
235 let (crypto, profile, log_manager) = setup_dec_context(&master_key, &header, final_config)?;
236 let config_pipe = PipelineConfig::new(profile, maybe_buf.clone());
237
238 let crypto = Arc::new(crypto);
240
241 let mut snapshot = decrypt_pipeline(
243 &mut payload_reader,
244 writer,
245 crypto,
246 &config_pipe,
247 log_manager,
248 )?;
249
250 if let Some(ref arc_buf) = maybe_buf {
252 let buf = arc_buf.lock().unwrap();
253 snapshot.attach_output(buf.clone());
254 }
256
257 Ok(snapshot)
258}
259
260
261pub fn validate_encrypt_params(
262 master_key: &MasterKey,
263 params: &EncryptParams,
264 workers: Option<usize>,
265 queue_cap: Option<usize>,
266
267) -> Result<(), StreamError> {
268 MasterKey::validate(&master_key)?;
270
271 let w = workers.unwrap_or(DEFAULT_WORKERS);
273 let q = queue_cap.unwrap_or(DEFAULT_QUEUE_CAP);
274
275 if !WORKERS_COUNT.contains(&w) {
276 return Err(StreamError::Validation(format!(
277 "invalid workers count: {w}, must be one of {:?}",
278 WORKERS_COUNT
279 )));
280 }
281 if !QUEUE_CAPS.contains(&q) {
282 return Err(StreamError::Validation(format!(
283 "invalid queue capacity: {q}, must be one of {:?}",
284 QUEUE_CAPS
285 )));
286 }
287
288 params.validate()?;
289 Ok(())
290}
291
292pub fn validate_decrypt_params(
293 master_key: &MasterKey,
294 params: &DecryptParams,
295 workers: Option<usize>,
296 queue_cap: Option<usize>,
297) -> Result<(), StreamError> {
298 if !MASTER_KEY_LENGTHS.contains(&master_key.len()) {
299 return Err(StreamError::Crypto(CryptoError::InvalidKeyLen {
300 expected: &MASTER_KEY_LENGTHS,
301 actual: master_key.len(),
302 }));
303 }
304
305 let w = workers.unwrap_or(DEFAULT_WORKERS);
307 let q = queue_cap.unwrap_or(DEFAULT_QUEUE_CAP);
308
309 if !WORKERS_COUNT.contains(&w) {
310 return Err(StreamError::Validation(format!(
311 "invalid workers count: {w}, must be one of {:?}",
312 WORKERS_COUNT
313 )));
314 }
315 if !QUEUE_CAPS.contains(&q) {
316 return Err(StreamError::Validation(format!(
317 "invalid queue capacity: {q}, must be one of {:?}",
318 QUEUE_CAPS
319 )));
320 }
321
322 params.validate()?;
323 Ok(())
324}
325
326pub fn validate_dictionary(dict: Option<&[u8]>) -> Result<(), StreamError> {
327 match dict {
328 None => Ok(()), Some(d) if d.is_empty() => Ok(()), Some(d) => {
331 if !is_valid_dictionary(d) {
333 Err(StreamError::Validation("invalid dictionary payload".into()))
334 } else {
335 Ok(())
336 }
337 }
338 }
339}
340
341pub fn is_valid_dictionary(dict: &[u8]) -> bool {
342 if dict.len() < MIN_DICT_LEN || dict.len() > MAX_DICT_LEN {
345 return false;
346 }
347
348 let magic = MAGIC_DICT;
350 dict.len() >= magic.len() && &dict[..magic.len()] == magic
351}