tract_kaldi/
parser.rs

1use tract_hir::internal::*;
2
3use nom::IResult;
4use nom::{
5    branch::alt,
6    bytes::complete::*,
7    character::complete::*,
8    combinator::*,
9    number::complete::{le_i32, le_i64},
10    sequence::*,
11};
12
13use std::collections::HashMap;
14
15use crate::model::{Component, KaldiProtoModel};
16
17use tract_itertools::Itertools;
18
19mod bin;
20mod components;
21mod config_lines;
22mod descriptor;
23mod text;
24
25pub fn nnet3(slice: &[u8]) -> TractResult<KaldiProtoModel> {
26    let (_, (config, components)) = parse_top_level(slice).map_err(|e| match e {
27        nom::Err::Error(err) => format_err!(
28            "Parsing kaldi enveloppe at: {:?}",
29            err.input.iter().take(120).map(|b| format!("{}", *b as char)).join("")
30        ),
31        e => format_err!("{:?}", e),
32    })?;
33    let config_lines = config_lines::parse_config(config)?;
34    Ok(KaldiProtoModel { config_lines, components, adjust_final_offset: 0 })
35}
36
37pub fn if_then_else<'a, T>(
38    condition: bool,
39    then: impl FnMut(&'a [u8]) -> IResult<&'a [u8], T>,
40    otherwise: impl FnMut(&'a [u8]) -> IResult<&'a [u8], T>,
41) -> impl FnMut(&'a [u8]) -> IResult<&'a [u8], T> {
42    map(pair(cond(condition, then), cond(!condition, otherwise)), |(a, b)| a.or(b).unwrap())
43}
44
45fn parse_top_level(i: &[u8]) -> IResult<&[u8], (&str, HashMap<String, Component>)> {
46    let (i, bin) = map(opt(tag([0, 0x42])), |o| Option::is_some(&o))(i)?;
47    let (i, _) = open(i, "Nnet3")?;
48    let (i, config_lines) = map_res(take_until("<NumComponents>"), std::str::from_utf8)(i)?;
49    let (i, num_components) = num_components(bin, i)?;
50    let mut components = HashMap::new();
51    let mut i = i;
52    for _ in 0..num_components {
53        let (new_i, name) = component_name(i)?;
54        debug!("Parsing component {}", name);
55        let (new_i, comp) = component(bin)(new_i)?;
56        i = new_i;
57        components.insert(name.to_owned(), comp);
58    }
59    let (i, _) = close(i, "Nnet3")?;
60    Ok((i, (config_lines, components)))
61}
62
63fn num_components(bin: bool, i: &[u8]) -> IResult<&[u8], usize> {
64    let (i, _) = open(i, "NumComponents")?;
65    let (i, n) = multispaced(integer(bin))(i)?;
66    Ok((i, n as usize))
67}
68
69fn component(bin: bool) -> impl Fn(&[u8]) -> IResult<&[u8], Component> {
70    move |i: &[u8]| {
71        let (i, klass) = open_any(i)?;
72        let (i, attributes) = if bin { bin::attributes(i, klass)? } else { text::attributes(i)? };
73        let (i, _) = close(i, klass)?;
74        Ok((i, Component { klass: klass.to_string(), attributes }))
75    }
76}
77
78fn component_name(i: &[u8]) -> IResult<&[u8], &str> {
79    multispaced(delimited(|i| open(i, "ComponentName"), name, multispace0))(i)
80}
81
82pub fn open<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
83    map(multispaced(tuple((tag("<"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
84}
85
86pub fn close<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
87    map(multispaced(tuple((tag("</"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
88}
89
90pub fn open_any(i: &[u8]) -> IResult<&[u8], &str> {
91    multispaced(delimited(tag("<"), name, tag(">")))(i)
92}
93
94pub fn name(i: &[u8]) -> IResult<&[u8], &str> {
95    map_res(
96        recognize(pair(
97            alpha1,
98            nom::multi::many0(nom::branch::alt((alphanumeric1, tag("."), tag("_"), tag("-")))),
99        )),
100        std::str::from_utf8,
101    )(i)
102}
103
104pub fn integer<'a>(bin: bool) -> impl FnMut(&'a [u8]) -> IResult<&'a [u8], i32> {
105    if_then_else(
106        bin,
107        alt((preceded(tag([4]), le_i32), preceded(tag([8]), map(le_i64, |i| i as i32)))),
108        map_res(
109            map_res(
110                recognize(pair(opt(tag("-")), take_while(nom::character::is_digit))),
111                std::str::from_utf8,
112            ),
113            |s| s.parse::<i32>(),
114        ),
115    )
116}
117
118pub fn spaced<I, O, E: nom::error::ParseError<I>, F>(it: F) -> impl FnMut(I) -> nom::IResult<I, O, E>
119where
120    I: nom::InputTakeAtPosition,
121    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
122    F: FnMut(I) -> nom::IResult<I, O, E>,
123{
124    delimited(space0, it, space0)
125}
126
127pub fn multispaced<I, O, E: nom::error::ParseError<I>, F>(
128    it: F,
129) -> impl FnMut(I) -> nom::IResult<I, O, E>
130where
131    I: nom::InputTakeAtPosition,
132    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
133    F: FnMut(I) -> nom::IResult<I, O, E>,
134{
135    delimited(multispace0, it, multispace0)
136}