serde_device_tree/
de.rs

1// Copyright (c) 2021 HUST IoT Security Lab
2// serde_device_tree is licensed under Mulan PSL v2.
3// You can use this software according to the terms and conditions of the Mulan PSL v2.
4// You may obtain a copy of Mulan PSL v2 at:
5//          http://license.coscl.org.cn/MulanPSL2
6// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
7// EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
8// MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
9// See the Mulan PSL v2 for more details.
10
11use crate::error::{Error, Result};
12use core::iter::Peekable;
13use serde::de;
14
15pub unsafe fn from_raw<'de, T>(ptr: *const u8) -> Result<T>
16where
17    T: de::Deserialize<'de>,
18{
19    // read header
20    let header = &*(ptr as *const Header);
21    let magic = u32::from_be(header.magic);
22    if magic != DEVICE_TREE_MAGIC {
23        let file_index =
24            (&header.magic as *const _ as usize) - (&header.magic as *const _ as usize);
25        return Err(Error::invalid_magic(magic, file_index));
26    }
27    let last_comp_version = u32::from_be(header.last_comp_version);
28    if last_comp_version > SUPPORTED_VERSION {
29        let file_index =
30            (&header.last_comp_version as *const _ as usize) - (&header.magic as *const _ as usize);
31        return Err(Error::incompatible_version(
32            last_comp_version,
33            SUPPORTED_VERSION,
34            file_index,
35        ));
36    }
37    let total_size = u32::from_be(header.total_size);
38    if total_size < HEADER_LEN {
39        let file_index =
40            (&header.total_size as *const _ as usize) - (&header.magic as *const _ as usize);
41        return Err(Error::header_too_short(total_size, HEADER_LEN, file_index));
42    }
43    let off_dt_struct = u32::from_be(header.off_dt_struct);
44    if off_dt_struct < HEADER_LEN {
45        let file_index =
46            (&header.off_dt_struct as *const _ as usize) - (&header.magic as *const _ as usize);
47        return Err(Error::structure_index_underflow(
48            off_dt_struct,
49            HEADER_LEN,
50            file_index,
51        ));
52    }
53    let size_dt_struct = u32::from_be(header.size_dt_struct);
54    if off_dt_struct + size_dt_struct > total_size {
55        let file_index =
56            (&header.size_dt_struct as *const _ as usize) - (&header.magic as *const _ as usize);
57        return Err(Error::structure_index_overflow(
58            off_dt_struct + size_dt_struct,
59            HEADER_LEN,
60            file_index,
61        ));
62    }
63    let off_dt_strings = u32::from_be(header.off_dt_strings);
64    if off_dt_strings < HEADER_LEN {
65        let file_index =
66            (&header.off_dt_strings as *const _ as usize) - (&header.magic as *const _ as usize);
67        return Err(Error::string_index_underflow(
68            off_dt_strings,
69            HEADER_LEN,
70            file_index,
71        ));
72    }
73    let size_dt_strings = u32::from_be(header.size_dt_strings);
74    if off_dt_struct + size_dt_strings > total_size {
75        let file_index =
76            (&header.size_dt_strings as *const _ as usize) - (&header.magic as *const _ as usize);
77        return Err(Error::string_index_overflow(
78            off_dt_strings,
79            HEADER_LEN,
80            file_index,
81        ));
82    }
83    let raw_data_len = (total_size - HEADER_LEN) as usize;
84    let ans_ptr = core::ptr::from_raw_parts(ptr as *const (), raw_data_len);
85    let device_tree: &DeviceTree = &*ans_ptr;
86    let tags = device_tree.tags();
87    let mut d = Deserializer {
88        tags: tags.peekable(),
89    };
90    let ret = T::deserialize(&mut d)?;
91    Ok(ret)
92}
93
94const DEVICE_TREE_MAGIC: u32 = 0xD00DFEED;
95
96const FDT_BEGIN_NODE: u32 = 0x1;
97const FDT_END_NODE: u32 = 0x2;
98const FDT_PROP: u32 = 0x3;
99const FDT_NOP: u32 = 0x4;
100const FDT_END: u32 = 0x9;
101
102const SUPPORTED_VERSION: u32 = 17;
103
104#[derive(Debug, Clone)]
105#[repr(C)]
106struct Header {
107    magic: u32,
108    total_size: u32,
109    off_dt_struct: u32,
110    off_dt_strings: u32,
111    off_mem_rsvmap: u32,
112    version: u32,
113    last_comp_version: u32,
114    boot_cpuid_phys: u32,
115    size_dt_strings: u32,
116    size_dt_struct: u32,
117}
118
119const HEADER_LEN: u32 = core::mem::size_of::<Header>() as u32;
120
121#[derive(Debug)]
122struct DeviceTree {
123    header: Header,
124    data: [u8],
125}
126
127impl DeviceTree {
128    pub fn tags(&self) -> Tags {
129        let structure_addr = (u32::from_be(self.header.off_dt_struct) - HEADER_LEN) as usize;
130        let structure_len = u32::from_be(self.header.size_dt_struct) as usize;
131        let strings_addr = (u32::from_be(self.header.off_dt_strings) - HEADER_LEN) as usize;
132        let strings_len = u32::from_be(self.header.size_dt_strings) as usize;
133        Tags {
134            structure: &self.data[structure_addr..structure_addr + structure_len],
135            string_table: &self.data[strings_addr..strings_addr + strings_len],
136            cur: 0,
137            offset_from_file_begin: structure_addr,
138        }
139    }
140}
141
142#[derive(Debug, Clone)]
143struct Tags<'a> {
144    structure: &'a [u8],
145    string_table: &'a [u8],
146    cur: usize,
147    offset_from_file_begin: usize,
148}
149
150#[inline]
151fn align_up_u32(val: usize) -> usize {
152    val + (4 - (val % 4)) % 4
153}
154
155impl<'a> Tags<'a> {
156    #[inline]
157    fn file_index(&self) -> usize {
158        self.cur + self.offset_from_file_begin
159    }
160    #[inline]
161    fn read_cur_u32(&mut self) -> u32 {
162        let ans = u32::from_be_bytes([
163            self.structure[self.cur],
164            self.structure[self.cur + 1],
165            self.structure[self.cur + 2],
166            self.structure[self.cur + 3],
167        ]);
168        self.cur += 4;
169        ans
170    }
171    #[inline]
172    fn read_string0_align(&mut self) -> Result<&'a [u8]> {
173        let begin = self.cur;
174        while self.cur < self.structure.len() {
175            if self.structure[self.cur] == b'\0' {
176                let end = self.cur;
177                self.cur = align_up_u32(end + 1);
178                return Ok(&self.structure[begin..end]);
179            }
180            self.cur += 1;
181        }
182        Err(Error::string_eof_unpexpected(self.file_index()))
183    }
184    #[inline]
185    fn read_slice_align(&mut self, len: u32) -> Result<&'a [u8]> {
186        let begin = self.cur;
187        let end = self.cur + len as usize;
188        if end > self.structure.len() {
189            let remaining_length = self.structure.len() as u32 - begin as u32;
190            return Err(Error::slice_eof_unpexpected(
191                len,
192                remaining_length,
193                self.file_index(),
194            ));
195        }
196        self.cur = align_up_u32(end);
197        Ok(&self.structure[begin..end])
198    }
199    #[inline]
200    fn read_table_string(&mut self, pos: u32) -> Result<&'a [u8]> {
201        let begin = pos as usize;
202        if begin >= self.string_table.len() {
203            let bound_offset = self.string_table.len() as u32;
204            return Err(Error::table_string_offset(
205                pos,
206                bound_offset,
207                self.file_index(),
208            ));
209        }
210        let mut cur = begin;
211        while cur < self.string_table.len() {
212            if self.string_table[cur] == b'\0' {
213                return Ok(&self.string_table[begin..cur]);
214            }
215            cur += 1;
216        }
217        return Err(Error::table_string_offset(
218            pos,
219            cur as u32,
220            self.file_index(),
221        ));
222    }
223}
224
225impl<'a> Iterator for Tags<'a> {
226    type Item = Result<(Tag<'a>, usize)>; // Tag, byte index from file begin
227    fn next(&mut self) -> Option<Self::Item> {
228        if self.cur > self.structure.len() - core::mem::size_of::<u32>() {
229            return Some(Err(Error::tag_eof_unexpected(
230                self.cur as u32,
231                self.structure.len() as u32,
232                self.file_index(),
233            )));
234        }
235        let ans = loop {
236            match self.read_cur_u32() {
237                FDT_BEGIN_NODE => match self.read_string0_align() {
238                    Ok(name) => {
239                        // println!("cur = {}", self.cur)
240                        break Some(Ok(Tag::Begin(name)));
241                    }
242                    Err(e) => break Some(Err(e)),
243                },
244                FDT_PROP => {
245                    let val_size = self.read_cur_u32();
246                    let name_offset = self.read_cur_u32();
247                    // println!("size {}, off {}", val_size, name_offset);
248                    // get value slice
249                    let val = match self.read_slice_align(val_size) {
250                        Ok(slice) => slice,
251                        Err(e) => break Some(Err(e)),
252                    };
253
254                    // lookup name in strings table
255                    let prop_name = match self.read_table_string(name_offset) {
256                        Ok(slice) => slice,
257                        Err(e) => break Some(Err(e)),
258                    };
259                    break Some(Ok(Tag::Prop(val, prop_name)));
260                }
261                FDT_END_NODE => break Some(Ok(Tag::End)),
262                FDT_NOP => self.cur += 4,
263                FDT_END => break None,
264                invalid => break Some(Err(Error::invalid_tag_id(invalid, self.file_index()))),
265            }
266        };
267        match ans {
268            Some(Ok(tag)) => Some(Ok((tag, self.file_index()))),
269            Some(Err(e)) => Some(Err(e)),
270            None => None,
271        }
272    }
273}
274
275#[derive(Clone, Copy, Debug)]
276pub enum Tag<'a> {
277    Begin(&'a [u8]),
278    Prop(&'a [u8], &'a [u8]),
279    End,
280}
281
282#[derive(Debug, Clone)]
283pub struct Deserializer<'a> {
284    tags: Peekable<Tags<'a>>,
285}
286
287impl<'a> Deserializer<'a> {
288    fn next_tag(&mut self) -> Result<Option<(Tag<'a>, usize)>> {
289        self.tags.next().transpose()
290    }
291    fn peek_tag(&mut self) -> Result<Option<Tag<'a>>> {
292        match self.tags.peek() {
293            Some(Ok((t, _i))) => Ok(Some(*t)),
294            Some(Err(e)) => Err(e.clone()),
295            None => Ok(None),
296        }
297    }
298    fn peek_tag_index(&mut self) -> Result<Option<&(Tag<'a>, usize)>> {
299        match self.tags.peek() {
300            Some(Ok(t)) => Ok(Some(t)),
301            Some(Err(e)) => Err(e.clone()),
302            None => Ok(None),
303        }
304    }
305    fn eat_tag(&mut self) -> Result<()> {
306        match self.tags.next() {
307            Some(Ok(_t)) => Ok(()),
308            Some(Err(e)) => Err(e),
309            None => Ok(()),
310        }
311    }
312}
313
314impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
315    type Error = Error;
316
317    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
318    where
319        V: de::Visitor<'de>,
320    {
321        match self.peek_tag()? {
322            Some(Tag::Prop(_, value_slice)) => {
323                if value_slice.len() == 0 {
324                    self.deserialize_bool(visitor)
325                } else if value_slice.len() == 4 {
326                    self.deserialize_u32(visitor)
327                } else {
328                    self.deserialize_bytes(visitor) // by default, it's bytes
329                }
330            }
331            Some(Tag::Begin(_name_slice)) => self.deserialize_map(visitor),
332            Some(Tag::End) => unreachable!(),
333            _ => todo!(),
334        }
335    }
336
337    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
338    where
339        V: de::Visitor<'de>,
340    {
341        match self.peek_tag_index()? {
342            Some((Tag::Prop(value_slice, _name_slice), _file_index)) => {
343                if value_slice.len() == 0 {
344                    self.eat_tag()?;
345                    visitor.visit_bool(true)
346                } else {
347                    panic!()
348                }
349            }
350            _ => panic!(),
351        }
352    }
353
354    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
355    where
356        V: de::Visitor<'de>,
357    {
358        let _ = visitor;
359        todo!()
360    }
361
362    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
363    where
364        V: de::Visitor<'de>,
365    {
366        let _ = visitor;
367        todo!()
368    }
369
370    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
371    where
372        V: de::Visitor<'de>,
373    {
374        let _ = visitor;
375        todo!()
376    }
377
378    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
379    where
380        V: de::Visitor<'de>,
381    {
382        let _ = visitor;
383        todo!()
384    }
385
386    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
387    where
388        V: de::Visitor<'de>,
389    {
390        let _ = visitor;
391        todo!()
392    }
393
394    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
395    where
396        V: de::Visitor<'de>,
397    {
398        let _ = visitor;
399        todo!()
400    }
401
402    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
403    where
404        V: de::Visitor<'de>,
405    {
406        match self.peek_tag_index()? {
407            Some((Tag::Prop(value_slice, _name_slice), file_index)) => {
408                let value = match value_slice {
409                    [a, b, c, d] => u32::from_be_bytes([*a, *b, *c, *d]),
410                    _ => return Err(Error::invalid_serde_type_length(4, *file_index)),
411                };
412                self.eat_tag()?;
413                visitor.visit_u32(value)
414            }
415            _ => todo!(),
416        }
417    }
418
419    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
420    where
421        V: de::Visitor<'de>,
422    {
423        let _ = visitor;
424        todo!()
425    }
426
427    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
428    where
429        V: de::Visitor<'de>,
430    {
431        let _ = visitor;
432        todo!()
433    }
434
435    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
436    where
437        V: de::Visitor<'de>,
438    {
439        let _ = visitor;
440        todo!()
441    }
442
443    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
444    where
445        V: de::Visitor<'de>,
446    {
447        let _ = visitor;
448        todo!()
449    }
450
451    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
452    where
453        V: de::Visitor<'de>,
454    {
455        match self.peek_tag_index()? {
456            Some((Tag::Prop(value_slice, _name_slice), file_index)) => {
457                let s =
458                    core::str::from_utf8(value_slice).map_err(|e| Error::utf8(e, *file_index))?;
459                let value = visitor.visit_borrowed_str(s)?;
460                self.eat_tag()?;
461                Ok(value)
462            }
463            _ => todo!(),
464        }
465    }
466
467    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
468    where
469        V: de::Visitor<'de>,
470    {
471        let _ = visitor;
472        todo!()
473    }
474
475    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
476    where
477        V: de::Visitor<'de>,
478    {
479        match self.peek_tag()? {
480            Some(Tag::Prop(value_slice, _name_slice)) => {
481                let value = visitor.visit_borrowed_bytes(value_slice)?;
482                self.eat_tag()?;
483                Ok(value)
484            }
485            _ => todo!(),
486        }
487    }
488
489    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
490    where
491        V: de::Visitor<'de>,
492    {
493        let _ = visitor;
494        todo!()
495    }
496
497    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
498    where
499        V: de::Visitor<'de>,
500    {
501        visitor.visit_some(self)
502    }
503
504    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
505    where
506        V: de::Visitor<'de>,
507    {
508        let _ = visitor;
509        todo!()
510    }
511
512    fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
513    where
514        V: de::Visitor<'de>,
515    {
516        let _ = (name, visitor);
517        todo!()
518    }
519
520    fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
521    where
522        V: de::Visitor<'de>,
523    {
524        let _ = (name, visitor);
525        todo!()
526    }
527
528    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
529    where
530        V: de::Visitor<'de>,
531    {
532        let _ = visitor;
533        todo!()
534    }
535
536    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
537    where
538        V: de::Visitor<'de>,
539    {
540        let _ = (len, visitor);
541        todo!()
542    }
543
544    fn deserialize_tuple_struct<V>(
545        self,
546        name: &'static str,
547        len: usize,
548        visitor: V,
549    ) -> Result<V::Value>
550    where
551        V: de::Visitor<'de>,
552    {
553        let _ = (name, len, visitor);
554        todo!()
555    }
556
557    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
558    where
559        V: de::Visitor<'de>,
560    {
561        if let Some((Tag::Begin(_name_slice), _file_index)) = self.next_tag()? {
562            let ret = visitor.visit_map(MapVisitor::new(self))?;
563            if let Some((Tag::End, _file_index)) = self.next_tag()? {
564                Ok(ret)
565            } else {
566                Err(Error::expected_struct_end())
567            }
568        } else {
569            Err(Error::expected_struct_begin())
570        }
571    }
572
573    fn deserialize_struct<V>(
574        self,
575        name: &'static str,
576        fields: &'static [&'static str],
577        visitor: V,
578    ) -> Result<V::Value>
579    where
580        V: de::Visitor<'de>,
581    {
582        let _ = (name, fields);
583        self.deserialize_map(visitor)
584    }
585
586    fn deserialize_enum<V>(
587        self,
588        name: &'static str,
589        variants: &'static [&'static str],
590        visitor: V,
591    ) -> Result<V::Value>
592    where
593        V: de::Visitor<'de>,
594    {
595        let _ = (name, variants, visitor);
596        todo!()
597    }
598
599    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
600    where
601        V: de::Visitor<'de>,
602    {
603        if let Some((Tag::Begin(name_slice), file_index)) = self.peek_tag_index()? {
604            let s = core::str::from_utf8(name_slice).map_err(|e| Error::utf8(e, *file_index))?;
605            visitor.visit_str(s)
606        } else {
607            todo!()
608        }
609    }
610
611    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
612    where
613        V: de::Visitor<'de>,
614    {
615        if let Some(tag) = self.peek_tag()? {
616            match tag {
617                Tag::Begin(_) => {
618                    self.eat_tag()?;
619                    let mut depth = 0;
620                    while let Some((tag, _file_index)) = self.next_tag()? {
621                        match tag {
622                            Tag::Begin(_) => depth += 1,
623                            Tag::End => {
624                                if depth == 0 {
625                                    break;
626                                } else {
627                                    depth -= 1
628                                }
629                            }
630                            Tag::Prop(_, _) => {}
631                        }
632                    }
633                }
634                Tag::End => unreachable!(),
635                Tag::Prop(_, _) => self.eat_tag()?,
636            }
637        }
638        visitor.visit_unit()
639    }
640}
641
642struct MapVisitor<'de, 'b> {
643    de: &'b mut Deserializer<'de>,
644}
645
646impl<'de, 'b> MapVisitor<'de, 'b> {
647    fn new(de: &'b mut Deserializer<'de>) -> Self {
648        Self { de }
649    }
650}
651
652impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
653    type Error = Error;
654
655    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
656    where
657        K: de::DeserializeSeed<'de>,
658    {
659        match self.de.peek_tag()? {
660            Some(Tag::Prop(_value_slice, name_slice)) => seed
661                .deserialize(serde::de::value::BorrowedBytesDeserializer::new(name_slice))
662                .map(Some),
663            Some(Tag::Begin(name_slice)) => seed
664                .deserialize(serde::de::value::BorrowedBytesDeserializer::new(name_slice))
665                .map(Some),
666            Some(Tag::End) => Ok(None),
667            None => return Err(Error::no_remaining_tags()),
668        }
669    }
670
671    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
672    where
673        V: de::DeserializeSeed<'de>,
674    {
675        match self.de.peek_tag()? {
676            Some(Tag::Prop(_value_slice, _name_slice)) => seed.deserialize(&mut *self.de),
677            Some(Tag::Begin(_name_slice)) => seed.deserialize(&mut *self.de),
678            Some(Tag::End) => panic!(),
679            None => return Err(Error::no_remaining_tags()),
680        }
681    }
682}
683
684#[cfg(test)]
685mod tests {
686    #[cfg(feature = "alloc")]
687    use alloc::format;
688    #[cfg(any(feature = "std", feature = "alloc"))]
689    use serde_derive::Deserialize;
690    #[cfg(feature = "std")]
691    use std::format;
692
693    #[cfg(any(feature = "std", feature = "alloc"))]
694    #[test]
695    fn error_invalid_magic() {
696        static DEVICE_TREE: &'static [u8] = &[0x11, 0x22, 0x33, 0x44]; // not device tree blob format
697        let ptr = DEVICE_TREE.as_ptr();
698
699        #[derive(Debug, Deserialize)]
700        struct Tree {}
701
702        let ans: Result<Tree, _> = unsafe { super::from_raw(ptr) };
703        let err = ans.unwrap_err();
704        assert_eq!(
705            "Error(invalid magic, value: 287454020, index: 0)",
706            format!("{}", err)
707        );
708    }
709}