Skip to main content

streaming_crypto/core_api/stream_v2/
core.rs

1
2// ## 2️⃣ `core.rs` — stable public API
3
4use std::sync::Arc;
5use std::ops::Deref;
6
7use crate::{
8    constants::{MASTER_KEY_LENGTHS, MAGIC_DICT, MAX_DICT_LEN, MIN_DICT_LEN, DEFAULT_QUEUE_CAP, DEFAULT_WORKERS, QUEUE_CAPS, WORKERS_COUNT}, 
9    crypto::{CryptoError, DigestAlg, derive_session_key_32}, 
10    headers::HeaderV1, recovery::AsyncLogManager, 
11    parallelism::{HybridParallelismProfile, ParallelismConfig}, 
12    stream_v2::{
13        io::{InputSource, OutputSink, PayloadReader, open_input, open_output}, 
14        pipeline::{PipelineConfig, decrypt_pipeline, encrypt_pipeline}, 
15    segment_worker::{DecryptContext, EncryptContext}}, 
16    telemetry::TelemetrySnapshot, 
17    types::StreamError
18};
19
20#[derive(Clone)]
21pub struct MasterKey(Vec<u8>);
22
23impl Deref for MasterKey {
24    type Target = [u8];
25
26    fn deref(&self) -> &Self::Target {
27        &self.0
28    }
29}
30
31impl MasterKey {
32    pub fn new(bytes: Vec<u8>) -> Self {
33        // let len = bytes.len();
34
35        // if MASTER_KEY_LENGTHS.contains(&len) {
36        //     Ok(Self(bytes))
37        // } else {
38        //     Err(StreamError::Crypto(
39        //         CryptoError::InvalidKeyLen {
40        //             expected: &MASTER_KEY_LENGTHS,
41        //             actual: len,
42        //         },
43        //     ))
44        // }
45        Self(bytes)
46    }
47
48    pub fn as_slice(&self) -> &[u8] {
49        &self.0
50    }
51
52    /// Validate that the provided bytes match one of the allowed key lengths.
53    pub fn validate(bytes: &[u8]) -> Result<(), StreamError> {
54        let len = bytes.len();
55
56        if MASTER_KEY_LENGTHS.contains(&len) {
57            Ok(())
58        } else {
59            Err(StreamError::Crypto(
60                CryptoError::InvalidKeyLen {
61                    expected: &MASTER_KEY_LENGTHS,
62                    actual: len,
63                },
64            ))
65        }
66    }
67}
68
69
70#[derive(Clone, Debug)]
71pub struct EncryptParams<'a> {
72    pub header: HeaderV1,
73    pub dict: Option<&'a [u8]>,
74}
75impl<'a> EncryptParams<'a> {
76    pub fn validate(&self) -> Result<(), StreamError> {
77        validate_dictionary(self.dict.as_deref())?;
78        // If HeaderV1 has validation logic, we can enable it here:
79        // self.header.validate_header()?;
80        Ok(())
81    }
82}
83#[derive(Clone, Debug)]
84pub struct DecryptParams;
85impl DecryptParams {
86    pub fn validate(&self) -> Result<(), StreamError> {
87        Ok(())
88    }
89}
90
91#[derive(Debug, Clone)]
92pub struct ApiConfig {
93    /// Whether to capture the output buffer in memory.
94    /// - `None` or `Some(false)` → no buffer capture (production default).
95    /// - `Some(true)` → capture buffer for tests/benchmarks.
96    pub with_buf: Option<bool>,
97
98    /// Whether to collect detailed metrics during pipeline execution.
99    /// Currently unused, reserved for future expansion.
100    pub collect_metrics: Option<bool>,
101
102    /// 
103    /// Supported digest algorithms (extensible).
104    pub alg: Option<DigestAlg>,
105
106    /// 
107    /// Parallelism configuration.
108    pub parallelism: Option<ParallelismConfig>,
109}
110
111impl Default for ApiConfig {
112    fn default() -> Self {
113        Self {
114            with_buf: Some(false),      // default: no buffer
115            collect_metrics: Some(false), // default: no metrics
116            alg: Some(DigestAlg::Blake3), // default: Blake3
117            parallelism: Some(ParallelismConfig::default()),
118        }
119    }
120}
121
122impl ApiConfig {
123    pub fn new(with_buf: Option<bool>, collect_metrics: Option<bool>, alg: Option<DigestAlg>, parallelism: Option<ParallelismConfig>) -> Self {
124        Self {
125            with_buf: with_buf.or(Some(false)),
126            collect_metrics: collect_metrics.or(Some(false)),
127            alg: alg.or(Some(DigestAlg::Blake3)),
128            parallelism: Some(parallelism.unwrap_or_default()),
129        }
130    }
131    /// Merge user-provided values with defaults 
132    pub fn with_defaults(self) -> Self { 
133        let defaults = ApiConfig::default(); 
134        
135        Self { 
136            with_buf: self.with_buf.or(defaults.with_buf), 
137            collect_metrics: self.collect_metrics.or(defaults.collect_metrics), 
138            alg: self.alg.or(defaults.alg), 
139            parallelism: self.parallelism.or(defaults.parallelism), 
140        } 
141    }
142}
143
144fn setup_enc_context(master_key: &MasterKey, header: &HeaderV1, config: ApiConfig)
145    -> Result<(EncryptContext, HybridParallelismProfile, Arc<AsyncLogManager>), StreamError> 
146{
147    let session_key = derive_session_key_32(&master_key, header).map_err(StreamError::Crypto)?;
148
149    let profile = HybridParallelismProfile::from_stream_header(header.clone(), config.parallelism)?;
150    let context = EncryptContext::new(header.clone(), profile.clone(), &session_key, config.alg.unwrap())
151        .map_err(StreamError::SegmentWorker)?;
152    let log_manager = Arc::new(AsyncLogManager::new("stream_v2_enc.log", 100)?);
153
154    Ok((context, profile, log_manager))
155}
156
157fn setup_dec_context(master_key: &MasterKey, header: &HeaderV1, config: ApiConfig)
158    -> Result<(DecryptContext, HybridParallelismProfile, Arc<AsyncLogManager>), StreamError> 
159{
160    let session_key = derive_session_key_32(&master_key, header).map_err(StreamError::Crypto)?;
161
162    let profile = HybridParallelismProfile::from_stream_header(header.clone(), config.parallelism)?;
163    let context = DecryptContext::from_stream_header(header.clone(), profile.clone(), &session_key, config.alg.unwrap())
164        .map_err(StreamError::SegmentWorker)?;
165    let log_manager = Arc::new(AsyncLogManager::new("stream_v2_dec.log", 100)?);
166
167    Ok((context, profile, log_manager))
168}
169
170/// 🔐 Encrypt stream (v2)
171pub fn encrypt_stream_v2(
172    input: InputSource,
173    output: OutputSink,
174    master_key: &MasterKey,
175    params: EncryptParams,
176    config: ApiConfig,
177) -> Result<TelemetrySnapshot, StreamError> {
178    validate_encrypt_params(&master_key, &params, None, None)?;
179
180    // Normalize with defaults 
181    let final_config = config.with_defaults();
182
183    let reader = open_input(input)?;
184    let (writer, maybe_buf) = open_output(output, final_config.with_buf)?;
185
186    // ---- Read stream header ----
187    let mut payload_reader = PayloadReader::new(reader);
188
189    let (crypto, profile, log_manager) = setup_enc_context(&master_key, &params.header, final_config)?;
190    let config_pipe = PipelineConfig::new(profile, maybe_buf.clone());
191
192    // Wrap in Arc before passing into pipeline
193    let crypto = Arc::new(crypto);
194
195    // Call pipeline
196    let mut snapshot = encrypt_pipeline(
197        &mut payload_reader,
198        writer,
199        crypto,
200        &config_pipe,
201        log_manager,
202    )?;
203
204    // --- Telemetry buffer extraction for tests --- 
205    if let Some(ref arc_buf) = maybe_buf { 
206        let buf = arc_buf.lock().unwrap(); 
207        snapshot.attach_output(buf.clone()); 
208        // clone Vec<u8> into snapshot.output 
209    }
210
211    Ok(snapshot)
212}
213
214/// 🔓 Decrypt stream (v2)
215pub fn decrypt_stream_v2(
216    input: InputSource,
217    output: OutputSink,
218    master_key: &MasterKey,
219    params: DecryptParams,
220    config: ApiConfig,
221) -> Result<TelemetrySnapshot, StreamError> {
222    //
223    validate_decrypt_params(&master_key, &params, None, None)?;
224
225    // Normalize with defaults 
226    let final_config = config.with_defaults();
227
228    let reader = open_input(input)?;
229    let (writer, maybe_buf) = open_output(output, final_config.with_buf)?;
230
231    // ---- Read stream header ----
232    // Assert reader is positioned correctly
233    let (header, mut payload_reader) = PayloadReader::with_header(reader)?;
234
235    let (crypto, profile, log_manager) = setup_dec_context(&master_key, &header, final_config)?;
236    let config_pipe = PipelineConfig::new(profile, maybe_buf.clone());
237    
238    // Wrap in Arc before passing into pipeline
239    let crypto = Arc::new(crypto);
240
241    // Call pipeline
242    let mut snapshot = decrypt_pipeline(
243        &mut payload_reader,
244        writer,
245        crypto,
246        &config_pipe,
247        log_manager,
248    )?;
249
250    // --- Telemetry buffer extraction for tests --- 
251    if let Some(ref arc_buf) = maybe_buf { 
252        let buf = arc_buf.lock().unwrap(); 
253        snapshot.attach_output(buf.clone()); 
254        // clone Vec<u8> into snapshot.output 
255    }
256
257    Ok(snapshot)
258}
259
260
261pub fn validate_encrypt_params(
262    master_key: &MasterKey,
263    params: &EncryptParams,
264    workers: Option<usize>,
265    queue_cap: Option<usize>,
266
267) -> Result<(), StreamError> {
268    // --- Master key length ---
269    MasterKey::validate(&master_key)?;
270
271    // --- Resolve defaults ---
272    let w  = workers.unwrap_or(DEFAULT_WORKERS);
273    let q  = queue_cap.unwrap_or(DEFAULT_QUEUE_CAP);
274
275    if !WORKERS_COUNT.contains(&w) {
276        return Err(StreamError::Validation(format!(
277            "invalid workers count: {w}, must be one of {:?}",
278            WORKERS_COUNT
279        )));
280    }
281    if !QUEUE_CAPS.contains(&q) {
282        return Err(StreamError::Validation(format!(
283            "invalid queue capacity: {q}, must be one of {:?}",
284            QUEUE_CAPS
285        )));
286    }
287
288    params.validate()?;
289    Ok(())
290}
291
292pub fn validate_decrypt_params(
293    master_key: &MasterKey,
294    params: &DecryptParams,
295    workers: Option<usize>,
296    queue_cap: Option<usize>,
297) -> Result<(), StreamError> {
298    if !MASTER_KEY_LENGTHS.contains(&master_key.len()) {
299        return Err(StreamError::Crypto(CryptoError::InvalidKeyLen {
300            expected: &MASTER_KEY_LENGTHS,
301            actual: master_key.len(),
302        }));
303    }
304
305    // --- Resolve defaults ---
306    let w  = workers.unwrap_or(DEFAULT_WORKERS);
307    let q  = queue_cap.unwrap_or(DEFAULT_QUEUE_CAP);
308
309    if !WORKERS_COUNT.contains(&w) {
310        return Err(StreamError::Validation(format!(
311            "invalid workers count: {w}, must be one of {:?}",
312            WORKERS_COUNT
313        )));
314    }
315    if !QUEUE_CAPS.contains(&q) {
316        return Err(StreamError::Validation(format!(
317            "invalid queue capacity: {q}, must be one of {:?}",
318            QUEUE_CAPS
319        )));
320    }
321
322    params.validate()?;
323    Ok(())
324}
325
326pub fn validate_dictionary(dict: Option<&[u8]>) -> Result<(), StreamError> {
327    match dict {
328        None => Ok(()), // no dictionary supplied
329        Some(d) if d.is_empty() => Ok(()), // empty Vec also means "no dictionary"
330        Some(d) => {
331            // Non-empty dictionary must pass validation
332            if !is_valid_dictionary(d) {
333                Err(StreamError::Validation("invalid dictionary payload".into()))
334            } else {
335                Ok(())
336            }
337        }
338    }
339}
340
341pub fn is_valid_dictionary(dict: &[u8]) -> bool {
342    // Replace with the actual validation logic:
343    // e.g. check header bytes, length constraints, codec id, etc.
344    if dict.len() < MIN_DICT_LEN || dict.len() > MAX_DICT_LEN {
345        return false;
346    }
347
348    // First 4 bytes to be a magic number
349    let magic = MAGIC_DICT;
350    dict.len() >= magic.len() && &dict[..magic.len()] == magic
351}