1use std::collections::hash_map::Entry;
2use std::fmt::Debug;
3
4use flatbuffers::FlatBufferBuilder;
5use tract_core::internal::*;
6
7use crate::registry::Registry;
8use crate::tensors::{flat_tensor_to_tract_fact, flat_tensor_uses_per_axis_q};
9use crate::tflite;
10use crate::tflite::{Buffer, BufferArgs};
11
12pub struct Tflite(Registry);
13
14impl Debug for Tflite {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(f, "tract-TfLite-framework")
17 }
18}
19
20impl Default for Tflite {
21 fn default() -> Self {
22 let mut registry = Registry::default();
23 crate::ops::register_all(&mut registry);
24 Tflite(registry)
25 }
26}
27
28#[derive(Clone, Debug)]
29pub struct TfliteProtoModel(Vec<u8>);
30
31impl TfliteProtoModel {
32 fn new(buf: Vec<u8>) -> TractResult<TfliteProtoModel> {
33 let _ = tflite::root_as_model(&buf)?;
34 Ok(TfliteProtoModel(buf))
35 }
36
37 pub fn root(&self) -> tflite::Model<'_> {
38 unsafe { tflite::root_as_model_unchecked(&self.0) }
39 }
40}
41
42fn write_model<'fb>(
43 registry: &Registry,
44 model: &TypedModel,
45) -> TractResult<FlatBufferBuilder<'fb>> {
46 let mut model = model.clone();
47 crate::rewriter::rewrite_for_tflite(&mut model).context("Pre-dump rewrite")?;
48 let mut builder = flatbuffers::FlatBufferBuilder::new();
49 let mut op_codes = vec![];
50 let sentinel = Buffer::create(&mut builder, &BufferArgs { data: None });
51 let mut buffers = vec![sentinel];
52 crate::ser::ModelBuilder {
53 registry,
54 builder: &mut builder,
55 op_codes: &mut op_codes,
56 buffers: &mut buffers,
57 }
58 .write_model(&model)?;
59 Ok(builder)
60}
61
62impl Tflite {
63 pub fn write(&self, model: &TypedModel, mut w: impl std::io::Write) -> TractResult<()> {
64 let builder = write_model(&self.0, model)?;
65 w.write_all(builder.finished_data())?;
66 Ok(())
67 }
68}
69
70impl Framework<TfliteProtoModel, TypedModel> for Tflite {
71 fn proto_model_for_read(
72 &self,
73 reader: &mut dyn std::io::Read,
74 ) -> tract_core::prelude::TractResult<TfliteProtoModel> {
75 let mut buf = vec![];
76 reader.read_to_end(&mut buf)?;
77 TfliteProtoModel::new(buf)
78 }
79
80 fn model_for_proto_model_with_model_template(
81 &self,
82 proto: &TfliteProtoModel,
83 mut target: TypedModel,
84 ) -> TractResult<TypedModel> {
85 let root = proto.root();
86 let main = &root.subgraphs().context("No subgraphs in Tflite model")?.get(0);
87 let mut mapping = HashMap::new();
88 for input in main.inputs().context("No inputs in Tflite model")? {
89 if !flat_tensor_uses_per_axis_q(main, input) {
90 let (fact, name) = flat_tensor_to_tract_fact(&root, main, input)?;
91 let it = target.add_source(name, fact)?;
92 mapping.insert(input, it);
93 }
94 }
95 for op in main.operators().context("No operators in Tflite model")? {
96 for input in op.inputs().context("No input in Tflite operator")? {
97 if let Entry::Vacant(slot) = mapping.entry(input) {
98 let (fact, name) = flat_tensor_to_tract_fact(&root, main, input)?;
99 let value = fact.konst.with_context(|| format!("Error in TF file for operator {op:?}. No prior computation nor constant for input {input}"))?;
100 let konst = target.add_const(name, value)?;
101 slot.insert(konst);
102 }
103 }
104 self.0.deser_op(&root, main, &op, &mut target, &mut mapping).with_context(|| {
105 format!("Translating proto-op from Tflite into tract op: {op:#?}")
106 })?;
107 }
108 let outputs: TVec<_> = main
109 .outputs()
110 .context("No outputs in Tflite model")?
111 .iter()
112 .map(|o| mapping[&o])
113 .collect();
114 target.set_output_outlets(&outputs)?;
115 Ok(target)
116 }
117}