tiny_recursive_rs/data/
numpy_dataset.rs

1/// NumPy dataset loader for TinyRecursiveModels puzzle data (.npy format)
2use candle_core::{Result, Tensor, Device};
3use ndarray::{Array1, Array2, ArrayView1};
4use ndarray_npy::ReadNpyExt;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::BufReader;
8use std::path::Path;
9
10/// Metadata from dataset.json
11#[derive(Debug, Clone, Deserialize, Serialize)]
12pub struct DatasetMetadata {
13    pub vocab_size: usize,
14    pub seq_len: usize,
15    #[serde(default)]
16    pub num_examples: usize,
17    #[serde(default)]
18    pub description: String,
19}
20
21/// Dataset loaded from NumPy .npy files
22pub struct NumpyDataset {
23    inputs: Array2<i32>,        // [N, seq_len]
24    labels: Array2<i32>,        // [N, seq_len]
25    puzzle_ids: Vec<i32>,       // [M] - puzzle identifiers
26    metadata: DatasetMetadata,
27}
28
29impl NumpyDataset {
30    /// Load from directory containing .npy files and dataset.json
31    pub fn from_directory<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
32        let dir = path.as_ref();
33
34        log::info!("Loading NumPy dataset from: {:?}", dir);
35
36        // Load metadata
37        let metadata_path = dir.join("dataset.json");
38        let metadata: DatasetMetadata = if metadata_path.exists() {
39            let file = File::open(&metadata_path)?;
40            let reader = BufReader::new(file);
41            serde_json::from_reader(reader)?
42        } else {
43            log::warn!("dataset.json not found, using defaults");
44            DatasetMetadata {
45                vocab_size: 256,
46                seq_len: 64,
47                num_examples: 0,
48                description: "Unknown".to_string(),
49            }
50        };
51
52        // Load inputs (Python saves as i64, need to cast to i32)
53        let inputs_path = dir.join("all__inputs.npy");
54        let inputs_i64 = <Array2<i64> as ReadNpyExt>::read_npy(File::open(&inputs_path)?)
55            .map_err(|e| std::io::Error::new(
56                std::io::ErrorKind::InvalidData,
57                format!("Failed to read all__inputs.npy: {}", e)
58            ))?;
59        let inputs = inputs_i64.mapv(|x| x as i32);
60
61        log::info!("Loaded inputs: shape {:?}", inputs.shape());
62
63        // Load labels (Python saves as i64, need to cast to i32)
64        let labels_path = dir.join("all__labels.npy");
65        let labels_i64 = <Array2<i64> as ReadNpyExt>::read_npy(File::open(&labels_path)?)
66            .map_err(|e| std::io::Error::new(
67                std::io::ErrorKind::InvalidData,
68                format!("Failed to read all__labels.npy: {}", e)
69            ))?;
70        let labels = labels_i64.mapv(|x| x as i32);
71
72        log::info!("Loaded labels: shape {:?}", labels.shape());
73
74        // Load puzzle identifiers (optional)
75        let puzzle_ids_path = dir.join("all__puzzle_identifiers.npy");
76        let puzzle_ids: Vec<i32> = if puzzle_ids_path.exists() {
77            let ids = <Array1<i32> as ReadNpyExt>::read_npy(File::open(&puzzle_ids_path)?)
78                .map_err(|e| std::io::Error::new(
79                    std::io::ErrorKind::InvalidData,
80                    format!("Failed to read all__puzzle_identifiers.npy: {}", e)
81                ))?;
82            ids.to_vec()
83        } else {
84            log::warn!("all__puzzle_identifiers.npy not found, using empty vector");
85            Vec::new()
86        };
87
88        // Validate shapes
89        if inputs.shape() != labels.shape() {
90            return Err(crate::TRMError::Config(format!(
91                "Shape mismatch: inputs {:?} != labels {:?}",
92                inputs.shape(),
93                labels.shape()
94            )));
95        }
96
97        let num_examples = inputs.nrows();
98        let seq_len = inputs.ncols();
99
100        log::info!(
101            "Dataset loaded: {} examples, seq_len={}, vocab_size={}",
102            num_examples,
103            seq_len,
104            metadata.vocab_size
105        );
106
107        Ok(Self {
108            inputs,
109            labels,
110            puzzle_ids,
111            metadata,
112        })
113    }
114
115    /// Get number of examples
116    pub fn len(&self) -> usize {
117        self.inputs.nrows()
118    }
119
120    /// Check if empty
121    pub fn is_empty(&self) -> bool {
122        self.inputs.nrows() == 0
123    }
124
125    /// Get vocabulary size
126    pub fn vocab_size(&self) -> usize {
127        self.metadata.vocab_size
128    }
129
130    /// Get sequence length
131    pub fn seq_len(&self) -> usize {
132        self.inputs.ncols()
133    }
134
135    /// Get metadata
136    pub fn metadata(&self) -> &DatasetMetadata {
137        &self.metadata
138    }
139
140    /// Get input at index
141    pub fn get_input(&self, idx: usize) -> ArrayView1<i32> {
142        self.inputs.row(idx)
143    }
144
145    /// Get label at index
146    pub fn get_label(&self, idx: usize) -> ArrayView1<i32> {
147        self.labels.row(idx)
148    }
149
150    /// Get puzzle ID at index (if available)
151    pub fn get_puzzle_id(&self, idx: usize) -> Option<i32> {
152        if idx < self.puzzle_ids.len() {
153            Some(self.puzzle_ids[idx])
154        } else {
155            None
156        }
157    }
158}
159
160/// Data loader for NumPy puzzle datasets
161pub struct NumpyDataLoader {
162    dataset: NumpyDataset,
163    batch_size: usize,
164    current_idx: usize,
165    indices: Vec<usize>,
166    shuffle: bool,
167}
168
169impl NumpyDataLoader {
170    /// Create new data loader
171    pub fn new(dataset: NumpyDataset, batch_size: usize, shuffle: bool) -> Self {
172        let num_samples = dataset.len();
173        let mut indices: Vec<usize> = (0..num_samples).collect();
174
175        if shuffle {
176            use rand::seq::SliceRandom;
177            let mut rng = rand::thread_rng();
178            indices.shuffle(&mut rng);
179        }
180
181        Self {
182            dataset,
183            batch_size,
184            current_idx: 0,
185            indices,
186            shuffle,
187        }
188    }
189
190    /// Get next batch (input_ids, target_ids)
191    pub fn next_batch(&mut self, device: &Device) -> Result<Option<(Tensor, Tensor)>> {
192        if self.current_idx >= self.indices.len() {
193            return Ok(None);
194        }
195
196        let end_idx = (self.current_idx + self.batch_size).min(self.indices.len());
197        let batch_indices = &self.indices[self.current_idx..end_idx];
198        let actual_batch_size = batch_indices.len();
199
200        // Collect sequences for this batch
201        let mut input_data = Vec::new();
202        let mut target_data = Vec::new();
203
204        for &idx in batch_indices {
205            let input = self.dataset.get_input(idx);
206            let target = self.dataset.get_label(idx);
207
208            // Convert i32 to u32 for Candle
209            input_data.extend(input.iter().map(|&x| x as u32));
210            target_data.extend(target.iter().map(|&x| x as u32));
211        }
212
213        self.current_idx = end_idx;
214
215        // Convert to tensors
216        let seq_len = self.dataset.seq_len();
217        let input_tensor = Tensor::from_vec(
218            input_data,
219            (actual_batch_size, seq_len),
220            device,
221        )?.to_dtype(candle_core::DType::U32)?;
222
223        let target_tensor = Tensor::from_vec(
224            target_data,
225            (actual_batch_size, seq_len),
226            device,
227        )?.to_dtype(candle_core::DType::U32)?;
228
229        Ok(Some((input_tensor, target_tensor)))
230    }
231
232    /// Reset loader for new epoch
233    pub fn reset(&mut self) {
234        self.current_idx = 0;
235
236        if self.shuffle {
237            use rand::seq::SliceRandom;
238            let mut rng = rand::thread_rng();
239            self.indices.shuffle(&mut rng);
240        }
241    }
242
243    /// Get number of batches
244    pub fn num_batches(&self) -> usize {
245        (self.dataset.len() + self.batch_size - 1) / self.batch_size
246    }
247
248    /// Get dataset reference
249    pub fn dataset(&self) -> &NumpyDataset {
250        &self.dataset
251    }
252}
253
254// Implement BatchDataLoader trait
255impl super::BatchDataLoader for NumpyDataLoader {
256    fn next_batch(&mut self, device: &Device) -> Result<Option<(Tensor, Tensor)>> {
257        NumpyDataLoader::next_batch(self, device)
258    }
259
260    fn reset(&mut self) {
261        NumpyDataLoader::reset(self)
262    }
263
264    fn num_batches(&self) -> usize {
265        NumpyDataLoader::num_batches(self)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_metadata_deserialization() {
275        let json = r#"{
276            "vocab_size": 11,
277            "seq_len": 81,
278            "num_examples": 1000000,
279            "description": "Sudoku-Extreme"
280        }"#;
281
282        let metadata: DatasetMetadata = serde_json::from_str(json).unwrap();
283        assert_eq!(metadata.vocab_size, 11);
284        assert_eq!(metadata.seq_len, 81);
285        assert_eq!(metadata.num_examples, 1000000);
286    }
287}