Skip to main content

tract_libcli/
export.rs

1use crate::annotations::{Annotations, NodeQId};
2use crate::model::Model;
3use serde::Serialize;
4use std::collections::HashMap;
5use tract_core::internal::*;
6
7#[derive(Clone, Debug, Default, Serialize)]
8pub struct GraphPerfInfo {
9    nodes: Vec<Node>,
10    profiling_info: Option<ProfilingInfo>,
11}
12
13#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize)]
14pub struct NodeQIdSer(pub Vec<(usize, String)>, pub usize);
15
16#[derive(Clone, Debug, Serialize)]
17pub struct Node {
18    qualified_id: NodeQIdSer,
19    op_name: String,
20    node_name: String,
21
22    #[serde(skip_serializing_if = "HashMap::is_empty")]
23    cost: HashMap<String, String>,
24
25    #[serde(skip_serializing_if = "Option::is_none")]
26    secs_per_iter: Option<f64>,
27}
28
29#[derive(Clone, Debug, Serialize)]
30pub struct ProfilingInfo {
31    iterations: usize,
32    secs_per_iter: f64,
33}
34
35impl GraphPerfInfo {
36    pub fn from(model: &dyn Model, annotations: &Annotations) -> GraphPerfInfo {
37        let nodes = annotations
38            .tags
39            .iter()
40            .map(|(id, node)| Node {
41                qualified_id: NodeQIdSer(id.0.iter().cloned().collect(), id.1),
42                cost: node.cost.iter().map(|(k, v)| (format!("{k:?}"), format!("{v}"))).collect(),
43                node_name: id.model(model).unwrap().node_name(id.1).to_string(),
44                op_name: id.model(model).unwrap().node_op_name(id.1).to_string(),
45                secs_per_iter: node.profile.map(|s| s.as_secs_f64()),
46            })
47            .collect();
48        let profiling_info = annotations.profile_summary.as_ref().map(|summary| ProfilingInfo {
49            secs_per_iter: summary.entire.as_secs_f64(),
50            iterations: summary.iters,
51        });
52        GraphPerfInfo { nodes, profiling_info }
53    }
54}
55
56// -- audit-json --
57
58#[derive(Serialize)]
59pub struct AuditModel {
60    properties: HashMap<String, String>,
61    assertions: Vec<String>,
62    inputs: Vec<AuditModelIo>,
63    outputs: Vec<AuditModelIo>,
64    nodes: Vec<AuditNode>,
65}
66
67#[derive(Serialize)]
68struct AuditModelIo {
69    name: String,
70    node: usize,
71    slot: usize,
72    fact: String,
73}
74
75#[derive(Serialize)]
76struct AuditNode {
77    id: usize,
78    name: String,
79    op: String,
80    #[serde(skip_serializing_if = "Vec::is_empty")]
81    info: Vec<String>,
82    inputs: Vec<AuditOutletRef>,
83    outputs: Vec<AuditNodeOutput>,
84    #[serde(skip_serializing_if = "HashMap::is_empty")]
85    cost: HashMap<String, String>,
86}
87
88#[derive(Serialize)]
89struct AuditOutletRef {
90    node: usize,
91    slot: usize,
92}
93
94#[derive(Serialize)]
95struct AuditNodeOutput {
96    fact: String,
97    successors: Vec<AuditInletRef>,
98}
99
100#[derive(Serialize)]
101struct AuditInletRef {
102    node: usize,
103    slot: usize,
104}
105
106pub fn audit_json(
107    model: &dyn Model,
108    annotations: &Annotations,
109    writer: impl std::io::Write,
110) -> TractResult<()> {
111    let properties: HashMap<String, String> =
112        model.properties().iter().map(|(k, v)| (k.clone(), format!("{v:?}"))).collect();
113
114    let scope = model.symbols();
115    let assertions: Vec<String> = scope.all_assertions().iter().map(|a| format!("{a}")).collect();
116
117    let inputs: Vec<AuditModelIo> = model
118        .input_outlets()
119        .iter()
120        .map(|o| {
121            Ok(AuditModelIo {
122                name: model.node_name(o.node).to_string(),
123                node: o.node,
124                slot: o.slot,
125                fact: model.outlet_fact_format(*o),
126            })
127        })
128        .collect::<TractResult<_>>()?;
129
130    let outputs: Vec<AuditModelIo> = model
131        .output_outlets()
132        .iter()
133        .map(|o| {
134            Ok(AuditModelIo {
135                name: model.node_name(o.node).to_string(),
136                node: o.node,
137                slot: o.slot,
138                fact: model.outlet_fact_format(*o),
139            })
140        })
141        .collect::<TractResult<_>>()?;
142
143    let nodes: Vec<AuditNode> = (0..model.nodes_len())
144        .map(|id| {
145            let op = model.node_op(id);
146            let info = op.info().unwrap_or_default();
147            let node_inputs: Vec<AuditOutletRef> = model
148                .node_inputs(id)
149                .iter()
150                .map(|o| AuditOutletRef { node: o.node, slot: o.slot })
151                .collect();
152            let node_outputs: Vec<AuditNodeOutput> = (0..model.node_output_count(id))
153                .map(|slot| {
154                    let outlet = OutletId::new(id, slot);
155                    let fact = model.outlet_fact_format(outlet);
156                    let successors: Vec<AuditInletRef> = model
157                        .outlet_successors(outlet)
158                        .iter()
159                        .map(|inlet| AuditInletRef { node: inlet.node, slot: inlet.slot })
160                        .collect();
161                    AuditNodeOutput { fact, successors }
162                })
163                .collect();
164            let cost: HashMap<String, String> = annotations
165                .tags
166                .get(&NodeQId(tvec!(), id))
167                .map(|tags| {
168                    tags.cost.iter().map(|(k, v)| (format!("{k:?}"), format!("{v}"))).collect()
169                })
170                .unwrap_or_default();
171            AuditNode {
172                id,
173                name: model.node_name(id).to_string(),
174                op: model.node_op_name(id).to_string(),
175                info,
176                inputs: node_inputs,
177                outputs: node_outputs,
178                cost,
179            }
180        })
181        .collect();
182
183    let audit = AuditModel { properties, assertions, inputs, outputs, nodes };
184    serde_json::to_writer_pretty(writer, &audit)?;
185    Ok(())
186}