1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use tract_hir::internal::*;

use nom::IResult;
use nom::{
    branch::alt,
    bytes::complete::*,
    character::complete::*,
    combinator::*,
    number::complete::{le_i32, le_i64},
    sequence::*,
};

use std::collections::HashMap;

use crate::model::{Component, KaldiProtoModel};

use tract_itertools::Itertools;

mod bin;
mod components;
mod config_lines;
mod descriptor;
mod text;

pub fn nnet3(slice: &[u8]) -> TractResult<KaldiProtoModel> {
    let (_, (config, components)) = parse_top_level(slice).map_err(|e| match e {
        nom::Err::Error(err) => format_err!(
            "Parsing kaldi enveloppe at: {:?}",
            err.0.iter().take(120).map(|b| format!("{}", *b as char)).join("")
        ),
        e => format_err!("{:?}", e),
    })?;
    let config_lines = config_lines::parse_config(config)?;
    Ok(KaldiProtoModel { config_lines, components, adjust_final_offset: 0 })
}

pub fn if_then_else<'a, T>(
    condition: bool,
    then: impl Fn(&'a [u8]) -> IResult<&'a [u8], T>,
    otherwise: impl Fn(&'a [u8]) -> IResult<&'a [u8], T>,
) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], T> {
    map(pair(cond(condition, then), cond(!condition, otherwise)), |(a, b)| a.or(b).unwrap())
}

fn parse_top_level(i: &[u8]) -> IResult<&[u8], (&str, HashMap<String, Component>)> {
    let (i, bin) = map(opt(tag([0, 0x42])), |o| Option::is_some(&o))(i)?;
    let (i, _) = open(i, "Nnet3")?;
    let (i, config_lines) = map_res(take_until("<NumComponents>"), std::str::from_utf8)(i)?;
    let (i, num_components) = num_components(bin, i)?;
    let mut components = HashMap::new();
    let mut i = i;
    for _ in 0..num_components {
        let (new_i, name) = component_name(i)?;
        debug!("Parsing component {}", name);
        let (new_i, comp) = component(bin)(new_i)?;
        i = new_i;
        components.insert(name.to_owned(), comp);
    }
    let (i, _) = close(i, "Nnet3")?;
    Ok((i, (config_lines, components)))
}

fn num_components(bin: bool, i: &[u8]) -> IResult<&[u8], usize> {
    let (i, _) = open(i, "NumComponents")?;
    let (i, n) = multispaced(integer(bin))(i)?;
    Ok((i, n as usize))
}

fn component(bin: bool) -> impl Fn(&[u8]) -> IResult<&[u8], Component> {
    move |i: &[u8]| {
        let (i, klass) = open_any(i)?;
        let (i, attributes) = if bin { bin::attributes(i, klass)? } else { text::attributes(i)? };
        let (i, _) = close(i, klass)?;
        Ok((i, Component { klass: klass.to_string(), attributes }))
    }
}

fn component_name(i: &[u8]) -> IResult<&[u8], &str> {
    multispaced(delimited(|i| open(i, "ComponentName"), name, multispace0))(i)
}

pub fn open<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
    map(multispaced(tuple((tag("<"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
}

pub fn close<'a>(i: &'a [u8], t: &str) -> IResult<&'a [u8], ()> {
    map(multispaced(tuple((tag("</"), tag(t.as_bytes()), tag(">")))), |_| ())(i)
}

pub fn open_any(i: &[u8]) -> IResult<&[u8], &str> {
    multispaced(delimited(tag("<"), name, tag(">")))(i)
}

pub fn name(i: &[u8]) -> IResult<&[u8], &str> {
    map_res(
        recognize(pair(
            alpha1,
            nom::multi::many0(nom::branch::alt((alphanumeric1, tag("."), tag("_"), tag("-")))),
        )),
        std::str::from_utf8,
    )(i)
}

pub fn integer<'a>(bin: bool) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], i32> {
    if_then_else(
        bin,
        alt((preceded(tag([4]), le_i32), preceded(tag([8]), map(le_i64, |i| i as i32)))),
        map_res(
            map_res(
                recognize(pair(opt(tag("-")), take_while(nom::character::is_digit))),
                std::str::from_utf8,
            ),
            |s| s.parse::<i32>(),
        ),
    )
}

pub fn spaced<I, O, E: nom::error::ParseError<I>, F>(it: F) -> impl Fn(I) -> nom::IResult<I, O, E>
where
    I: nom::InputTakeAtPosition,
    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
    F: Fn(I) -> nom::IResult<I, O, E>,
{
    delimited(space0, it, space0)
}

pub fn multispaced<I, O, E: nom::error::ParseError<I>, F>(
    it: F,
) -> impl Fn(I) -> nom::IResult<I, O, E>
where
    I: nom::InputTakeAtPosition,
    <I as nom::InputTakeAtPosition>::Item: nom::AsChar + Clone,
    F: Fn(I) -> nom::IResult<I, O, E>,
{
    delimited(multispace0, it, multispace0)
}