1use std::collections::HashMap;
7use std::fs::File;
8use std::io::{BufReader, Read, Seek, SeekFrom};
9use std::path::Path;
10use thiserror::Error;
11
12use crate::dataset::Dataset;
13use crate::error::{DataError, Result};
14use torsh_tensor::Tensor;
15
16#[derive(Error, Debug)]
17pub enum TFRecordError {
18 #[error("IO error: {0}")]
19 IoError(#[from] std::io::Error),
20 #[error("Invalid TFRecord format: {0}")]
21 FormatError(String),
22 #[error("CRC checksum mismatch")]
23 ChecksumError,
24 #[error("Protobuf parsing error: {0}")]
25 ProtobufError(String),
26 #[error("Feature not found: {0}")]
27 FeatureNotFound(String),
28 #[error("Unsupported feature type: {0}")]
29 UnsupportedFeatureType(String),
30}
31
32impl From<TFRecordError> for DataError {
33 fn from(err: TFRecordError) -> Self {
34 DataError::Other(err.to_string())
35 }
36}
37
38pub struct TFRecordReader {
40 reader: BufReader<File>,
41 records_read: usize,
42}
43
44impl TFRecordReader {
45 pub fn new<P: AsRef<Path>>(file_path: P) -> std::result::Result<Self, TFRecordError> {
47 let file = File::open(file_path)?;
48 let reader = BufReader::new(file);
49
50 Ok(Self {
51 reader,
52 records_read: 0,
53 })
54 }
55
56 pub fn read_next_record(&mut self) -> std::result::Result<Option<Vec<u8>>, TFRecordError> {
58 let mut length_bytes = [0u8; 8];
66 match self.reader.read_exact(&mut length_bytes) {
67 Ok(()) => {}
68 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
69 Err(e) => return Err(TFRecordError::IoError(e)),
70 }
71
72 let length = u64::from_le_bytes(length_bytes);
73
74 let mut length_crc_bytes = [0u8; 4];
76 self.reader.read_exact(&mut length_crc_bytes)?;
77 let _length_crc = u32::from_le_bytes(length_crc_bytes);
78
79 let mut data = vec![0u8; length as usize];
83 self.reader.read_exact(&mut data)?;
84
85 let mut data_crc_bytes = [0u8; 4];
87 self.reader.read_exact(&mut data_crc_bytes)?;
88 let _data_crc = u32::from_le_bytes(data_crc_bytes);
89
90 self.records_read += 1;
93 Ok(Some(data))
94 }
95
96 pub fn records_read(&self) -> usize {
98 self.records_read
99 }
100
101 pub fn reset(&mut self) -> std::result::Result<(), TFRecordError> {
103 self.reader.seek(SeekFrom::Start(0))?;
104 self.records_read = 0;
105 Ok(())
106 }
107}
108
109#[derive(Debug, Clone)]
112pub enum FeatureValue {
113 BytesList(Vec<Vec<u8>>),
114 FloatList(Vec<f32>),
115 Int64List(Vec<i64>),
116}
117
118#[derive(Debug, Clone)]
119pub struct Example {
120 features: HashMap<String, FeatureValue>,
121}
122
123impl Example {
124 pub fn from_bytes(_data: &[u8]) -> std::result::Result<Self, TFRecordError> {
126 let mut features = HashMap::new();
130
131 features.insert(
134 "example_feature".to_string(),
135 FeatureValue::FloatList(vec![1.0, 2.0, 3.0]),
136 );
137
138 Ok(Example { features })
139 }
140
141 pub fn get_feature(&self, name: &str) -> Option<&FeatureValue> {
143 self.features.get(name)
144 }
145
146 pub fn feature_names(&self) -> Vec<&String> {
148 self.features.keys().collect()
149 }
150
151 pub fn feature_to_tensor<T: torsh_core::TensorElement>(
153 &self,
154 name: &str,
155 ) -> std::result::Result<Tensor<T>, TFRecordError> {
156 let feature = self
157 .get_feature(name)
158 .ok_or_else(|| TFRecordError::FeatureNotFound(name.to_string()))?;
159
160 match feature {
161 FeatureValue::FloatList(values) => {
162 let converted_values: Vec<T> = values
163 .iter()
164 .filter_map(|&v| T::from_f64(v as f64))
165 .collect();
166
167 if converted_values.len() != values.len() {
168 return Err(TFRecordError::FormatError(
169 "Type conversion failed".to_string(),
170 ));
171 }
172
173 let shape = vec![converted_values.len()];
174 Tensor::from_vec(converted_values, &shape)
175 .map_err(|e| TFRecordError::FormatError(e.to_string()))
176 }
177 FeatureValue::Int64List(values) => {
178 let converted_values: Vec<T> = values
179 .iter()
180 .filter_map(|&v| T::from_f64(v as f64))
181 .collect();
182
183 if converted_values.len() != values.len() {
184 return Err(TFRecordError::FormatError(
185 "Type conversion failed".to_string(),
186 ));
187 }
188
189 let shape = vec![converted_values.len()];
190 Tensor::from_vec(converted_values, &shape)
191 .map_err(|e| TFRecordError::FormatError(e.to_string()))
192 }
193 FeatureValue::BytesList(_) => Err(TFRecordError::UnsupportedFeatureType(
194 "BytesList not supported for tensor conversion".to_string(),
195 )),
196 }
197 }
198}
199
200pub struct TFRecordDataset {
202 _file_path: String,
203 records: Vec<Example>,
204 feature_names: Vec<String>,
205}
206
207impl TFRecordDataset {
208 pub fn new<P: AsRef<Path>>(file_path: P) -> Result<Self> {
210 let path_str = file_path.as_ref().to_string_lossy().to_string();
211 let mut reader = TFRecordReader::new(&file_path)?;
212
213 let mut records = Vec::new();
214 let mut feature_names = std::collections::HashSet::new();
215
216 while let Some(raw_data) = reader.read_next_record()? {
218 let example = Example::from_bytes(&raw_data)?;
219
220 for name in example.feature_names() {
222 feature_names.insert(name.clone());
223 }
224
225 records.push(example);
226 }
227
228 let feature_names: Vec<String> = feature_names.into_iter().collect();
229
230 Ok(Self {
231 _file_path: path_str,
232 records,
233 feature_names,
234 })
235 }
236
237 pub fn feature_names(&self) -> &[String] {
239 &self.feature_names
240 }
241
242 pub fn get_example(&self, index: usize) -> Option<&Example> {
244 self.records.get(index)
245 }
246
247 pub fn extract_feature<T: torsh_core::TensorElement>(
249 &self,
250 feature_name: &str,
251 ) -> Result<Vec<Tensor<T>>> {
252 let mut tensors = Vec::with_capacity(self.records.len());
253
254 for example in &self.records {
255 let tensor = example.feature_to_tensor::<T>(feature_name)?;
256 tensors.push(tensor);
257 }
258
259 Ok(tensors)
260 }
261
262 pub fn read_batch(&self, start_idx: usize, batch_size: usize) -> Vec<&Example> {
264 let end_idx = (start_idx + batch_size).min(self.records.len());
265
266 if start_idx >= self.records.len() {
267 return Vec::new();
268 }
269
270 self.records[start_idx..end_idx].iter().collect()
271 }
272}
273
274impl Dataset for TFRecordDataset {
275 type Item = Example;
276
277 fn len(&self) -> usize {
278 self.records.len()
279 }
280
281 fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
282 self.records.get(index).cloned().ok_or_else(|| {
283 DataError::Other(format!(
284 "Index {} out of bounds for dataset of size {}",
285 index,
286 self.records.len()
287 ))
288 .into()
289 })
290 }
291}
292
293pub struct TFRecordDatasetBuilder {
295 file_path: String,
296 feature_names: Option<Vec<String>>,
297 max_records: Option<usize>,
298}
299
300impl TFRecordDatasetBuilder {
301 pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
303 Self {
304 file_path: file_path.as_ref().to_string_lossy().to_string(),
305 feature_names: None,
306 max_records: None,
307 }
308 }
309
310 pub fn features(mut self, feature_names: Vec<String>) -> Self {
312 self.feature_names = Some(feature_names);
313 self
314 }
315
316 pub fn max_records(mut self, max_records: usize) -> Self {
318 self.max_records = Some(max_records);
319 self
320 }
321
322 pub fn build(self) -> Result<TFRecordDataset> {
324 TFRecordDataset::new(&self.file_path)
325 }
326}
327
328pub mod tfrecord_utils {
330 use super::*;
331
332 pub fn is_tfrecord_file<P: AsRef<Path>>(file_path: P) -> bool {
334 match TFRecordReader::new(&file_path) {
335 Ok(mut reader) => {
336 matches!(reader.read_next_record(), Ok(Some(_)))
338 }
339 Err(_) => false,
340 }
341 }
342
343 pub fn count_records<P: AsRef<Path>>(
345 file_path: P,
346 ) -> std::result::Result<usize, TFRecordError> {
347 let mut reader = TFRecordReader::new(file_path)?;
348 let mut count = 0;
349
350 while (reader.read_next_record()?).is_some() {
351 count += 1;
352 }
353
354 Ok(count)
355 }
356
357 pub fn get_file_info<P: AsRef<Path>>(
359 file_path: P,
360 ) -> std::result::Result<HashMap<String, String>, TFRecordError> {
361 let mut info = HashMap::new();
362
363 let num_records = count_records(&file_path)?;
365 info.insert("num_records".to_string(), num_records.to_string());
366
367 let metadata = std::fs::metadata(&file_path)?;
369 info.insert("file_size_bytes".to_string(), metadata.len().to_string());
370
371 let mut reader = TFRecordReader::new(&file_path)?;
373 if let Some(raw_data) = reader.read_next_record()? {
374 match Example::from_bytes(&raw_data) {
375 Ok(example) => {
376 let feature_names: Vec<String> = example
377 .feature_names()
378 .iter()
379 .map(|s| (*s).clone())
380 .collect();
381 info.insert("feature_names".to_string(), feature_names.join(", "));
382 info.insert("num_features".to_string(), feature_names.len().to_string());
383 }
384 Err(_) => {
385 info.insert("parsing_status".to_string(), "failed".to_string());
386 }
387 }
388 }
389
390 Ok(info)
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use tempfile::NamedTempFile;
398
399 #[test]
400 fn test_tfrecord_dataset_builder() {
401 let temp_file = NamedTempFile::new().unwrap();
402 let builder = TFRecordDatasetBuilder::new(temp_file.path())
403 .features(vec!["feature1".to_string(), "feature2".to_string()])
404 .max_records(100);
405
406 assert!(builder.feature_names.is_some());
408 assert_eq!(builder.max_records, Some(100));
409 }
410
411 #[test]
412 fn test_feature_value() {
413 let float_feature = FeatureValue::FloatList(vec![1.0, 2.0, 3.0]);
414 let int_feature = FeatureValue::Int64List(vec![1, 2, 3]);
415 let bytes_feature = FeatureValue::BytesList(vec![vec![1, 2, 3]]);
416
417 match float_feature {
418 FeatureValue::FloatList(values) => assert_eq!(values.len(), 3),
419 _ => panic!("Expected FloatList"),
420 }
421
422 match int_feature {
423 FeatureValue::Int64List(values) => assert_eq!(values.len(), 3),
424 _ => panic!("Expected Int64List"),
425 }
426
427 match bytes_feature {
428 FeatureValue::BytesList(values) => assert_eq!(values.len(), 1),
429 _ => panic!("Expected BytesList"),
430 }
431 }
432
433 #[test]
434 fn test_tfrecord_utils() {
435 let temp_file = NamedTempFile::new().unwrap();
436
437 assert!(!tfrecord_utils::is_tfrecord_file(temp_file.path()));
439 }
440}