Skip to main content

sheetkit_core/
vba.rs

1//! VBA project extraction from macro-enabled workbooks (.xlsm).
2//!
3//! `.xlsm` files contain a `xl/vbaProject.bin` entry which is an OLE2
4//! Compound Binary File (CFB) holding VBA source code. This module
5//! provides read-only access to the raw binary and to individual VBA
6//! module source code.
7
8use std::io::{Cursor, Read as _};
9
10use crate::error::{Error, Result};
11
12/// Classification of a VBA module.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum VbaModuleType {
15    /// A standard code module (`.bas`).
16    Standard,
17    /// A class module (`.cls`).
18    Class,
19    /// A UserForm module.
20    Form,
21    /// A document module (e.g. Sheet code-behind).
22    Document,
23    /// The ThisWorkbook module.
24    ThisWorkbook,
25}
26
27/// A single VBA module with its name, source code, and type.
28#[derive(Debug, Clone)]
29pub struct VbaModule {
30    pub name: String,
31    pub source_code: String,
32    pub module_type: VbaModuleType,
33}
34
35/// Result of extracting a VBA project from a `.xlsm` file.
36///
37/// Contains extracted modules and any non-fatal warnings encountered
38/// during parsing (e.g., unreadable streams, decompression failures,
39/// unsupported codepages).
40#[derive(Debug, Clone)]
41pub struct VbaProject {
42    pub modules: Vec<VbaModule>,
43    pub warnings: Vec<String>,
44}
45
46/// Offset entry parsed from the `dir` stream for a single module.
47struct ModuleEntry {
48    name: String,
49    stream_name: String,
50    text_offset: u32,
51    module_type: VbaModuleType,
52}
53
54/// Parsed metadata from the `dir` stream.
55struct DirInfo {
56    entries: Vec<ModuleEntry>,
57    codepage: u16,
58}
59
60/// Extract VBA module source code from a `vbaProject.bin` binary blob.
61///
62/// Parses the OLE/CFB container, reads the `dir` stream to discover
63/// module metadata, then decompresses each module stream.
64///
65/// Returns a [`VbaProject`] containing extracted modules and any
66/// non-fatal warnings (e.g., modules that could not be read or
67/// decompressed, unsupported codepages).
68pub fn extract_vba_modules(vba_bin: &[u8]) -> Result<VbaProject> {
69    let cursor = Cursor::new(vba_bin);
70    let mut cfb = cfb::CompoundFile::open(cursor)
71        .map_err(|e| Error::Internal(format!("failed to open VBA project as CFB: {e}")))?;
72
73    // Find the VBA storage root. Typically `/VBA` or could be nested.
74    let vba_prefix = find_vba_prefix(&mut cfb)?;
75
76    // Read the `dir` stream to get module entries.
77    let dir_path = format!("{vba_prefix}dir");
78    let dir_data = read_cfb_stream(&mut cfb, &dir_path)?;
79
80    // The dir stream is compressed using MS-OVBA compression.
81    let decompressed_dir = decompress_vba_stream(&dir_data)?;
82
83    // Parse module entries and codepage from the decompressed dir stream.
84    let dir_info = parse_dir_stream(&decompressed_dir)?;
85
86    let mut modules = Vec::with_capacity(dir_info.entries.len());
87    let mut warnings = Vec::new();
88
89    for entry in dir_info.entries {
90        let stream_path = format!("{vba_prefix}{}", entry.stream_name);
91        let compressed_data = match read_cfb_stream(&mut cfb, &stream_path) {
92            Ok(data) => data,
93            Err(e) => {
94                warnings.push(format!(
95                    "skipped module '{}': failed to read stream '{}': {}",
96                    entry.name, stream_path, e
97                ));
98                continue;
99            }
100        };
101
102        // The module stream has text_offset bytes of "performance cache"
103        // (compiled code) followed by compressed source code.
104        if (entry.text_offset as usize) > compressed_data.len() {
105            warnings.push(format!(
106                "skipped module '{}': text_offset {} exceeds stream length {}",
107                entry.name,
108                entry.text_offset,
109                compressed_data.len()
110            ));
111            continue;
112        }
113        let source_compressed = &compressed_data[entry.text_offset as usize..];
114        let source_bytes = match decompress_vba_stream(source_compressed) {
115            Ok(b) => b,
116            Err(e) => {
117                warnings.push(format!(
118                    "skipped module '{}': decompression failed: {}",
119                    entry.name, e
120                ));
121                continue;
122            }
123        };
124
125        let source_code = decode_source_bytes(&source_bytes, dir_info.codepage, &mut warnings);
126
127        modules.push(VbaModule {
128            name: entry.name,
129            source_code,
130            module_type: entry.module_type,
131        });
132    }
133
134    Ok(VbaProject { modules, warnings })
135}
136
137/// Decode source bytes using the specified codepage.
138///
139/// Supports common codepages: 1252 (Western European), 932 (Japanese Shift-JIS),
140/// 949 (Korean), 936 (Simplified Chinese GBK), 65001 (UTF-8).
141/// For unrecognized codepages, falls back to UTF-8 lossy and emits a warning.
142fn decode_source_bytes(bytes: &[u8], codepage: u16, warnings: &mut Vec<String>) -> String {
143    match codepage {
144        65001 | 0 => String::from_utf8_lossy(bytes).into_owned(),
145        1252 => decode_single_byte(bytes, &WINDOWS_1252_HIGH),
146        932 => decode_shift_jis(bytes),
147        949 => decode_euc_kr(bytes),
148        936 => decode_gbk(bytes),
149        _ => {
150            warnings.push(format!(
151                "unsupported codepage {codepage}, falling back to UTF-8 lossy"
152            ));
153            String::from_utf8_lossy(bytes).into_owned()
154        }
155    }
156}
157
158/// Windows-1252 high-byte mapping (0x80..0xFF).
159/// Bytes 0x00..0x7F are identical to ASCII.
160static WINDOWS_1252_HIGH: [char; 128] = [
161    '\u{20AC}', '\u{0081}', '\u{201A}', '\u{0192}', '\u{201E}', '\u{2026}', '\u{2020}', '\u{2021}',
162    '\u{02C6}', '\u{2030}', '\u{0160}', '\u{2039}', '\u{0152}', '\u{008D}', '\u{017D}', '\u{008F}',
163    '\u{0090}', '\u{2018}', '\u{2019}', '\u{201C}', '\u{201D}', '\u{2022}', '\u{2013}', '\u{2014}',
164    '\u{02DC}', '\u{2122}', '\u{0161}', '\u{203A}', '\u{0153}', '\u{009D}', '\u{017E}', '\u{0178}',
165    '\u{00A0}', '\u{00A1}', '\u{00A2}', '\u{00A3}', '\u{00A4}', '\u{00A5}', '\u{00A6}', '\u{00A7}',
166    '\u{00A8}', '\u{00A9}', '\u{00AA}', '\u{00AB}', '\u{00AC}', '\u{00AD}', '\u{00AE}', '\u{00AF}',
167    '\u{00B0}', '\u{00B1}', '\u{00B2}', '\u{00B3}', '\u{00B4}', '\u{00B5}', '\u{00B6}', '\u{00B7}',
168    '\u{00B8}', '\u{00B9}', '\u{00BA}', '\u{00BB}', '\u{00BC}', '\u{00BD}', '\u{00BE}', '\u{00BF}',
169    '\u{00C0}', '\u{00C1}', '\u{00C2}', '\u{00C3}', '\u{00C4}', '\u{00C5}', '\u{00C6}', '\u{00C7}',
170    '\u{00C8}', '\u{00C9}', '\u{00CA}', '\u{00CB}', '\u{00CC}', '\u{00CD}', '\u{00CE}', '\u{00CF}',
171    '\u{00D0}', '\u{00D1}', '\u{00D2}', '\u{00D3}', '\u{00D4}', '\u{00D5}', '\u{00D6}', '\u{00D7}',
172    '\u{00D8}', '\u{00D9}', '\u{00DA}', '\u{00DB}', '\u{00DC}', '\u{00DD}', '\u{00DE}', '\u{00DF}',
173    '\u{00E0}', '\u{00E1}', '\u{00E2}', '\u{00E3}', '\u{00E4}', '\u{00E5}', '\u{00E6}', '\u{00E7}',
174    '\u{00E8}', '\u{00E9}', '\u{00EA}', '\u{00EB}', '\u{00EC}', '\u{00ED}', '\u{00EE}', '\u{00EF}',
175    '\u{00F0}', '\u{00F1}', '\u{00F2}', '\u{00F3}', '\u{00F4}', '\u{00F5}', '\u{00F6}', '\u{00F7}',
176    '\u{00F8}', '\u{00F9}', '\u{00FA}', '\u{00FB}', '\u{00FC}', '\u{00FD}', '\u{00FE}', '\u{00FF}',
177];
178
179/// Decode bytes using a single-byte codepage with the given high-byte table.
180fn decode_single_byte(bytes: &[u8], high_table: &[char; 128]) -> String {
181    let mut out = String::with_capacity(bytes.len());
182    for &b in bytes {
183        if b < 0x80 {
184            out.push(b as char);
185        } else {
186            out.push(high_table[(b - 0x80) as usize]);
187        }
188    }
189    out
190}
191
192/// Decode Shift-JIS (codepage 932) bytes to a String.
193/// Uses a best-effort approach: valid multi-byte sequences are decoded,
194/// invalid bytes are replaced with the Unicode replacement character.
195fn decode_shift_jis(bytes: &[u8]) -> String {
196    let mut out = String::with_capacity(bytes.len());
197    let mut i = 0;
198    while i < bytes.len() {
199        let b = bytes[i];
200        if b < 0x80 {
201            out.push(b as char);
202            i += 1;
203        } else if b == 0x80 || b == 0xA0 || b >= 0xFD {
204            out.push('\u{FFFD}');
205            i += 1;
206        } else if (0xA1..=0xDF).contains(&b) {
207            // Half-width katakana
208            out.push(char::from_u32(0xFF61 + (b as u32 - 0xA1)).unwrap_or('\u{FFFD}'));
209            i += 1;
210        } else if i + 1 < bytes.len() {
211            // Double-byte character -- fall back to replacement for simplicity
212            // Full Shift-JIS decoding requires a large mapping table.
213            out.push('\u{FFFD}');
214            i += 2;
215        } else {
216            out.push('\u{FFFD}');
217            i += 1;
218        }
219    }
220    out
221}
222
223/// Decode EUC-KR / codepage 949 bytes to a String.
224/// Best-effort: ASCII bytes pass through, multi-byte sequences use replacement.
225fn decode_euc_kr(bytes: &[u8]) -> String {
226    let mut out = String::with_capacity(bytes.len());
227    let mut i = 0;
228    while i < bytes.len() {
229        let b = bytes[i];
230        if b < 0x80 {
231            out.push(b as char);
232            i += 1;
233        } else if i + 1 < bytes.len() {
234            out.push('\u{FFFD}');
235            i += 2;
236        } else {
237            out.push('\u{FFFD}');
238            i += 1;
239        }
240    }
241    out
242}
243
244/// Decode GBK / codepage 936 bytes to a String.
245/// Best-effort: ASCII bytes pass through, multi-byte sequences use replacement.
246fn decode_gbk(bytes: &[u8]) -> String {
247    let mut out = String::with_capacity(bytes.len());
248    let mut i = 0;
249    while i < bytes.len() {
250        let b = bytes[i];
251        if b < 0x80 {
252            out.push(b as char);
253            i += 1;
254        } else if i + 1 < bytes.len() {
255            out.push('\u{FFFD}');
256            i += 2;
257        } else {
258            out.push('\u{FFFD}');
259            i += 1;
260        }
261    }
262    out
263}
264
265/// Find the VBA storage prefix inside the CFB container.
266/// Returns the path prefix ending with a separator (e.g. "VBA/").
267fn find_vba_prefix(cfb: &mut cfb::CompoundFile<Cursor<&[u8]>>) -> Result<String> {
268    // Collect all entries first to avoid borrow issues.
269    let entries: Vec<String> = cfb
270        .walk()
271        .map(|e| e.path().to_string_lossy().into_owned())
272        .collect();
273
274    // Look for a "dir" stream under a VBA storage.
275    for entry_path in &entries {
276        let normalized = entry_path.replace('\\', "/");
277        if normalized.ends_with("/dir") || normalized.ends_with("/DIR") {
278            let prefix = &normalized[..normalized.len() - 3];
279            return Ok(prefix.to_string());
280        }
281    }
282
283    // Try common paths directly.
284    for prefix in ["/VBA/", "VBA/", "/"] {
285        let dir_path = format!("{prefix}dir");
286        if cfb.is_stream(&dir_path) {
287            return Ok(prefix.to_string());
288        }
289    }
290
291    Err(Error::Internal(
292        "could not find VBA dir stream in vbaProject.bin".to_string(),
293    ))
294}
295
296/// Read a stream from the CFB container as raw bytes.
297fn read_cfb_stream(cfb: &mut cfb::CompoundFile<Cursor<&[u8]>>, path: &str) -> Result<Vec<u8>> {
298    let mut stream = cfb
299        .open_stream(path)
300        .map_err(|e| Error::Internal(format!("failed to open CFB stream '{path}': {e}")))?;
301    let mut data = Vec::new();
302    stream
303        .read_to_end(&mut data)
304        .map_err(|e| Error::Internal(format!("failed to read CFB stream '{path}': {e}")))?;
305    Ok(data)
306}
307
308/// Decompress a VBA compressed stream per MS-OVBA 2.4.1.
309///
310/// The format is:
311/// - 1 byte signature (0x01)
312/// - Sequence of compressed chunks, each starting with a 2-byte header
313/// - Each chunk contains a mix of literal bytes and copy tokens
314pub fn decompress_vba_stream(data: &[u8]) -> Result<Vec<u8>> {
315    if data.is_empty() {
316        return Ok(Vec::new());
317    }
318
319    if data[0] != 0x01 {
320        return Err(Error::Internal(format!(
321            "invalid VBA compression signature: expected 0x01, got 0x{:02X}",
322            data[0]
323        )));
324    }
325
326    let mut output = Vec::with_capacity(data.len() * 2);
327    let mut pos = 1; // skip signature byte
328
329    while pos < data.len() {
330        if pos + 1 >= data.len() {
331            break;
332        }
333
334        // Read chunk header (2 bytes, little-endian)
335        let header = u16::from_le_bytes([data[pos], data[pos + 1]]);
336        pos += 2;
337
338        let chunk_size = (header & 0x0FFF) as usize + 3;
339        let is_compressed = (header & 0x8000) != 0;
340
341        let chunk_end = (pos + chunk_size - 2).min(data.len());
342
343        if !is_compressed {
344            // Uncompressed chunk: raw bytes (4096 bytes max)
345            let raw_end = chunk_end.min(pos + 4096);
346            if raw_end > data.len() {
347                break;
348            }
349            output.extend_from_slice(&data[pos..raw_end]);
350            pos = chunk_end;
351            continue;
352        }
353
354        // Compressed chunk
355        let chunk_start_output = output.len();
356        while pos < chunk_end {
357            if pos >= data.len() {
358                break;
359            }
360
361            let flag_byte = data[pos];
362            pos += 1;
363
364            for bit_index in 0..8 {
365                if pos >= chunk_end {
366                    break;
367                }
368
369                if (flag_byte >> bit_index) & 1 == 0 {
370                    // Literal byte
371                    output.push(data[pos]);
372                    pos += 1;
373                } else {
374                    // Copy token (2 bytes, little-endian)
375                    if pos + 1 >= data.len() {
376                        pos = chunk_end;
377                        break;
378                    }
379                    let token = u16::from_le_bytes([data[pos], data[pos + 1]]);
380                    pos += 2;
381
382                    // Calculate the number of bits for the length and offset
383                    let decompressed_current = output.len() - chunk_start_output;
384                    let bit_count = max_bit_count(decompressed_current);
385                    let length_mask = 0xFFFF >> bit_count;
386                    let offset_mask = !length_mask;
387
388                    let length = ((token & length_mask) + 3) as usize;
389                    let offset = (((token & offset_mask) >> (16 - bit_count)) + 1) as usize;
390
391                    if offset > output.len() {
392                        // Invalid offset, skip
393                        break;
394                    }
395
396                    let copy_start = output.len() - offset;
397                    for i in 0..length {
398                        let byte = output[copy_start + (i % offset)];
399                        output.push(byte);
400                    }
401                }
402            }
403        }
404    }
405
406    Ok(output)
407}
408
409/// Calculate the bit count for the copy token offset field.
410/// Per MS-OVBA 2.4.1.3.19.1:
411/// The number of bits used for the offset is ceil(log2(decompressed_current)) with min 4.
412fn max_bit_count(decompressed_current: usize) -> u16 {
413    if decompressed_current <= 16 {
414        return 12;
415    }
416    if decompressed_current <= 32 {
417        return 11;
418    }
419    if decompressed_current <= 64 {
420        return 10;
421    }
422    if decompressed_current <= 128 {
423        return 9;
424    }
425    if decompressed_current <= 256 {
426        return 8;
427    }
428    if decompressed_current <= 512 {
429        return 7;
430    }
431    if decompressed_current <= 1024 {
432        return 6;
433    }
434    if decompressed_current <= 2048 {
435        return 5;
436    }
437    4 // >= 4096
438}
439
440/// Parse the decompressed `dir` stream to extract module entries and codepage.
441///
442/// The dir stream is a sequence of records with 2-byte IDs and 4-byte sizes.
443/// We look for MODULE_NAME, MODULE_STREAM_NAME, MODULE_OFFSET,
444/// MODULE_TYPE, and PROJECTCODEPAGE records.
445///
446/// MODULE_TYPE record 0x0021 indicates a procedural (standard) module.
447/// MODULE_TYPE record 0x0022 indicates a document/class module. When 0x0022
448/// is present, we refine the type to `Document`, `ThisWorkbook`, or `Class`
449/// based on the module name (since OOXML does not distinguish these subtypes
450/// at the record level).
451fn parse_dir_stream(data: &[u8]) -> Result<DirInfo> {
452    let mut pos = 0;
453    let mut modules = Vec::new();
454    let mut codepage: u16 = 1252; // Default to Windows-1252
455
456    // Current module being built
457    let mut current_name: Option<String> = None;
458    let mut current_stream_name: Option<String> = None;
459    let mut current_offset: u32 = 0;
460    let mut current_type = VbaModuleType::Standard;
461    let mut in_module = false;
462
463    while pos + 6 <= data.len() {
464        let record_id = u16::from_le_bytes([data[pos], data[pos + 1]]);
465        let record_size =
466            u32::from_le_bytes([data[pos + 2], data[pos + 3], data[pos + 4], data[pos + 5]])
467                as usize;
468        pos += 6;
469
470        if pos + record_size > data.len() {
471            break;
472        }
473
474        let record_data = &data[pos..pos + record_size];
475
476        match record_id {
477            // PROJECTCODEPAGE
478            0x0003 => {
479                if record_size >= 2 {
480                    codepage = u16::from_le_bytes([record_data[0], record_data[1]]);
481                }
482            }
483            // MODULENAME
484            0x0019 => {
485                if in_module {
486                    // Save previous module
487                    if let (Some(name), Some(stream)) =
488                        (current_name.take(), current_stream_name.take())
489                    {
490                        let refined_type = refine_module_type(&current_type, &name);
491                        modules.push(ModuleEntry {
492                            name,
493                            stream_name: stream,
494                            text_offset: current_offset,
495                            module_type: refined_type,
496                        });
497                    }
498                }
499                in_module = true;
500                current_name = Some(String::from_utf8_lossy(record_data).into_owned());
501                current_stream_name = None;
502                current_offset = 0;
503                current_type = VbaModuleType::Standard;
504            }
505            // MODULENAMEUNICODE
506            0x0047 => {
507                // UTF-16LE encoded name, prefer this over the ANSI name.
508                // Use only the even portion of the data; an odd trailing
509                // byte indicates a truncated record and is safely ignored.
510                if record_size >= 2 {
511                    let even_len = record_data.len() & !1;
512                    let u16_data: Vec<u16> = record_data[..even_len]
513                        .chunks_exact(2)
514                        .map(|c| u16::from_le_bytes([c[0], c[1]]))
515                        .collect();
516                    let name = String::from_utf16_lossy(&u16_data);
517                    // Remove trailing null if present
518                    let name = name.trim_end_matches('\0').to_string();
519                    if !name.is_empty() {
520                        current_name = Some(name);
521                    }
522                }
523            }
524            // MODULESTREAMNAME
525            0x001A => {
526                current_stream_name = Some(String::from_utf8_lossy(record_data).into_owned());
527                // The MODULENAMEUNICODE record for stream name follows with id 0x0032
528                // We handle it inline: skip the unicode record
529                if pos + record_size + 6 <= data.len() {
530                    let next_id =
531                        u16::from_le_bytes([data[pos + record_size], data[pos + record_size + 1]]);
532                    if next_id == 0x0032 {
533                        let next_size = u32::from_le_bytes([
534                            data[pos + record_size + 2],
535                            data[pos + record_size + 3],
536                            data[pos + record_size + 4],
537                            data[pos + record_size + 5],
538                        ]) as usize;
539                        // Skip the unicode stream name record
540                        pos += record_size + 6 + next_size;
541                        continue;
542                    }
543                }
544            }
545            // MODULEOFFSET
546            0x0031 => {
547                if record_size >= 4 {
548                    current_offset = u32::from_le_bytes([
549                        record_data[0],
550                        record_data[1],
551                        record_data[2],
552                        record_data[3],
553                    ]);
554                }
555            }
556            // MODULETYPE procedural (0x0021)
557            0x0021 => {
558                current_type = VbaModuleType::Standard;
559            }
560            // MODULETYPE document/class (0x0022)
561            0x0022 => {
562                // The dir stream only distinguishes procedural (0x0021) from
563                // non-procedural (0x0022). We refine 0x0022 into Document,
564                // ThisWorkbook, or Class based on the module name when the
565                // module is finalized.
566                current_type = VbaModuleType::Class;
567            }
568            // TERMINATOR for modules section (0x002B)
569            0x002B => {
570                // End of module list
571            }
572            _ => {}
573        }
574
575        pos += record_size;
576    }
577
578    // Save the last module if present
579    if in_module {
580        if let (Some(name), Some(stream)) = (current_name, current_stream_name) {
581            let refined_type = refine_module_type(&current_type, &name);
582            modules.push(ModuleEntry {
583                name,
584                stream_name: stream,
585                text_offset: current_offset,
586                module_type: refined_type,
587            });
588        }
589    }
590
591    Ok(DirInfo {
592        entries: modules,
593        codepage,
594    })
595}
596
597/// Refine the module type for non-procedural modules (0x0022) based on
598/// the module name. Procedural modules (0x0021) are always `Standard`.
599fn refine_module_type(base_type: &VbaModuleType, name: &str) -> VbaModuleType {
600    if *base_type == VbaModuleType::Standard {
601        return VbaModuleType::Standard;
602    }
603    let name_lower = name.to_lowercase();
604    if name_lower == "thisworkbook" {
605        VbaModuleType::ThisWorkbook
606    } else if name_lower.starts_with("sheet") {
607        VbaModuleType::Document
608    } else {
609        // Remains as Class (could be a class module or UserForm).
610        VbaModuleType::Class
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    #[test]
619    fn test_decompress_empty_input() {
620        let result = decompress_vba_stream(&[]);
621        assert!(result.is_ok());
622        assert!(result.unwrap().is_empty());
623    }
624
625    #[test]
626    fn test_decompress_invalid_signature() {
627        let result = decompress_vba_stream(&[0x00, 0x01, 0x02]);
628        assert!(result.is_err());
629        let err_msg = result.unwrap_err().to_string();
630        assert!(err_msg.contains("invalid VBA compression signature"));
631    }
632
633    #[test]
634    fn test_decompress_uncompressed_chunk() {
635        // Signature byte + uncompressed chunk header (size=3 -> 3+3-2=4 bytes, bit 15 clear)
636        // Header: chunk_size = 3 bytes (field = 3-3 = 0), not compressed (bit 15 = 0)
637        // So header = 0x0000 means size=3, uncompressed
638        let mut data = vec![0x01]; // signature
639                                   // Uncompressed chunk: header with bit 15 clear, size field = N-3
640                                   // For 4 bytes of data: chunk_size = 4, field = 4-3 = 1
641        let header: u16 = 0x0001; // bit 15 = 0 (uncompressed), size = 1+3-2 = 2 (actual chunk payload = 2)
642                                  // Wait, let me recalculate.
643                                  // chunk_size = (header & 0x0FFF) + 3 = field + 3
644                                  // The chunk payload is chunk_size - 2 = field + 1 bytes
645                                  // For 3 bytes of payload: field = 2, header = 0x0002
646        data.extend_from_slice(&header.to_le_bytes());
647        data.extend_from_slice(b"AB");
648        // This should produce "AB" but limited to min(chunk_end, pos+4096)
649        let result = decompress_vba_stream(&data).unwrap();
650        assert_eq!(&result, b"AB");
651    }
652
653    #[test]
654    fn test_decompress_real_compressed_data() {
655        // Test with a known compressed sequence from the MS-OVBA spec example.
656        // Compressed representation of "aaaaaaaaaaaaaaa" (15 'a's)
657        // Signature: 0x01
658        // Chunk header: compressed, size field
659        // Flag byte: 0b00000011 = 0x03 (bit 0: literal, bit 1: copy token)
660        // Actually building a minimal valid compressed stream:
661        // Signature: 0x01
662        // Chunk header: size = N-3, compressed bit set
663        // Then flag + data
664
665        // A simpler approach: verify that decompression of a manually built stream works.
666        let mut compressed = vec![0x01u8];
667        // Build chunk: 1 literal 'a', then copy token referencing offset=1, length=3
668        // Flag byte: 0b00000010 = bit 0 literal, bit 1 copy
669        // Literal: b'a'
670        // Copy token with decompressed_current=1 -> bit_count=12, length_mask=0x000F
671        // offset=1, length=3 -> offset_field=(1-1)<<4=0, length_field=3-3=0
672        // token = 0x0000
673        let flag = 0x02u8; // bits: 0=literal, 1=copy, rest=0
674        let literal = b'a';
675        let copy_token: u16 = 0x0000; // offset=1, length=3
676
677        let mut chunk_payload = Vec::new();
678        chunk_payload.push(flag);
679        chunk_payload.push(literal);
680        chunk_payload.extend_from_slice(&copy_token.to_le_bytes());
681
682        let chunk_size = chunk_payload.len() + 2; // +2 for header
683        let header: u16 = 0x8000 | ((chunk_size as u16 - 3) & 0x0FFF);
684        compressed.extend_from_slice(&header.to_le_bytes());
685        compressed.extend_from_slice(&chunk_payload);
686
687        let result = decompress_vba_stream(&compressed).unwrap();
688        assert_eq!(&result, b"aaaa"); // 1 literal + 3 from copy
689    }
690
691    #[test]
692    fn test_max_bit_count() {
693        assert_eq!(max_bit_count(0), 12);
694        assert_eq!(max_bit_count(1), 12);
695        assert_eq!(max_bit_count(16), 12);
696        assert_eq!(max_bit_count(17), 11);
697        assert_eq!(max_bit_count(32), 11);
698        assert_eq!(max_bit_count(33), 10);
699        assert_eq!(max_bit_count(64), 10);
700        assert_eq!(max_bit_count(65), 9);
701        assert_eq!(max_bit_count(128), 9);
702        assert_eq!(max_bit_count(129), 8);
703        assert_eq!(max_bit_count(256), 8);
704        assert_eq!(max_bit_count(257), 7);
705        assert_eq!(max_bit_count(512), 7);
706        assert_eq!(max_bit_count(513), 6);
707        assert_eq!(max_bit_count(1024), 6);
708        assert_eq!(max_bit_count(1025), 5);
709        assert_eq!(max_bit_count(2048), 5);
710        assert_eq!(max_bit_count(2049), 4);
711        assert_eq!(max_bit_count(4096), 4);
712    }
713
714    #[test]
715    fn test_parse_dir_stream_empty() {
716        let result = parse_dir_stream(&[]);
717        assert!(result.is_ok());
718        let info = result.unwrap();
719        assert!(info.entries.is_empty());
720        assert_eq!(info.codepage, 1252);
721    }
722
723    #[test]
724    fn test_extract_vba_modules_invalid_cfb() {
725        let result = extract_vba_modules(b"not a CFB file");
726        assert!(result.is_err());
727        let err_msg = result.unwrap_err().to_string();
728        assert!(err_msg.contains("failed to open VBA project as CFB"));
729    }
730
731    #[test]
732    fn test_vba_module_type_clone() {
733        let t = VbaModuleType::Standard;
734        let t2 = t.clone();
735        assert_eq!(t, t2);
736    }
737
738    #[test]
739    fn test_vba_module_debug() {
740        let m = VbaModule {
741            name: "Module1".to_string(),
742            source_code: "Sub Test()\nEnd Sub".to_string(),
743            module_type: VbaModuleType::Standard,
744        };
745        let debug = format!("{:?}", m);
746        assert!(debug.contains("Module1"));
747    }
748
749    #[test]
750    fn test_vba_roundtrip_with_xlsm() {
751        use std::io::{Read as _, Write as _};
752
753        // Build a minimal CFB container with a VBA dir stream and a module
754        let vba_bin = build_test_vba_project();
755
756        // Create a valid xlsx using the Workbook API, then inject vbaProject.bin
757        let base_wb = crate::workbook::Workbook::new();
758        let base_buf = base_wb.save_to_buffer().unwrap();
759
760        // Rewrite the ZIP, adding the vbaProject.bin entry
761        let mut buf = Vec::new();
762        {
763            let base_cursor = std::io::Cursor::new(&base_buf);
764            let mut base_archive = zip::ZipArchive::new(base_cursor).unwrap();
765
766            let out_cursor = std::io::Cursor::new(&mut buf);
767            let mut zip = zip::ZipWriter::new(out_cursor);
768            let options = zip::write::SimpleFileOptions::default()
769                .compression_method(zip::CompressionMethod::Deflated);
770
771            for i in 0..base_archive.len() {
772                let mut entry = base_archive.by_index(i).unwrap();
773                let name = entry.name().to_string();
774                zip.start_file(&name, options).unwrap();
775                let mut data = Vec::new();
776                entry.read_to_end(&mut data).unwrap();
777                zip.write_all(&data).unwrap();
778            }
779
780            zip.start_file("xl/vbaProject.bin", options).unwrap();
781            zip.write_all(&vba_bin).unwrap();
782            zip.finish().unwrap();
783        }
784
785        // Open and extract
786        let wb = crate::workbook::Workbook::open_from_buffer(&buf).unwrap();
787
788        // Raw VBA project should be available
789        let raw = wb.get_vba_project();
790        assert!(raw.is_some(), "VBA project binary should be present");
791        assert_eq!(raw.unwrap(), vba_bin);
792    }
793
794    #[test]
795    fn test_xlsx_without_vba_returns_none() {
796        let wb = crate::workbook::Workbook::new();
797        assert!(wb.get_vba_project().is_none());
798        assert!(wb.get_vba_modules().unwrap().is_none());
799    }
800
801    #[test]
802    fn test_xlsx_roundtrip_no_vba() {
803        let wb = crate::workbook::Workbook::new();
804        let buf = wb.save_to_buffer().unwrap();
805        let wb2 = crate::workbook::Workbook::open_from_buffer(&buf).unwrap();
806        assert!(wb2.get_vba_project().is_none());
807    }
808
809    #[test]
810    fn test_get_vba_modules_from_test_project() {
811        use std::io::{Read as _, Write as _};
812
813        let vba_bin = build_test_vba_project();
814
815        // Create a valid xlsx, then inject vbaProject.bin
816        let base_wb = crate::workbook::Workbook::new();
817        let base_buf = base_wb.save_to_buffer().unwrap();
818
819        let mut buf = Vec::new();
820        {
821            let base_cursor = std::io::Cursor::new(&base_buf);
822            let mut base_archive = zip::ZipArchive::new(base_cursor).unwrap();
823
824            let out_cursor = std::io::Cursor::new(&mut buf);
825            let mut zip = zip::ZipWriter::new(out_cursor);
826            let options = zip::write::SimpleFileOptions::default()
827                .compression_method(zip::CompressionMethod::Deflated);
828
829            for i in 0..base_archive.len() {
830                let mut entry = base_archive.by_index(i).unwrap();
831                let name = entry.name().to_string();
832                zip.start_file(&name, options).unwrap();
833                let mut data = Vec::new();
834                entry.read_to_end(&mut data).unwrap();
835                zip.write_all(&data).unwrap();
836            }
837
838            zip.start_file("xl/vbaProject.bin", options).unwrap();
839            zip.write_all(&vba_bin).unwrap();
840            zip.finish().unwrap();
841        }
842
843        let wb = crate::workbook::Workbook::open_from_buffer(&buf).unwrap();
844        let project = wb.get_vba_modules().unwrap();
845        assert!(project.is_some(), "should have VBA modules");
846        let project = project.unwrap();
847        assert_eq!(project.modules.len(), 1);
848        assert_eq!(project.modules[0].name, "Module1");
849        assert_eq!(project.modules[0].module_type, VbaModuleType::Standard);
850        assert!(
851            project.modules[0].source_code.contains("Sub Hello()"),
852            "source should contain Sub Hello(), got: {}",
853            project.modules[0].source_code
854        );
855    }
856
857    #[test]
858    fn test_vba_project_preserved_in_save_roundtrip() {
859        use std::io::{Read as _, Write as _};
860
861        let vba_bin = build_test_vba_project();
862
863        let base_wb = crate::workbook::Workbook::new();
864        let base_buf = base_wb.save_to_buffer().unwrap();
865
866        let mut buf = Vec::new();
867        {
868            let base_cursor = std::io::Cursor::new(&base_buf);
869            let mut base_archive = zip::ZipArchive::new(base_cursor).unwrap();
870
871            let out_cursor = std::io::Cursor::new(&mut buf);
872            let mut zip = zip::ZipWriter::new(out_cursor);
873            let options = zip::write::SimpleFileOptions::default()
874                .compression_method(zip::CompressionMethod::Deflated);
875
876            for i in 0..base_archive.len() {
877                let mut entry = base_archive.by_index(i).unwrap();
878                let name = entry.name().to_string();
879                zip.start_file(&name, options).unwrap();
880                let mut data = Vec::new();
881                entry.read_to_end(&mut data).unwrap();
882                zip.write_all(&data).unwrap();
883            }
884
885            zip.start_file("xl/vbaProject.bin", options).unwrap();
886            zip.write_all(&vba_bin).unwrap();
887            zip.finish().unwrap();
888        }
889
890        // Open, then save again
891        let wb = crate::workbook::Workbook::open_from_buffer(&buf).unwrap();
892        let saved_buf = wb.save_to_buffer().unwrap();
893
894        // Re-open and verify VBA is preserved
895        let wb2 = crate::workbook::Workbook::open_from_buffer(&saved_buf).unwrap();
896        let raw = wb2.get_vba_project();
897        assert!(raw.is_some(), "VBA project should survive save roundtrip");
898        assert_eq!(raw.unwrap(), vba_bin);
899
900        // Modules should still be extractable
901        let project = wb2.get_vba_modules().unwrap().unwrap();
902        assert_eq!(project.modules.len(), 1);
903        assert_eq!(project.modules[0].name, "Module1");
904    }
905
906    /// Build a minimal CFB container that looks like a VBA project.
907    fn build_test_vba_project() -> Vec<u8> {
908        let mut buf = Vec::new();
909        let cursor = std::io::Cursor::new(&mut buf);
910        let mut cfb = cfb::CompoundFile::create(cursor).unwrap();
911
912        // Create VBA storage
913        cfb.create_storage("/VBA").unwrap();
914
915        // Build a minimal dir stream
916        let dir_data = build_minimal_dir_stream("Module1");
917
918        // Compress the dir stream
919        let compressed_dir = compress_for_test(&dir_data);
920
921        // Write dir stream
922        {
923            let mut stream = cfb.create_stream("/VBA/dir").unwrap();
924            std::io::Write::write_all(&mut stream, &compressed_dir).unwrap();
925        }
926
927        // Build module source: "Sub Hello()\nEnd Sub\n"
928        let source = b"Sub Hello()\r\nEnd Sub\r\n";
929        let compressed_source = compress_for_test(source);
930
931        // The module stream has 0 bytes of performance cache + compressed source.
932        // (text_offset = 0 in the dir stream)
933        {
934            let mut stream = cfb.create_stream("/VBA/Module1").unwrap();
935            std::io::Write::write_all(&mut stream, &compressed_source).unwrap();
936        }
937
938        // Create _VBA_PROJECT stream (required for validity, can be minimal)
939        {
940            let mut stream = cfb.create_stream("/VBA/_VBA_PROJECT").unwrap();
941            // Minimal header: version bytes
942            let header = [0xCC, 0x61, 0x00, 0x00, 0x00, 0x00, 0x00];
943            std::io::Write::write_all(&mut stream, &header).unwrap();
944        }
945
946        cfb.flush().unwrap();
947        buf
948    }
949
950    /// Build a minimal dir stream binary for one standard module.
951    fn build_minimal_dir_stream(module_name: &str) -> Vec<u8> {
952        let mut data = Vec::new();
953        let name_bytes = module_name.as_bytes();
954
955        // PROJECTSYSKIND record (0x0001): 4 bytes, value = 1 (Win32)
956        write_dir_record(&mut data, 0x0001, &1u32.to_le_bytes());
957
958        // PROJECTLCID record (0x0002): 4 bytes
959        write_dir_record(&mut data, 0x0002, &0x0409u32.to_le_bytes());
960
961        // PROJECTLCIDINVOKE record (0x0014): 4 bytes
962        write_dir_record(&mut data, 0x0014, &0x0409u32.to_le_bytes());
963
964        // PROJECTCODEPAGE record (0x0003): 2 bytes (1252 = Windows-1252)
965        write_dir_record(&mut data, 0x0003, &1252u16.to_le_bytes());
966
967        // PROJECTNAME record (0x0004)
968        write_dir_record(&mut data, 0x0004, b"VBAProject");
969
970        // PROJECTDOCSTRING record (0x0005): empty
971        write_dir_record(&mut data, 0x0005, &[]);
972        // Unicode variant (0x0040): empty
973        write_dir_record(&mut data, 0x0040, &[]);
974
975        // PROJECTHELPFILEPATH record (0x0006): empty
976        write_dir_record(&mut data, 0x0006, &[]);
977        // Unicode variant (0x003D): empty
978        write_dir_record(&mut data, 0x003D, &[]);
979
980        // PROJECTHELPCONTEXT (0x0007): 4 bytes
981        write_dir_record(&mut data, 0x0007, &0u32.to_le_bytes());
982
983        // PROJECTLIBFLAGS (0x0008): 4 bytes
984        write_dir_record(&mut data, 0x0008, &0u32.to_le_bytes());
985
986        // PROJECTVERSION (0x0009): 4 + 2 bytes (major + minor)
987        let mut version = Vec::new();
988        version.extend_from_slice(&1u32.to_le_bytes());
989        version.extend_from_slice(&0u16.to_le_bytes());
990        // Version record is special: id=0x0009, size=4 for major, then 2 bytes minor appended
991        write_dir_record(&mut data, 0x0009, &version);
992
993        // PROJECTCONSTANTS (0x000C): empty
994        write_dir_record(&mut data, 0x000C, &[]);
995        // Unicode variant (0x003C): empty
996        write_dir_record(&mut data, 0x003C, &[]);
997
998        // MODULES count record: id=0x000F, size=2
999        let module_count: u16 = 1;
1000        write_dir_record(&mut data, 0x000F, &module_count.to_le_bytes());
1001
1002        // PROJECTCOOKIE record (0x0013): 2 bytes
1003        write_dir_record(&mut data, 0x0013, &0u16.to_le_bytes());
1004
1005        // MODULE_NAME record (0x0019)
1006        write_dir_record(&mut data, 0x0019, name_bytes);
1007
1008        // MODULE_STREAM_NAME record (0x001A)
1009        write_dir_record(&mut data, 0x001A, name_bytes);
1010        // Unicode variant (0x0032)
1011        let name_utf16: Vec<u8> = module_name
1012            .encode_utf16()
1013            .flat_map(|c| c.to_le_bytes())
1014            .collect();
1015        write_dir_record(&mut data, 0x0032, &name_utf16);
1016
1017        // MODULE_OFFSET record (0x0031): 4 bytes (offset = 0)
1018        write_dir_record(&mut data, 0x0031, &0u32.to_le_bytes());
1019
1020        // MODULE_TYPE procedural (0x0021): 0 bytes
1021        write_dir_record(&mut data, 0x0021, &[]);
1022
1023        // MODULE_TERMINATOR (0x002B): 0 bytes
1024        write_dir_record(&mut data, 0x002B, &[]);
1025
1026        // End of modules
1027        // Global TERMINATOR (0x0010): 0 bytes
1028        write_dir_record(&mut data, 0x0010, &[]);
1029
1030        data
1031    }
1032
1033    fn write_dir_record(buf: &mut Vec<u8>, id: u16, data: &[u8]) {
1034        buf.extend_from_slice(&id.to_le_bytes());
1035        buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
1036        buf.extend_from_slice(data);
1037    }
1038
1039    /// Minimal MS-OVBA "compression" that produces an uncompressed container.
1040    /// Signature 0x01 + one uncompressed chunk per 4096 bytes.
1041    fn compress_for_test(data: &[u8]) -> Vec<u8> {
1042        let mut result = vec![0x01u8]; // signature
1043        let mut pos = 0;
1044        while pos < data.len() {
1045            let chunk_len = (data.len() - pos).min(4096);
1046            let chunk_data = &data[pos..pos + chunk_len];
1047            // Chunk header: bit 15 = 0 (uncompressed), bits 0-11 = chunk_len + 2 - 3
1048            let header: u16 = (chunk_len as u16 + 2).wrapping_sub(3) & 0x0FFF;
1049            result.extend_from_slice(&header.to_le_bytes());
1050            result.extend_from_slice(chunk_data);
1051            // Pad to 4096 if needed
1052            for _ in chunk_len..4096 {
1053                result.push(0x00);
1054            }
1055            pos += chunk_len;
1056        }
1057        result
1058    }
1059}