tract_nnef/ast/
dump_doc.rs1use crate::ast::dump::Dumper;
2use crate::ast::*;
3use std::path::Path;
4use tract_core::internal::*;
5
6pub struct DocDumper<'a> {
7 w: &'a mut dyn std::io::Write,
8}
9
10impl DocDumper<'_> {
11 pub fn new(w: &mut dyn std::io::Write) -> DocDumper {
12 DocDumper { w }
13 }
14
15 pub fn registry(&mut self, registry: &Registry) -> TractResult<()> {
16 for d in registry.docstrings.iter().flatten() {
18 writeln!(self.w, "# {d}")?;
19 }
20 writeln!(self.w)?;
21 for unit_el_wise_op in registry.unit_element_wise_ops.iter() {
23 writeln!(
26 self.w,
27 "fragment {}( x: tensor<scalar> ) -> (y: tensor<scalar>);",
28 &unit_el_wise_op.0 .0
29 )?;
30 }
31 writeln!(self.w)?;
32
33 for el_wise_op in registry.element_wise_ops.iter() {
35 let fragment_decl = FragmentDecl {
36 id: el_wise_op.0.clone(),
37 generic_decl: None,
38 parameters: el_wise_op.3.clone(),
39 results: vec![Result_ { id: "output".into(), spec: TypeName::Any.tensor() }],
40 };
41 Dumper::new(&Nnef::default(), self.w).with_doc().fragment_decl(&fragment_decl)?;
42 }
43 for primitive in registry.primitives.values().sorted_by_key(|v| &v.decl.id) {
45 primitive.docstrings.iter().flatten().try_for_each(|d| writeln!(self.w, "# {d}"))?;
46
47 Dumper::new(&Nnef::default(), self.w).with_doc().fragment_decl(&primitive.decl)?;
48 writeln!(self.w, ";\n")?;
49 }
50
51 Dumper::new(&Nnef::default(), self.w)
53 .with_doc()
54 .fragments(registry.fragments.values().cloned().collect::<Vec<_>>().as_slice())?;
55
56 Ok(())
57 }
58
59 pub fn registry_to_path(path: impl AsRef<Path>, registry: &Registry) -> TractResult<()> {
60 let mut file = std::fs::File::create(path.as_ref())
61 .with_context(|| anyhow!("Error while creating file at path: {:?}", path.as_ref()))?;
62 DocDumper::new(&mut file).registry(registry)
63 }
64
65 pub fn to_directory(path: impl AsRef<Path>, nnef: &Nnef) -> TractResult<()> {
66 for registry in nnef.registries.iter() {
67 let registry_file = path.as_ref().join(format!("{}.nnef", registry.id.0));
68 let mut file = std::fs::File::create(®istry_file).with_context(|| {
69 anyhow!("Error while creating file at path: {:?}", registry_file)
70 })?;
71 DocDumper::new(&mut file).registry(registry)?;
72 }
73 Ok(())
74 }
75}
76
77#[cfg(test)]
78mod test {
79 use super::*;
80 use temp_dir::TempDir;
81
82 #[test]
83 fn doc_example() -> TractResult<()> {
84 let d = TempDir::new()?;
85 let nnef = crate::nnef().with_tract_core().with_tract_resource();
86 DocDumper::to_directory(d.path(), &nnef)?;
87 Ok(())
88 }
89
90 #[test]
91 fn doc_registry() -> TractResult<()> {
92 let mut registry = Registry::new("test_doc")
93 .with_doc("test_doc registry gather all the needed primitives")
94 .with_doc("to test the documentation dumper");
95 registry.register_primitive(
96 "tract_primitive",
97 &[TypeName::Integer.tensor().named("input")],
98 &[("output", TypeName::Scalar.tensor())],
99 |_, _| panic!("No deserialization needed"),
100 );
101 let mut docbytes = vec![];
102 let mut dumper = DocDumper::new(&mut docbytes);
103 dumper.registry(®istry)?;
104 let docstring = String::from_utf8(docbytes)?;
105 assert_eq!(
106 docstring,
107 r#"# test_doc registry gather all the needed primitives
108# to test the documentation dumper
109
110
111fragment tract_primitive(
112 input: tensor<integer>
113) -> (output: tensor<scalar>);
114
115"#
116 );
117 Ok(())
118 }
119}