torsh_data/
tfrecord_integration.rs

1//! TensorFlow TFRecord format integration
2//!
3//! This module provides functionality to read TensorFlow's TFRecord files,
4//! which are commonly used for storing training data in TensorFlow ecosystems.
5
6use std::collections::HashMap;
7use std::fs::File;
8use std::io::{BufReader, Read, Seek, SeekFrom};
9use std::path::Path;
10use thiserror::Error;
11
12use crate::dataset::Dataset;
13use crate::error::{DataError, Result};
14use torsh_tensor::Tensor;
15
16#[derive(Error, Debug)]
17pub enum TFRecordError {
18    #[error("IO error: {0}")]
19    IoError(#[from] std::io::Error),
20    #[error("Invalid TFRecord format: {0}")]
21    FormatError(String),
22    #[error("CRC checksum mismatch")]
23    ChecksumError,
24    #[error("Protobuf parsing error: {0}")]
25    ProtobufError(String),
26    #[error("Feature not found: {0}")]
27    FeatureNotFound(String),
28    #[error("Unsupported feature type: {0}")]
29    UnsupportedFeatureType(String),
30}
31
32impl From<TFRecordError> for DataError {
33    fn from(err: TFRecordError) -> Self {
34        DataError::Other(err.to_string())
35    }
36}
37
38/// A TFRecord reader that can parse TensorFlow's binary record format
39pub struct TFRecordReader {
40    reader: BufReader<File>,
41    records_read: usize,
42}
43
44impl TFRecordReader {
45    /// Create a new TFRecord reader from a file path
46    pub fn new<P: AsRef<Path>>(file_path: P) -> std::result::Result<Self, TFRecordError> {
47        let file = File::open(file_path)?;
48        let reader = BufReader::new(file);
49
50        Ok(Self {
51            reader,
52            records_read: 0,
53        })
54    }
55
56    /// Read the next record from the TFRecord file
57    pub fn read_next_record(&mut self) -> std::result::Result<Option<Vec<u8>>, TFRecordError> {
58        // TFRecord format:
59        // uint64 length
60        // uint32 masked_crc32_of_length
61        // byte data[length]
62        // uint32 masked_crc32_of_data
63
64        // Read length (8 bytes, little endian)
65        let mut length_bytes = [0u8; 8];
66        match self.reader.read_exact(&mut length_bytes) {
67            Ok(()) => {}
68            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
69            Err(e) => return Err(TFRecordError::IoError(e)),
70        }
71
72        let length = u64::from_le_bytes(length_bytes);
73
74        // Read length CRC (4 bytes, little endian)
75        let mut length_crc_bytes = [0u8; 4];
76        self.reader.read_exact(&mut length_crc_bytes)?;
77        let _length_crc = u32::from_le_bytes(length_crc_bytes);
78
79        // For now, skip CRC verification (would need crc32 implementation)
80
81        // Read data
82        let mut data = vec![0u8; length as usize];
83        self.reader.read_exact(&mut data)?;
84
85        // Read data CRC (4 bytes, little endian)
86        let mut data_crc_bytes = [0u8; 4];
87        self.reader.read_exact(&mut data_crc_bytes)?;
88        let _data_crc = u32::from_le_bytes(data_crc_bytes);
89
90        // For now, skip CRC verification
91
92        self.records_read += 1;
93        Ok(Some(data))
94    }
95
96    /// Get the number of records read so far
97    pub fn records_read(&self) -> usize {
98        self.records_read
99    }
100
101    /// Reset the reader to the beginning of the file
102    pub fn reset(&mut self) -> std::result::Result<(), TFRecordError> {
103        self.reader.seek(SeekFrom::Start(0))?;
104        self.records_read = 0;
105        Ok(())
106    }
107}
108
109/// Simple protobuf-like parsing for TensorFlow Example format
110/// This is a simplified implementation that handles basic Example structures
111#[derive(Debug, Clone)]
112pub enum FeatureValue {
113    BytesList(Vec<Vec<u8>>),
114    FloatList(Vec<f32>),
115    Int64List(Vec<i64>),
116}
117
118#[derive(Debug, Clone)]
119pub struct Example {
120    features: HashMap<String, FeatureValue>,
121}
122
123impl Example {
124    /// Parse an Example from raw bytes (simplified protobuf parsing)
125    pub fn from_bytes(_data: &[u8]) -> std::result::Result<Self, TFRecordError> {
126        // This is a very simplified protobuf parser for TensorFlow Example format
127        // In a real implementation, you'd use a proper protobuf library
128
129        let mut features = HashMap::new();
130
131        // For now, create dummy features as a placeholder
132        // Real implementation would parse the protobuf data
133        features.insert(
134            "example_feature".to_string(),
135            FeatureValue::FloatList(vec![1.0, 2.0, 3.0]),
136        );
137
138        Ok(Example { features })
139    }
140
141    /// Get a feature by name
142    pub fn get_feature(&self, name: &str) -> Option<&FeatureValue> {
143        self.features.get(name)
144    }
145
146    /// Get all feature names
147    pub fn feature_names(&self) -> Vec<&String> {
148        self.features.keys().collect()
149    }
150
151    /// Convert a feature to a tensor
152    pub fn feature_to_tensor<T: torsh_core::TensorElement>(
153        &self,
154        name: &str,
155    ) -> std::result::Result<Tensor<T>, TFRecordError> {
156        let feature = self
157            .get_feature(name)
158            .ok_or_else(|| TFRecordError::FeatureNotFound(name.to_string()))?;
159
160        match feature {
161            FeatureValue::FloatList(values) => {
162                let converted_values: Vec<T> = values
163                    .iter()
164                    .filter_map(|&v| T::from_f64(v as f64))
165                    .collect();
166
167                if converted_values.len() != values.len() {
168                    return Err(TFRecordError::FormatError(
169                        "Type conversion failed".to_string(),
170                    ));
171                }
172
173                let shape = vec![converted_values.len()];
174                Tensor::from_vec(converted_values, &shape)
175                    .map_err(|e| TFRecordError::FormatError(e.to_string()))
176            }
177            FeatureValue::Int64List(values) => {
178                let converted_values: Vec<T> = values
179                    .iter()
180                    .filter_map(|&v| T::from_f64(v as f64))
181                    .collect();
182
183                if converted_values.len() != values.len() {
184                    return Err(TFRecordError::FormatError(
185                        "Type conversion failed".to_string(),
186                    ));
187                }
188
189                let shape = vec![converted_values.len()];
190                Tensor::from_vec(converted_values, &shape)
191                    .map_err(|e| TFRecordError::FormatError(e.to_string()))
192            }
193            FeatureValue::BytesList(_) => Err(TFRecordError::UnsupportedFeatureType(
194                "BytesList not supported for tensor conversion".to_string(),
195            )),
196        }
197    }
198}
199
200/// Dataset for reading TFRecord files
201pub struct TFRecordDataset {
202    _file_path: String,
203    records: Vec<Example>,
204    feature_names: Vec<String>,
205}
206
207impl TFRecordDataset {
208    /// Create a new TFRecordDataset from a file path
209    pub fn new<P: AsRef<Path>>(file_path: P) -> Result<Self> {
210        let path_str = file_path.as_ref().to_string_lossy().to_string();
211        let mut reader = TFRecordReader::new(&file_path)?;
212
213        let mut records = Vec::new();
214        let mut feature_names = std::collections::HashSet::new();
215
216        // Read all records
217        while let Some(raw_data) = reader.read_next_record()? {
218            let example = Example::from_bytes(&raw_data)?;
219
220            // Collect feature names
221            for name in example.feature_names() {
222                feature_names.insert(name.clone());
223            }
224
225            records.push(example);
226        }
227
228        let feature_names: Vec<String> = feature_names.into_iter().collect();
229
230        Ok(Self {
231            _file_path: path_str,
232            records,
233            feature_names,
234        })
235    }
236
237    /// Get feature names
238    pub fn feature_names(&self) -> &[String] {
239        &self.feature_names
240    }
241
242    /// Get a specific example by index
243    pub fn get_example(&self, index: usize) -> Option<&Example> {
244        self.records.get(index)
245    }
246
247    /// Extract a specific feature from all records as tensors
248    pub fn extract_feature<T: torsh_core::TensorElement>(
249        &self,
250        feature_name: &str,
251    ) -> Result<Vec<Tensor<T>>> {
252        let mut tensors = Vec::with_capacity(self.records.len());
253
254        for example in &self.records {
255            let tensor = example.feature_to_tensor::<T>(feature_name)?;
256            tensors.push(tensor);
257        }
258
259        Ok(tensors)
260    }
261
262    /// Read a batch of examples
263    pub fn read_batch(&self, start_idx: usize, batch_size: usize) -> Vec<&Example> {
264        let end_idx = (start_idx + batch_size).min(self.records.len());
265
266        if start_idx >= self.records.len() {
267            return Vec::new();
268        }
269
270        self.records[start_idx..end_idx].iter().collect()
271    }
272}
273
274impl Dataset for TFRecordDataset {
275    type Item = Example;
276
277    fn len(&self) -> usize {
278        self.records.len()
279    }
280
281    fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
282        self.records.get(index).cloned().ok_or_else(|| {
283            DataError::Other(format!(
284                "Index {} out of bounds for dataset of size {}",
285                index,
286                self.records.len()
287            ))
288            .into()
289        })
290    }
291}
292
293/// Builder for creating TFRecordDataset with configuration options
294pub struct TFRecordDatasetBuilder {
295    file_path: String,
296    feature_names: Option<Vec<String>>,
297    max_records: Option<usize>,
298}
299
300impl TFRecordDatasetBuilder {
301    /// Create a new builder
302    pub fn new<P: AsRef<Path>>(file_path: P) -> Self {
303        Self {
304            file_path: file_path.as_ref().to_string_lossy().to_string(),
305            feature_names: None,
306            max_records: None,
307        }
308    }
309
310    /// Select specific features to extract
311    pub fn features(mut self, feature_names: Vec<String>) -> Self {
312        self.feature_names = Some(feature_names);
313        self
314    }
315
316    /// Limit the number of records to read
317    pub fn max_records(mut self, max_records: usize) -> Self {
318        self.max_records = Some(max_records);
319        self
320    }
321
322    /// Build the TFRecordDataset
323    pub fn build(self) -> Result<TFRecordDataset> {
324        TFRecordDataset::new(&self.file_path)
325    }
326}
327
328/// Utility functions for TFRecord operations
329pub mod tfrecord_utils {
330    use super::*;
331
332    /// Check if a file appears to be a TFRecord file
333    pub fn is_tfrecord_file<P: AsRef<Path>>(file_path: P) -> bool {
334        match TFRecordReader::new(&file_path) {
335            Ok(mut reader) => {
336                // Try to read the first record
337                matches!(reader.read_next_record(), Ok(Some(_)))
338            }
339            Err(_) => false,
340        }
341    }
342
343    /// Count the number of records in a TFRecord file
344    pub fn count_records<P: AsRef<Path>>(
345        file_path: P,
346    ) -> std::result::Result<usize, TFRecordError> {
347        let mut reader = TFRecordReader::new(file_path)?;
348        let mut count = 0;
349
350        while (reader.read_next_record()?).is_some() {
351            count += 1;
352        }
353
354        Ok(count)
355    }
356
357    /// Get basic information about a TFRecord file
358    pub fn get_file_info<P: AsRef<Path>>(
359        file_path: P,
360    ) -> std::result::Result<HashMap<String, String>, TFRecordError> {
361        let mut info = HashMap::new();
362
363        // Count records
364        let num_records = count_records(&file_path)?;
365        info.insert("num_records".to_string(), num_records.to_string());
366
367        // Get file size
368        let metadata = std::fs::metadata(&file_path)?;
369        info.insert("file_size_bytes".to_string(), metadata.len().to_string());
370
371        // Try to read first record to get feature information
372        let mut reader = TFRecordReader::new(&file_path)?;
373        if let Some(raw_data) = reader.read_next_record()? {
374            match Example::from_bytes(&raw_data) {
375                Ok(example) => {
376                    let feature_names: Vec<String> = example
377                        .feature_names()
378                        .iter()
379                        .map(|s| (*s).clone())
380                        .collect();
381                    info.insert("feature_names".to_string(), feature_names.join(", "));
382                    info.insert("num_features".to_string(), feature_names.len().to_string());
383                }
384                Err(_) => {
385                    info.insert("parsing_status".to_string(), "failed".to_string());
386                }
387            }
388        }
389
390        Ok(info)
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use tempfile::NamedTempFile;
398
399    #[test]
400    fn test_tfrecord_dataset_builder() {
401        let temp_file = NamedTempFile::new().unwrap();
402        let builder = TFRecordDatasetBuilder::new(temp_file.path())
403            .features(vec!["feature1".to_string(), "feature2".to_string()])
404            .max_records(100);
405
406        // Test builder configuration
407        assert!(builder.feature_names.is_some());
408        assert_eq!(builder.max_records, Some(100));
409    }
410
411    #[test]
412    fn test_feature_value() {
413        let float_feature = FeatureValue::FloatList(vec![1.0, 2.0, 3.0]);
414        let int_feature = FeatureValue::Int64List(vec![1, 2, 3]);
415        let bytes_feature = FeatureValue::BytesList(vec![vec![1, 2, 3]]);
416
417        match float_feature {
418            FeatureValue::FloatList(values) => assert_eq!(values.len(), 3),
419            _ => panic!("Expected FloatList"),
420        }
421
422        match int_feature {
423            FeatureValue::Int64List(values) => assert_eq!(values.len(), 3),
424            _ => panic!("Expected Int64List"),
425        }
426
427        match bytes_feature {
428            FeatureValue::BytesList(values) => assert_eq!(values.len(), 1),
429            _ => panic!("Expected BytesList"),
430        }
431    }
432
433    #[test]
434    fn test_tfrecord_utils() {
435        let temp_file = NamedTempFile::new().unwrap();
436
437        // Test with invalid file
438        assert!(!tfrecord_utils::is_tfrecord_file(temp_file.path()));
439    }
440}