tokio_ar/
header.rs

1use crate::archive::{
2    BSD_SORTED_SYMBOL_LOOKUP_TABLE_ID, BSD_SYMBOL_LOOKUP_TABLE_ID,
3    GNU_NAME_TABLE_ID, GNU_SYMBOL_LOOKUP_TABLE_ID, Variant,
4};
5use crate::error::annotate;
6use std::cmp::min;
7use std::collections::HashMap;
8use std::fs::Metadata;
9use std::io::{Error, ErrorKind, Result};
10use std::str;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12
13#[cfg(unix)]
14use std::os::unix::fs::MetadataExt;
15
16const ENTRY_HEADER_LEN: usize = 60;
17
18/// Representation of an archive entry header.
19#[derive(Clone, Debug, Eq, PartialEq)]
20pub struct Header {
21    identifier: Vec<u8>,
22    mtime: u64,
23    uid: u32,
24    gid: u32,
25    mode: u32,
26    size: u64,
27}
28
29impl Header {
30    /// Creates a header with the given file identifier and size, and all
31    /// other fields set to zero.
32    pub fn new(identifier: Vec<u8>, size: u64) -> Header {
33        Header { identifier, mtime: 0, uid: 0, gid: 0, mode: 0, size }
34    }
35
36    /// Creates a header with the given file identifier and all other fields
37    /// set from the given filesystem metadata.
38    #[cfg(unix)]
39    pub fn from_metadata(identifier: Vec<u8>, meta: &Metadata) -> Header {
40        Header {
41            identifier,
42            mtime: meta.mtime() as u64,
43            uid: meta.uid(),
44            gid: meta.gid(),
45            mode: meta.mode(),
46            size: meta.len(),
47        }
48    }
49
50    #[cfg(not(unix))]
51    pub fn from_metadata(identifier: Vec<u8>, meta: &Metadata) -> Header {
52        Header::new(identifier, meta.len())
53    }
54
55    /// Returns the file identifier.
56    pub fn identifier(&self) -> &[u8] {
57        &self.identifier
58    }
59
60    /// Sets the file identifier.
61    pub fn set_identifier(&mut self, identifier: Vec<u8>) {
62        self.identifier = identifier;
63    }
64
65    /// Returns the last modification time in Unix time format.
66    pub fn mtime(&self) -> u64 {
67        self.mtime
68    }
69
70    /// Sets the last modification time in Unix time format.
71    pub fn set_mtime(&mut self, mtime: u64) {
72        self.mtime = mtime;
73    }
74
75    /// Returns the value of the owner's user ID field.
76    pub fn uid(&self) -> u32 {
77        self.uid
78    }
79
80    /// Sets the value of the owner's user ID field.
81    pub fn set_uid(&mut self, uid: u32) {
82        self.uid = uid;
83    }
84
85    /// Returns the value of the group's user ID field.
86    pub fn gid(&self) -> u32 {
87        self.gid
88    }
89
90    /// Returns the value of the group's user ID field.
91    pub fn set_gid(&mut self, gid: u32) {
92        self.gid = gid;
93    }
94
95    /// Returns the mode bits for this file.
96    pub fn mode(&self) -> u32 {
97        self.mode
98    }
99
100    /// Sets the mode bits for this file.
101    pub fn set_mode(&mut self, mode: u32) {
102        self.mode = mode;
103    }
104
105    /// Returns the length of the file, in bytes.
106    pub fn size(&self) -> u64 {
107        self.size
108    }
109
110    /// Sets the length of the file, in bytes.
111    pub fn set_size(&mut self, size: u64) {
112        self.size = size;
113    }
114
115    /// Parses and returns the next header and its length.  Returns `Ok(None)`
116    /// if we are at EOF.
117    pub(crate) async fn read<R>(
118        reader: &mut R,
119        variant: &mut Variant,
120        name_table: &mut Vec<u8>,
121    ) -> Result<Option<(Header, u64)>>
122    where
123        R: AsyncRead + Unpin,
124    {
125        let mut buffer = [0; 60];
126        let bytes_read = reader.read(&mut buffer).await?;
127        if bytes_read == 0 {
128            return Ok(None);
129        } else if bytes_read < buffer.len()
130            && let Err(error) =
131                reader.read_exact(&mut buffer[bytes_read..]).await
132        {
133            if error.kind() == ErrorKind::UnexpectedEof {
134                let msg = "unexpected EOF in the middle of archive entry \
135                               header";
136                return Err(Error::new(ErrorKind::UnexpectedEof, msg));
137            } else {
138                let msg = "failed to read archive entry header";
139                return Err(annotate(error, msg));
140            }
141        }
142        let mut identifier = buffer[0..16].to_vec();
143        while identifier.last() == Some(&b' ') {
144            identifier.pop();
145        }
146        let mut size = parse_number("file size", &buffer[48..58], 10)?;
147        let mut header_len = ENTRY_HEADER_LEN as u64;
148        if *variant != Variant::BSD && identifier.starts_with(b"/") {
149            *variant = Variant::GNU;
150            if identifier == GNU_SYMBOL_LOOKUP_TABLE_ID {
151                tokio::io::copy(
152                    &mut reader.take(size),
153                    &mut tokio::io::sink(),
154                )
155                .await?;
156                return Ok(Some((Header::new(identifier, size), header_len)));
157            } else if identifier == GNU_NAME_TABLE_ID.as_bytes() {
158                *name_table = vec![0; size as usize];
159                reader.read_exact(name_table as &mut [u8]).await.map_err(
160                    |err| annotate(err, "failed to read name table"),
161                )?;
162                return Ok(Some((Header::new(identifier, size), header_len)));
163            }
164            let start = parse_number("GNU filename index", &buffer[1..16], 10)?
165                as usize;
166            if start > name_table.len() {
167                let msg = "GNU filename index out of range";
168                return Err(Error::new(ErrorKind::InvalidData, msg));
169            }
170            let end = match name_table[start..]
171                .iter()
172                .position(|&ch| ch == b'/' || ch == b'\x00')
173            {
174                Some(len) => start + len,
175                None => name_table.len(),
176            };
177            identifier = name_table[start..end].to_vec();
178        } else if *variant != Variant::BSD && identifier.ends_with(b"/") {
179            *variant = Variant::GNU;
180            identifier.pop();
181        }
182        let mtime = parse_number_permitting_minus_one(
183            "timestamp",
184            &buffer[16..28],
185            10,
186        )?;
187        let uid = if *variant == Variant::GNU {
188            parse_number_permitting_empty("owner ID", &buffer[28..34], 10)?
189        } else {
190            parse_number("owner ID", &buffer[28..34], 10)?
191        } as u32;
192        let gid = if *variant == Variant::GNU {
193            parse_number_permitting_empty("group ID", &buffer[34..40], 10)?
194        } else {
195            parse_number("group ID", &buffer[34..40], 10)?
196        } as u32;
197        let mode = parse_number("file mode", &buffer[40..48], 8)? as u32;
198        if *variant != Variant::GNU && identifier.starts_with(b"#1/") {
199            *variant = Variant::BSD;
200            let padded_length =
201                parse_number("BSD filename length", &buffer[3..16], 10)?;
202            if size < padded_length {
203                let msg = format!(
204                    "Entry size ({}) smaller than extended \
205                                   entry identifier length ({})",
206                    size, padded_length
207                );
208                return Err(Error::new(ErrorKind::InvalidData, msg));
209            }
210            size -= padded_length;
211            header_len += padded_length;
212            let mut id_buffer = vec![0; padded_length as usize];
213            let bytes_read = reader.read(&mut id_buffer).await?;
214            if bytes_read < id_buffer.len()
215                && let Err(error) =
216                    reader.read_exact(&mut id_buffer[bytes_read..]).await
217            {
218                if error.kind() == ErrorKind::UnexpectedEof {
219                    let msg = "unexpected EOF in the middle of extended \
220                                   entry identifier";
221                    return Err(Error::new(ErrorKind::UnexpectedEof, msg));
222                } else {
223                    let msg = "failed to read extended entry identifier";
224                    return Err(annotate(error, msg));
225                }
226            }
227            while id_buffer.last() == Some(&0) {
228                id_buffer.pop();
229            }
230            identifier = id_buffer;
231            if identifier == BSD_SYMBOL_LOOKUP_TABLE_ID
232                || identifier == BSD_SORTED_SYMBOL_LOOKUP_TABLE_ID
233            {
234                tokio::io::copy(
235                    &mut reader.take(size),
236                    &mut tokio::io::sink(),
237                )
238                .await?;
239                return Ok(Some((Header::new(identifier, size), header_len)));
240            }
241        }
242        Ok(Some((
243            Header { identifier, mtime, uid, gid, mode, size },
244            header_len,
245        )))
246    }
247
248    pub(crate) async fn write<W: AsyncWrite + Unpin>(
249        &self,
250        writer: &mut W,
251    ) -> Result<()> {
252        if self.identifier.len() > 16 || self.identifier.contains(&b' ') {
253            let padding_length = (4 - self.identifier.len() % 4) % 4;
254            let padded_length = self.identifier.len() + padding_length;
255            writer
256                .write_all(
257                    format!(
258                        "#1/{:<13}{:<12}{:<6.6}{:<6.6}{:<8o}{:<10}`\n",
259                        padded_length,
260                        cap_mtime(self.mtime),
261                        self.uid.to_string(),
262                        self.gid.to_string(),
263                        cap_mode(self.mode),
264                        self.size + padded_length as u64
265                    )
266                    .as_bytes(),
267                )
268                .await?;
269            writer.write_all(&self.identifier).await?;
270            writer.write_all(&vec![0; padding_length]).await?;
271        } else {
272            writer.write_all(&self.identifier).await?;
273            writer.write_all(&vec![b' '; 16 - self.identifier.len()]).await?;
274            writer
275                .write_all(
276                    format!(
277                        "{:<12}{:<6.6}{:<6.6}{:<8o}{:<10}`\n",
278                        cap_mtime(self.mtime),
279                        self.uid.to_string(),
280                        self.gid.to_string(),
281                        cap_mode(self.mode),
282                        self.size
283                    )
284                    .as_bytes(),
285                )
286                .await?;
287        }
288        Ok(())
289    }
290
291    pub(crate) async fn write_gnu<W>(
292        &self,
293        writer: &mut W,
294        names: &HashMap<Vec<u8>, usize>,
295    ) -> Result<()>
296    where
297        W: AsyncWrite + Unpin,
298    {
299        if self.identifier.len() > 15 {
300            let offset = names[&self.identifier];
301            writer.write_all(format!("/{:<15}", offset).as_bytes()).await?;
302        } else {
303            writer.write_all(&self.identifier).await?;
304            writer.write_all(b"/").await?;
305            writer.write_all(&vec![b' '; 15 - self.identifier.len()]).await?;
306        }
307        writer
308            .write_all(
309                format!(
310                    "{:<12}{:<6.6}{:<6.6}{:<8o}{:<10}`\n",
311                    cap_mtime(self.mtime),
312                    self.uid.to_string(),
313                    self.gid.to_string(),
314                    cap_mode(self.mode),
315                    self.size
316                )
317                .as_bytes(),
318            )
319            .await?;
320        Ok(())
321    }
322}
323
324fn cap_mtime(mtime: u64) -> u64 {
325    min(mtime, 999_999_999_999) // Closest representable timestamp
326}
327
328fn cap_mode(mode: u32) -> u32 {
329    mode & 0o7777_7777 // Preserve as many bits as possible
330}
331
332fn parse_number(field_name: &str, bytes: &[u8], radix: u32) -> Result<u64> {
333    if let Ok(string) = str::from_utf8(bytes) {
334        let string = match radix {
335            2 => string.trim_start_matches("0b"),
336            8 => string.trim_start_matches("0o"),
337            16 => string.trim_start_matches("0x"),
338            _ => string,
339        };
340        if let Ok(value) = u64::from_str_radix(string.trim_end(), radix) {
341            return Ok(value);
342        }
343    }
344    let msg = format!(
345        "Invalid {} field in entry header ({:?})",
346        field_name,
347        String::from_utf8_lossy(bytes)
348    );
349    Err(Error::new(ErrorKind::InvalidData, msg))
350}
351
352/*
353 * Equivalent to parse_number() except for the case of "-1"
354 * as MS tools may emit for mtime.
355 */
356fn parse_number_permitting_minus_one(
357    field_name: &str,
358    bytes: &[u8],
359    radix: u32,
360) -> Result<u64> {
361    if let Ok(string) = str::from_utf8(bytes) {
362        let trimmed = string.trim_end();
363        if trimmed == "-1" {
364            return Ok(0);
365        } else if let Ok(value) = u64::from_str_radix(trimmed, radix) {
366            return Ok(value);
367        }
368    }
369    let msg = format!(
370        "Invalid {} field in entry header ({:?})",
371        field_name,
372        String::from_utf8_lossy(bytes)
373    );
374    Err(Error::new(ErrorKind::InvalidData, msg))
375}
376
377/*
378 * Equivalent to parse_number() except for the case of bytes being
379 * all spaces (eg all 0x20) as MS tools emit for UID/GID
380 */
381fn parse_number_permitting_empty(
382    field_name: &str,
383    bytes: &[u8],
384    radix: u32,
385) -> Result<u64> {
386    if let Ok(string) = str::from_utf8(bytes) {
387        let trimmed = string.trim_end();
388        if trimmed.is_empty() {
389            return Ok(0);
390        } else if let Ok(value) = u64::from_str_radix(trimmed, radix) {
391            return Ok(value);
392        }
393    }
394    let msg = format!(
395        "Invalid {} field in entry header ({:?})",
396        field_name,
397        String::from_utf8_lossy(bytes)
398    );
399    Err(Error::new(ErrorKind::InvalidData, msg))
400}