tiny_recursive_rs/data/
numpy_dataset.rs1use candle_core::{Result, Tensor, Device};
3use ndarray::{Array1, Array2, ArrayView1};
4use ndarray_npy::ReadNpyExt;
5use serde::{Deserialize, Serialize};
6use std::fs::File;
7use std::io::BufReader;
8use std::path::Path;
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
12pub struct DatasetMetadata {
13 pub vocab_size: usize,
14 pub seq_len: usize,
15 #[serde(default)]
16 pub num_examples: usize,
17 #[serde(default)]
18 pub description: String,
19}
20
21pub struct NumpyDataset {
23 inputs: Array2<i32>, labels: Array2<i32>, puzzle_ids: Vec<i32>, metadata: DatasetMetadata,
27}
28
29impl NumpyDataset {
30 pub fn from_directory<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
32 let dir = path.as_ref();
33
34 log::info!("Loading NumPy dataset from: {:?}", dir);
35
36 let metadata_path = dir.join("dataset.json");
38 let metadata: DatasetMetadata = if metadata_path.exists() {
39 let file = File::open(&metadata_path)?;
40 let reader = BufReader::new(file);
41 serde_json::from_reader(reader)?
42 } else {
43 log::warn!("dataset.json not found, using defaults");
44 DatasetMetadata {
45 vocab_size: 256,
46 seq_len: 64,
47 num_examples: 0,
48 description: "Unknown".to_string(),
49 }
50 };
51
52 let inputs_path = dir.join("all__inputs.npy");
54 let inputs_i64 = <Array2<i64> as ReadNpyExt>::read_npy(File::open(&inputs_path)?)
55 .map_err(|e| std::io::Error::new(
56 std::io::ErrorKind::InvalidData,
57 format!("Failed to read all__inputs.npy: {}", e)
58 ))?;
59 let inputs = inputs_i64.mapv(|x| x as i32);
60
61 log::info!("Loaded inputs: shape {:?}", inputs.shape());
62
63 let labels_path = dir.join("all__labels.npy");
65 let labels_i64 = <Array2<i64> as ReadNpyExt>::read_npy(File::open(&labels_path)?)
66 .map_err(|e| std::io::Error::new(
67 std::io::ErrorKind::InvalidData,
68 format!("Failed to read all__labels.npy: {}", e)
69 ))?;
70 let labels = labels_i64.mapv(|x| x as i32);
71
72 log::info!("Loaded labels: shape {:?}", labels.shape());
73
74 let puzzle_ids_path = dir.join("all__puzzle_identifiers.npy");
76 let puzzle_ids: Vec<i32> = if puzzle_ids_path.exists() {
77 let ids = <Array1<i32> as ReadNpyExt>::read_npy(File::open(&puzzle_ids_path)?)
78 .map_err(|e| std::io::Error::new(
79 std::io::ErrorKind::InvalidData,
80 format!("Failed to read all__puzzle_identifiers.npy: {}", e)
81 ))?;
82 ids.to_vec()
83 } else {
84 log::warn!("all__puzzle_identifiers.npy not found, using empty vector");
85 Vec::new()
86 };
87
88 if inputs.shape() != labels.shape() {
90 return Err(crate::TRMError::Config(format!(
91 "Shape mismatch: inputs {:?} != labels {:?}",
92 inputs.shape(),
93 labels.shape()
94 )));
95 }
96
97 let num_examples = inputs.nrows();
98 let seq_len = inputs.ncols();
99
100 log::info!(
101 "Dataset loaded: {} examples, seq_len={}, vocab_size={}",
102 num_examples,
103 seq_len,
104 metadata.vocab_size
105 );
106
107 Ok(Self {
108 inputs,
109 labels,
110 puzzle_ids,
111 metadata,
112 })
113 }
114
115 pub fn len(&self) -> usize {
117 self.inputs.nrows()
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.inputs.nrows() == 0
123 }
124
125 pub fn vocab_size(&self) -> usize {
127 self.metadata.vocab_size
128 }
129
130 pub fn seq_len(&self) -> usize {
132 self.inputs.ncols()
133 }
134
135 pub fn metadata(&self) -> &DatasetMetadata {
137 &self.metadata
138 }
139
140 pub fn get_input(&self, idx: usize) -> ArrayView1<i32> {
142 self.inputs.row(idx)
143 }
144
145 pub fn get_label(&self, idx: usize) -> ArrayView1<i32> {
147 self.labels.row(idx)
148 }
149
150 pub fn get_puzzle_id(&self, idx: usize) -> Option<i32> {
152 if idx < self.puzzle_ids.len() {
153 Some(self.puzzle_ids[idx])
154 } else {
155 None
156 }
157 }
158}
159
160pub struct NumpyDataLoader {
162 dataset: NumpyDataset,
163 batch_size: usize,
164 current_idx: usize,
165 indices: Vec<usize>,
166 shuffle: bool,
167}
168
169impl NumpyDataLoader {
170 pub fn new(dataset: NumpyDataset, batch_size: usize, shuffle: bool) -> Self {
172 let num_samples = dataset.len();
173 let mut indices: Vec<usize> = (0..num_samples).collect();
174
175 if shuffle {
176 use rand::seq::SliceRandom;
177 let mut rng = rand::thread_rng();
178 indices.shuffle(&mut rng);
179 }
180
181 Self {
182 dataset,
183 batch_size,
184 current_idx: 0,
185 indices,
186 shuffle,
187 }
188 }
189
190 pub fn next_batch(&mut self, device: &Device) -> Result<Option<(Tensor, Tensor)>> {
192 if self.current_idx >= self.indices.len() {
193 return Ok(None);
194 }
195
196 let end_idx = (self.current_idx + self.batch_size).min(self.indices.len());
197 let batch_indices = &self.indices[self.current_idx..end_idx];
198 let actual_batch_size = batch_indices.len();
199
200 let mut input_data = Vec::new();
202 let mut target_data = Vec::new();
203
204 for &idx in batch_indices {
205 let input = self.dataset.get_input(idx);
206 let target = self.dataset.get_label(idx);
207
208 input_data.extend(input.iter().map(|&x| x as u32));
210 target_data.extend(target.iter().map(|&x| x as u32));
211 }
212
213 self.current_idx = end_idx;
214
215 let seq_len = self.dataset.seq_len();
217 let input_tensor = Tensor::from_vec(
218 input_data,
219 (actual_batch_size, seq_len),
220 device,
221 )?.to_dtype(candle_core::DType::U32)?;
222
223 let target_tensor = Tensor::from_vec(
224 target_data,
225 (actual_batch_size, seq_len),
226 device,
227 )?.to_dtype(candle_core::DType::U32)?;
228
229 Ok(Some((input_tensor, target_tensor)))
230 }
231
232 pub fn reset(&mut self) {
234 self.current_idx = 0;
235
236 if self.shuffle {
237 use rand::seq::SliceRandom;
238 let mut rng = rand::thread_rng();
239 self.indices.shuffle(&mut rng);
240 }
241 }
242
243 pub fn num_batches(&self) -> usize {
245 (self.dataset.len() + self.batch_size - 1) / self.batch_size
246 }
247
248 pub fn dataset(&self) -> &NumpyDataset {
250 &self.dataset
251 }
252}
253
254impl super::BatchDataLoader for NumpyDataLoader {
256 fn next_batch(&mut self, device: &Device) -> Result<Option<(Tensor, Tensor)>> {
257 NumpyDataLoader::next_batch(self, device)
258 }
259
260 fn reset(&mut self) {
261 NumpyDataLoader::reset(self)
262 }
263
264 fn num_batches(&self) -> usize {
265 NumpyDataLoader::num_batches(self)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_metadata_deserialization() {
275 let json = r#"{
276 "vocab_size": 11,
277 "seq_len": 81,
278 "num_examples": 1000000,
279 "description": "Sudoku-Extreme"
280 }"#;
281
282 let metadata: DatasetMetadata = serde_json::from_str(json).unwrap();
283 assert_eq!(metadata.vocab_size, 11);
284 assert_eq!(metadata.seq_len, 81);
285 assert_eq!(metadata.num_examples, 1000000);
286 }
287}