tract_nnef/ast/
dump_doc.rs

1use crate::ast::dump::Dumper;
2use crate::ast::*;
3use std::path::Path;
4use tract_core::internal::*;
5
6pub struct DocDumper<'a> {
7    w: &'a mut dyn std::io::Write,
8}
9
10impl DocDumper<'_> {
11    pub fn new(w: &mut dyn std::io::Write) -> DocDumper {
12        DocDumper { w }
13    }
14
15    pub fn registry(&mut self, registry: &Registry) -> TractResult<()> {
16        // Write registry docstrings.
17        for d in registry.docstrings.iter().flatten() {
18            writeln!(self.w, "# {d}")?;
19        }
20        writeln!(self.w)?;
21        // Generate and write unit element wise op.
22        for unit_el_wise_op in registry.unit_element_wise_ops.iter() {
23            // we are assuming function names will not exhibit crazy node name weirdness, so we can
24            // dispense with escaping
25            writeln!(
26                self.w,
27                "fragment {}( x: tensor<scalar> ) -> (y: tensor<scalar>);",
28                &unit_el_wise_op.0 .0
29            )?;
30        }
31        writeln!(self.w)?;
32
33        // Generate and write element wise op.
34        for el_wise_op in registry.element_wise_ops.iter() {
35            let fragment_decl = FragmentDecl {
36                id: el_wise_op.0.clone(),
37                generic_decl: None,
38                parameters: el_wise_op.3.clone(),
39                results: vec![Result_ { id: "output".into(), spec: TypeName::Any.tensor() }],
40            };
41            Dumper::new(&Nnef::default(), self.w).with_doc().fragment_decl(&fragment_decl)?;
42        }
43        // Generate and write Primitive declarations.
44        for primitive in registry.primitives.values().sorted_by_key(|v| &v.decl.id) {
45            primitive.docstrings.iter().flatten().try_for_each(|d| writeln!(self.w, "# {d}"))?;
46
47            Dumper::new(&Nnef::default(), self.w).with_doc().fragment_decl(&primitive.decl)?;
48            writeln!(self.w, ";\n")?;
49        }
50
51        // Generate and write fragment declarations
52        Dumper::new(&Nnef::default(), self.w)
53            .with_doc()
54            .fragments(registry.fragments.values().cloned().collect::<Vec<_>>().as_slice())?;
55
56        Ok(())
57    }
58
59    pub fn registry_to_path(path: impl AsRef<Path>, registry: &Registry) -> TractResult<()> {
60        let mut file = std::fs::File::create(path.as_ref())
61            .with_context(|| anyhow!("Error while creating file at path: {:?}", path.as_ref()))?;
62        DocDumper::new(&mut file).registry(registry)
63    }
64
65    pub fn to_directory(path: impl AsRef<Path>, nnef: &Nnef) -> TractResult<()> {
66        for registry in nnef.registries.iter() {
67            let registry_file = path.as_ref().join(format!("{}.nnef", registry.id.0));
68            let mut file = std::fs::File::create(&registry_file).with_context(|| {
69                anyhow!("Error while creating file at path: {:?}", registry_file)
70            })?;
71            DocDumper::new(&mut file).registry(registry)?;
72        }
73        Ok(())
74    }
75}
76
77#[cfg(test)]
78mod test {
79    use super::*;
80    use temp_dir::TempDir;
81
82    #[test]
83    fn doc_example() -> TractResult<()> {
84        let d = TempDir::new()?;
85        let nnef = crate::nnef().with_tract_core().with_tract_resource();
86        DocDumper::to_directory(d.path(), &nnef)?;
87        Ok(())
88    }
89
90    #[test]
91    fn doc_registry() -> TractResult<()> {
92        let mut registry = Registry::new("test_doc")
93            .with_doc("test_doc registry gather all the needed primitives")
94            .with_doc("to test the documentation dumper");
95        registry.register_primitive(
96            "tract_primitive",
97            &[TypeName::Integer.tensor().named("input")],
98            &[("output", TypeName::Scalar.tensor())],
99            |_, _| panic!("No deserialization needed"),
100        );
101        let mut docbytes = vec![];
102        let mut dumper = DocDumper::new(&mut docbytes);
103        dumper.registry(&registry)?;
104        let docstring = String::from_utf8(docbytes)?;
105        assert_eq!(
106            docstring,
107            r#"# test_doc registry gather all the needed primitives
108# to test the documentation dumper
109
110
111fragment tract_primitive(
112    input: tensor<integer>
113) -> (output: tensor<scalar>);
114
115"#
116        );
117        Ok(())
118    }
119}