1use tract_hir::internal::*;
2
3use nom::IResult;
4use nom::{
5 branch::alt,
6 bytes::complete::*,
7 character::complete::*,
8 combinator::*,
9 number::complete::{le_i32, le_i64},
10 sequence::*,
11};
12
13use std::collections::HashMap;
14
15use crate::model::{Component, KaldiProtoModel};
16
17use tract_itertools::Itertools;
18
19mod bin;
20mod components;
21mod config_lines;
22mod descriptor;
23mod text;
24
25pub fn nnet3(slice: &[u8]) -> TractResult<KaldiProtoModel> {
26 let (_, (config, components)) = parse_top_level(slice).map_err(|e| match e {
27 nom::Err::Error(err) => format_err!(
28 "Parsing kaldi enveloppe at: {:?}",
29 err.input.iter().take(120).map(|b| format!("{}", *b as char)).join("")
30 ),
31 e => format_err!("{:?}", e),
32 })?;
33 let config_lines = config_lines::parse_config(config)?;
34 Ok(KaldiProtoModel { config_lines, components, adjust_final_offset: 0 })
35}
36
37pub fn if_then_else<'a, T>(
38 condition: bool,
39 then: impl FnMut(&'a [u8]) -> IResult<&'a [u8], T>,
40 otherwise: impl FnMut(&'a [u8]) -> IResult<&'a [u8], T>,
41) -> impl FnMut(&'a [u8]) -> IResult<&'a [u8], T> {
42 map(pair(cond(condition, then), cond(!condition, otherwise)), |(a, b)| a.or(b).unwrap())
43}
44
45fn parse_top_level(i: &[u8]) -> IResult<&[u8], (&str, HashMap<String, Component>)> {
46 let (i, bin) = map(opt(tag([0, 0x42])), |o| Option::is_some(&o))(i)?;
47 let (i, _) = open(i, "Nnet3")?;
48 let (i, config_lines) = map_res(take_until("<NumComponents>"), std::str::from_utf8)(i)?;
49 let (i, num_components) = num_components(bin, i)?;
50 let mut components = HashMap::new();
51 let mut i = i;
52 for _ in 0..num_components {
53 let (new_i, name) = component_name(i)?;
54 debug!("Parsing component {}", name);
55 let (new_i, comp) = component(bin)(new_i)?;
56 i = new_i;
57 components.insert(name.to_owned(), comp);
58 }
59 let (i, _) = close(i, "Nnet3")?;
60 Ok((i, (config_lines, components)))
61}
62
63fn num_components(bin: bool, i: &[u8]) -> IResult<&[u8], usize> {
64 let (i, _) = open(i, "NumComponents")?;
65 let (i, n) = multispaced(integer(bin))(i)?;
66 Ok((i, n as usize))
67}
68
69fn component(bin: bool) -> impl Fn(&[u8]) -> IResult<&[u8], Component> {
70 move |i: &[u8]| {
71 let (i, klass) = open_any(i)?;
72 let (i, attributes) = if bin { bin::attributes(i, klass)? } else { text::attributes(i)? };
73 let (i, _) = close(i, klass)?;
74 Ok((i, Component { klass: klass.to_string(), attributes }))
75 }
76}
77
78fn component_name(i: &[u8]) -> IResult<&[u8], &str> {
79 multispaced(delimited(|i| open(i, "ComponentName"), name, multispace0))(i)
80}
81
82pub fn open<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
83 map(multispaced(tuple((tag("<"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
84}
85
86pub fn close<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
87 map(multispaced(tuple((tag("</"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
88}
89
90pub fn open_any(i: &[u8]) -> IResult<&[u8], &str> {
91 multispaced(delimited(tag("<"), name, tag(">")))(i)
92}
93
94pub fn name(i: &[u8]) -> IResult<&[u8], &str> {
95 map_res(
96 recognize(pair(
97 alpha1,
98 nom::multi::many0(nom::branch::alt((alphanumeric1, tag("."), tag("_"), tag("-")))),
99 )),
100 std::str::from_utf8,
101 )(i)
102}
103
104pub fn integer<'a>(bin: bool) -> impl FnMut(&'a [u8]) -> IResult<&'a [u8], i32> {
105 if_then_else(
106 bin,
107 alt((preceded(tag([4]), le_i32), preceded(tag([8]), map(le_i64, |i| i as i32)))),
108 map_res(
109 map_res(
110 recognize(pair(opt(tag("-")), take_while(nom::character::is_digit))),
111 std::str::from_utf8,
112 ),
113 |s| s.parse::<i32>(),
114 ),
115 )
116}
117
118pub fn spaced<I, O, E: nom::error::ParseError<I>, F>(it: F) -> impl FnMut(I) -> nom::IResult<I, O, E>
119where
120 I: nom::InputTakeAtPosition,
121 <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
122 F: FnMut(I) -> nom::IResult<I, O, E>,
123{
124 delimited(space0, it, space0)
125}
126
127pub fn multispaced<I, O, E: nom::error::ParseError<I>, F>(
128 it: F,
129) -> impl FnMut(I) -> nom::IResult<I, O, E>
130where
131 I: nom::InputTakeAtPosition,
132 <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
133 F: FnMut(I) -> nom::IResult<I, O, E>,
134{
135 delimited(multispace0, it, multispace0)
136}