1use std::io::{BufRead, BufReader, BufWriter, Read, Write};
4
5use hashbrown::HashMap;
6use regex::Regex;
7
8use crate::errors::{Result, VibratoError};
9use crate::trainer::TrainerConfig;
10use crate::utils;
11
12#[allow(clippy::too_many_arguments)]
31pub fn generate_bigram_info(
32 feature_def_rdr: impl Read,
33 right_id_def_rdr: impl Read,
34 left_id_def_rdr: impl Read,
35 model_def_rdr: impl Read,
36 cost_factor: f64,
37 bigram_right_wtr: impl Write,
38 bigram_left_wtr: impl Write,
39 bigram_cost_wtr: impl Write,
40) -> Result<()> {
41 let mut left_features = HashMap::new();
42 let mut right_features = HashMap::new();
43
44 let mut feature_extractor = TrainerConfig::parse_feature_config(feature_def_rdr)?;
45
46 let id_feature_re = Regex::new(r"^([0-9]+) (.*)$").unwrap();
47 let model_re = Regex::new(r"^([0-9\-\.]+)\t(.*)$").unwrap();
48
49 let right_id_def_rdr = BufReader::new(right_id_def_rdr);
55 for line in right_id_def_rdr.lines() {
56 let line = line?;
57 if let Some(cap) = id_feature_re.captures(&line) {
58 let id = cap.get(1).unwrap().as_str().parse::<usize>()?;
59 let feature_str = cap.get(2).unwrap().as_str();
60 let feature_spl = utils::parse_csv_row(feature_str);
61 if id == 0 && feature_spl.first().is_some_and(|s| s != "BOS/EOS") {
62 return Err(VibratoError::invalid_format(
63 "right_id_def_rdr",
64 "ID 0 must be BOS/EOS",
65 ));
66 }
67 let feature_ids = feature_extractor.extract_left_feature_ids(&feature_spl);
68 left_features.insert(id, feature_ids);
69 } else {
70 return Err(VibratoError::invalid_format(
71 "right_id_def_rdr",
72 "each line must be a pair of an ID and features",
73 ));
74 }
75 }
76 let left_id_def_rdr = BufReader::new(left_id_def_rdr);
78 for line in left_id_def_rdr.lines() {
79 let line = line?;
80 if let Some(cap) = id_feature_re.captures(&line) {
81 let id = cap.get(1).unwrap().as_str().parse::<usize>()?;
82 let feature_str = cap.get(2).unwrap().as_str();
83 let feature_spl = utils::parse_csv_row(feature_str);
84 if id == 0 && feature_spl.first().is_some_and(|s| s != "BOS/EOS") {
85 return Err(VibratoError::invalid_format(
86 "left_id_def_rdr",
87 "ID 0 must be BOS/EOS",
88 ));
89 }
90 let feature_ids = feature_extractor.extract_right_feature_ids(&feature_spl);
91 right_features.insert(id, feature_ids);
92 } else {
93 return Err(VibratoError::invalid_format(
94 "left_id_def_rdr",
95 "each line must be a pair of an ID and features",
96 ));
97 }
98 }
99 let model_def_rdr = BufReader::new(model_def_rdr);
101 let mut bigram_cost_wtr = BufWriter::new(bigram_cost_wtr);
102 for line in model_def_rdr.lines() {
103 let line = line?;
104 if let Some(cap) = model_re.captures(&line) {
105 let weight = cap.get(1).unwrap().as_str().parse::<f64>()?;
106 let cost = -(weight * cost_factor) as i32;
107 if cost == 0 {
108 continue;
109 }
110 let feature_str = cap.get(2).unwrap().as_str().replace("BOS/EOS", "");
111 let mut spl = feature_str.split('/');
112 let left_feat_str = spl.next();
113 let right_feat_str = spl.next();
114 if let (Some(left_feat_str), Some(right_feat_str)) = (left_feat_str, right_feat_str) {
115 let left_id = if left_feat_str.is_empty() {
116 String::new()
117 } else if let Some(id) = feature_extractor.left_feature_ids().get(left_feat_str) {
118 id.to_string()
119 } else {
120 continue;
121 };
122 let right_id = if right_feat_str.is_empty() {
123 String::new()
124 } else if let Some(id) = feature_extractor.right_feature_ids().get(right_feat_str) {
125 id.to_string()
126 } else {
127 continue;
128 };
129 writeln!(&mut bigram_cost_wtr, "{left_id}/{right_id}\t{cost}")?;
130 }
131 }
132 }
133
134 let mut bigram_right_wtr = BufWriter::new(bigram_right_wtr);
135 for id in 1..left_features.len() {
136 write!(&mut bigram_right_wtr, "{id}\t")?;
137 if let Some(features) = left_features.get(&id) {
138 for (i, feat_id) in features.iter().enumerate() {
139 if i != 0 {
140 write!(&mut bigram_right_wtr, ",")?;
141 }
142 if let Some(feat_id) = feat_id {
143 write!(&mut bigram_right_wtr, "{}", feat_id.get())?;
144 } else {
145 write!(&mut bigram_right_wtr, "*")?;
146 }
147 }
148 } else {
149 return Err(VibratoError::invalid_format(
150 "right_id_def_rdr",
151 format!("feature ID {id} is undefined"),
152 ));
153 }
154 writeln!(&mut bigram_right_wtr)?;
155 }
156
157 let mut bigram_left_wtr = BufWriter::new(bigram_left_wtr);
158 for id in 1..right_features.len() {
159 write!(&mut bigram_left_wtr, "{id}\t")?;
160 if let Some(features) = right_features.get(&id) {
161 for (i, feat_id) in features.iter().enumerate() {
162 if i != 0 {
163 write!(&mut bigram_left_wtr, ",")?;
164 }
165 if let Some(feat_id) = feat_id {
166 write!(&mut bigram_left_wtr, "{}", feat_id.get())?;
167 } else {
168 write!(&mut bigram_left_wtr, "*")?;
169 }
170 }
171 writeln!(&mut bigram_left_wtr)?;
172 } else {
173 return Err(VibratoError::invalid_format(
174 "left_id_def_rdr",
175 format!("feature ID {id} is undefined"),
176 ));
177 }
178 }
179
180 Ok(())
181}