1use std::collections::HashMap;
6use ndarray::{Array1, Array2};
7use ndarray_rand::RandomExt;
8use ndarray_rand::rand_distr::Uniform;
9use crate::optimizers::Optimizer;
10
11#[derive(Clone, Debug)]
15pub struct TextVocabulary {
16 char_to_idx: HashMap<char, usize>,
17 idx_to_char: HashMap<usize, char>,
18 vocab_size: usize,
19}
20
21impl TextVocabulary {
22 pub fn from_text(text: &str) -> Self {
24 let mut chars: Vec<char> = text.chars().collect::<std::collections::HashSet<_>>()
25 .into_iter().collect();
26 chars.sort();
27
28 let vocab_size = chars.len();
29 let char_to_idx: HashMap<char, usize> = chars.iter()
30 .enumerate()
31 .map(|(i, &c)| (c, i))
32 .collect();
33 let idx_to_char: HashMap<usize, char> = chars.iter()
34 .enumerate()
35 .map(|(i, &c)| (i, c))
36 .collect();
37
38 Self { char_to_idx, idx_to_char, vocab_size }
39 }
40
41 pub fn from_chars(chars: &[char]) -> Self {
43 let vocab_size = chars.len();
44 let char_to_idx: HashMap<char, usize> = chars.iter()
45 .enumerate()
46 .map(|(i, &c)| (c, i))
47 .collect();
48 let idx_to_char: HashMap<usize, char> = chars.iter()
49 .enumerate()
50 .map(|(i, &c)| (i, c))
51 .collect();
52
53 Self { char_to_idx, idx_to_char, vocab_size }
54 }
55
56 pub fn char_to_index(&self, ch: char) -> Option<usize> {
58 self.char_to_idx.get(&ch).copied()
59 }
60
61 pub fn index_to_char(&self, idx: usize) -> Option<char> {
63 self.idx_to_char.get(&idx).copied()
64 }
65
66 pub fn size(&self) -> usize {
68 self.vocab_size
69 }
70
71 pub fn contains(&self, ch: char) -> bool {
73 self.char_to_idx.contains_key(&ch)
74 }
75
76 pub fn chars(&self) -> Vec<char> {
78 let mut chars: Vec<_> = self.idx_to_char.iter().collect();
79 chars.sort_by_key(|(idx, _)| *idx);
80 chars.into_iter().map(|(_, &ch)| ch).collect()
81 }
82
83 pub fn encode(&self, text: &str) -> Vec<usize> {
85 text.chars()
86 .filter_map(|ch| self.char_to_index(ch))
87 .collect()
88 }
89
90 pub fn decode(&self, indices: &[usize]) -> String {
92 indices.iter()
93 .filter_map(|&idx| self.index_to_char(idx))
94 .collect()
95 }
96}
97
98#[derive(Clone, Debug)]
100pub struct EmbeddingGradients {
101 pub weight: Array2<f64>,
102}
103
104#[derive(Clone, Debug)]
108pub struct CharacterEmbedding {
109 pub weight: Array2<f64>, vocab_size: usize,
111 embed_dim: usize,
112 input_cache: Option<Vec<usize>>,
113}
114
115impl CharacterEmbedding {
116 pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
118 let scale = (1.0 / embed_dim as f64).sqrt();
119 let weight = Array2::random((vocab_size, embed_dim), Uniform::new(-scale, scale));
120
121 Self {
122 weight,
123 vocab_size,
124 embed_dim,
125 input_cache: None,
126 }
127 }
128
129 pub fn new_zeros(vocab_size: usize, embed_dim: usize) -> Self {
131 Self {
132 weight: Array2::zeros((vocab_size, embed_dim)),
133 vocab_size,
134 embed_dim,
135 input_cache: None,
136 }
137 }
138
139 pub fn from_weights(weight: Array2<f64>) -> Self {
141 let (vocab_size, embed_dim) = weight.dim();
142 Self {
143 weight,
144 vocab_size,
145 embed_dim,
146 input_cache: None,
147 }
148 }
149
150 pub fn embed_dim(&self) -> usize {
152 self.embed_dim
153 }
154
155 pub fn vocab_size(&self) -> usize {
157 self.vocab_size
158 }
159
160 pub fn lookup(&self, char_idx: usize) -> Array1<f64> {
162 assert!(char_idx < self.vocab_size, "Index {} out of vocabulary size {}", char_idx, self.vocab_size);
163 self.weight.row(char_idx).to_owned()
164 }
165
166 pub fn forward(&mut self, char_indices: &[usize]) -> Array2<f64> {
169 self.input_cache = Some(char_indices.to_vec());
170
171 let seq_len = char_indices.len();
172 let mut output = Array2::zeros((seq_len, self.embed_dim));
173
174 for (i, &idx) in char_indices.iter().enumerate() {
175 assert!(idx < self.vocab_size, "Index {} out of vocabulary size {}", idx, self.vocab_size);
176 output.row_mut(i).assign(&self.weight.row(idx));
177 }
178
179 output
180 }
181
182 pub fn backward(&self, grad_output: &Array2<f64>) -> EmbeddingGradients {
185 let indices = self.input_cache.as_ref().expect("No cached input for backward pass");
186
187 let mut weight_grad = Array2::zeros((self.vocab_size, self.embed_dim));
188
189 for (i, &idx) in indices.iter().enumerate() {
190 for j in 0..self.embed_dim {
191 weight_grad[[idx, j]] += grad_output[[i, j]];
192 }
193 }
194
195 EmbeddingGradients { weight: weight_grad }
196 }
197
198 pub fn update_parameters<O: Optimizer>(&mut self, gradients: &EmbeddingGradients, optimizer: &mut O, prefix: &str) {
200 optimizer.update(&format!("{}_weight", prefix), &mut self.weight, &gradients.weight);
201 }
202
203 pub fn num_parameters(&self) -> usize {
205 self.weight.len()
206 }
207}
208
209pub fn sample_with_temperature(logits: &Array1<f64>, temperature: f64) -> usize {
213 assert!(temperature > 0.0, "Temperature must be positive");
214
215 let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
217
218 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
220 let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
221 let sum: f64 = exp_vals.iter().sum();
222 let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
223
224 let mut rng_val = rand::random::<f64>();
226 for (i, &prob) in probs.iter().enumerate() {
227 rng_val -= prob;
228 if rng_val <= 0.0 {
229 return i;
230 }
231 }
232
233 probs.len() - 1
234}
235
236pub fn sample_top_k(logits: &Array1<f64>, k: usize, temperature: f64) -> usize {
240 assert!(k > 0, "k must be positive");
241 assert!(temperature > 0.0, "Temperature must be positive");
242
243 let k = k.min(logits.len());
244
245 let mut indexed: Vec<(usize, f64)> = logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
247 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
248
249 let top_k: Vec<(usize, f64)> = indexed.into_iter().take(k).collect();
251
252 let scaled: Vec<f64> = top_k.iter().map(|(_, v)| v / temperature).collect();
254 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
255 let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
256 let sum: f64 = exp_vals.iter().sum();
257 let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
258
259 let mut rng_val = rand::random::<f64>();
261 for (i, &prob) in probs.iter().enumerate() {
262 rng_val -= prob;
263 if rng_val <= 0.0 {
264 return top_k[i].0;
265 }
266 }
267
268 top_k[k - 1].0
269}
270
271pub fn sample_nucleus(logits: &Array1<f64>, p: f64, temperature: f64) -> usize {
275 assert!(p > 0.0 && p <= 1.0, "p must be in (0, 1]");
276 assert!(temperature > 0.0, "Temperature must be positive");
277
278 let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
280 let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
281 let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
282 let sum: f64 = exp_vals.iter().sum();
283 let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
284
285 let mut indexed: Vec<(usize, f64)> = probs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
287 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
288
289 let mut cumulative = 0.0;
291 let mut nucleus: Vec<(usize, f64)> = Vec::new();
292 for (idx, prob) in indexed {
293 cumulative += prob;
294 nucleus.push((idx, prob));
295 if cumulative >= p {
296 break;
297 }
298 }
299
300 let nucleus_sum: f64 = nucleus.iter().map(|(_, prob)| prob).sum();
302 let nucleus_probs: Vec<f64> = nucleus.iter().map(|(_, prob)| prob / nucleus_sum).collect();
303
304 let mut rng_val = rand::random::<f64>();
306 for (i, &prob) in nucleus_probs.iter().enumerate() {
307 rng_val -= prob;
308 if rng_val <= 0.0 {
309 return nucleus[i].0;
310 }
311 }
312
313 nucleus.last().map(|(idx, _)| *idx).unwrap_or(0)
314}
315
316pub fn argmax(logits: &Array1<f64>) -> usize {
318 logits.iter()
319 .enumerate()
320 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
321 .map(|(idx, _)| idx)
322 .unwrap_or(0)
323}
324
325pub fn softmax(logits: &Array1<f64>) -> Array1<f64> {
327 let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
328 let exp_vals: Array1<f64> = logits.mapv(|x| (x - max_val).exp());
329 let sum: f64 = exp_vals.sum();
330 exp_vals / sum
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use ndarray::arr1;
337
338 #[test]
339 fn test_vocabulary_from_text() {
340 let vocab = TextVocabulary::from_text("hello");
341 assert_eq!(vocab.size(), 4); assert!(vocab.contains('h'));
343 assert!(vocab.contains('l'));
344 assert!(!vocab.contains('x'));
345 }
346
347 #[test]
348 fn test_vocabulary_encode_decode() {
349 let vocab = TextVocabulary::from_text("abc");
350 let encoded = vocab.encode("cab");
351 let decoded = vocab.decode(&encoded);
352 assert_eq!(decoded, "cab");
353 }
354
355 #[test]
356 fn test_embedding_forward() {
357 let mut emb = CharacterEmbedding::new(10, 8);
358 let output = emb.forward(&[0, 3, 5]);
359 assert_eq!(output.shape(), &[3, 8]);
360 }
361
362 #[test]
363 fn test_embedding_lookup() {
364 let emb = CharacterEmbedding::new(10, 8);
365 let vec = emb.lookup(5);
366 assert_eq!(vec.len(), 8);
367 }
368
369 #[test]
370 fn test_sample_with_temperature() {
371 let logits = arr1(&[1.0, 2.0, 3.0]);
372 let idx = sample_with_temperature(&logits, 1.0);
373 assert!(idx < 3);
374 }
375
376 #[test]
377 fn test_sample_top_k() {
378 let logits = arr1(&[1.0, 5.0, 2.0, 0.5]);
379 let idx = sample_top_k(&logits, 2, 1.0);
380 assert!(idx == 1 || idx == 2);
382 }
383
384 #[test]
385 fn test_sample_nucleus() {
386 let logits = arr1(&[0.0, 10.0, 0.0]); let idx = sample_nucleus(&logits, 0.9, 1.0);
388 assert_eq!(idx, 1); }
390
391 #[test]
392 fn test_argmax() {
393 let logits = arr1(&[1.0, 5.0, 2.0]);
394 assert_eq!(argmax(&logits), 1);
395 }
396
397 #[test]
398 fn test_softmax() {
399 let logits = arr1(&[1.0, 2.0, 3.0]);
400 let probs = softmax(&logits);
401 assert!((probs.sum() - 1.0).abs() < 1e-6);
402 }
403}