Skip to main content

rust_lstm/
text.rs

1//! Text generation utilities for character-level language models.
2//!
3//! Provides vocabulary management, character embeddings, and sampling strategies.
4
5use std::collections::HashMap;
6use ndarray::{Array1, Array2};
7use ndarray_rand::RandomExt;
8use ndarray_rand::rand_distr::Uniform;
9use crate::optimizers::Optimizer;
10
11/// Character vocabulary for text generation tasks.
12///
13/// Maps characters to indices and vice versa.
14#[derive(Clone, Debug)]
15pub struct TextVocabulary {
16    char_to_idx: HashMap<char, usize>,
17    idx_to_char: HashMap<usize, char>,
18    vocab_size: usize,
19}
20
21impl TextVocabulary {
22    /// Create vocabulary from text, extracting unique characters.
23    pub fn from_text(text: &str) -> Self {
24        let mut chars: Vec<char> = text.chars().collect::<std::collections::HashSet<_>>()
25            .into_iter().collect();
26        chars.sort();
27
28        let vocab_size = chars.len();
29        let char_to_idx: HashMap<char, usize> = chars.iter()
30            .enumerate()
31            .map(|(i, &c)| (c, i))
32            .collect();
33        let idx_to_char: HashMap<usize, char> = chars.iter()
34            .enumerate()
35            .map(|(i, &c)| (i, c))
36            .collect();
37
38        Self { char_to_idx, idx_to_char, vocab_size }
39    }
40
41    /// Create vocabulary from explicit character list.
42    pub fn from_chars(chars: &[char]) -> Self {
43        let vocab_size = chars.len();
44        let char_to_idx: HashMap<char, usize> = chars.iter()
45            .enumerate()
46            .map(|(i, &c)| (c, i))
47            .collect();
48        let idx_to_char: HashMap<usize, char> = chars.iter()
49            .enumerate()
50            .map(|(i, &c)| (i, c))
51            .collect();
52
53        Self { char_to_idx, idx_to_char, vocab_size }
54    }
55
56    /// Get index for a character.
57    pub fn char_to_index(&self, ch: char) -> Option<usize> {
58        self.char_to_idx.get(&ch).copied()
59    }
60
61    /// Get character for an index.
62    pub fn index_to_char(&self, idx: usize) -> Option<char> {
63        self.idx_to_char.get(&idx).copied()
64    }
65
66    /// Get vocabulary size.
67    pub fn size(&self) -> usize {
68        self.vocab_size
69    }
70
71    /// Check if character is in vocabulary.
72    pub fn contains(&self, ch: char) -> bool {
73        self.char_to_idx.contains_key(&ch)
74    }
75
76    /// Get all characters in vocabulary order.
77    pub fn chars(&self) -> Vec<char> {
78        let mut chars: Vec<_> = self.idx_to_char.iter().collect();
79        chars.sort_by_key(|(idx, _)| *idx);
80        chars.into_iter().map(|(_, &ch)| ch).collect()
81    }
82
83    /// Encode string to indices.
84    pub fn encode(&self, text: &str) -> Vec<usize> {
85        text.chars()
86            .filter_map(|ch| self.char_to_index(ch))
87            .collect()
88    }
89
90    /// Decode indices to string.
91    pub fn decode(&self, indices: &[usize]) -> String {
92        indices.iter()
93            .filter_map(|&idx| self.index_to_char(idx))
94            .collect()
95    }
96}
97
98/// Gradients for character embedding layer.
99#[derive(Clone, Debug)]
100pub struct EmbeddingGradients {
101    pub weight: Array2<f64>,
102}
103
104/// Trainable character embedding layer.
105///
106/// Maps character indices to dense vectors.
107#[derive(Clone, Debug)]
108pub struct CharacterEmbedding {
109    pub weight: Array2<f64>, // (vocab_size, embed_dim)
110    vocab_size: usize,
111    embed_dim: usize,
112    input_cache: Option<Vec<usize>>,
113}
114
115impl CharacterEmbedding {
116    /// Create new embedding with random initialization.
117    pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
118        let scale = (1.0 / embed_dim as f64).sqrt();
119        let weight = Array2::random((vocab_size, embed_dim), Uniform::new(-scale, scale));
120
121        Self {
122            weight,
123            vocab_size,
124            embed_dim,
125            input_cache: None,
126        }
127    }
128
129    /// Create embedding with zero initialization.
130    pub fn new_zeros(vocab_size: usize, embed_dim: usize) -> Self {
131        Self {
132            weight: Array2::zeros((vocab_size, embed_dim)),
133            vocab_size,
134            embed_dim,
135            input_cache: None,
136        }
137    }
138
139    /// Create embedding from existing weights.
140    pub fn from_weights(weight: Array2<f64>) -> Self {
141        let (vocab_size, embed_dim) = weight.dim();
142        Self {
143            weight,
144            vocab_size,
145            embed_dim,
146            input_cache: None,
147        }
148    }
149
150    /// Get embedding dimension.
151    pub fn embed_dim(&self) -> usize {
152        self.embed_dim
153    }
154
155    /// Get vocabulary size.
156    pub fn vocab_size(&self) -> usize {
157        self.vocab_size
158    }
159
160    /// Lookup single character embedding.
161    pub fn lookup(&self, char_idx: usize) -> Array1<f64> {
162        assert!(char_idx < self.vocab_size, "Index {} out of vocabulary size {}", char_idx, self.vocab_size);
163        self.weight.row(char_idx).to_owned()
164    }
165
166    /// Forward pass for sequence of indices.
167    /// Returns (seq_len, embed_dim) matrix.
168    pub fn forward(&mut self, char_indices: &[usize]) -> Array2<f64> {
169        self.input_cache = Some(char_indices.to_vec());
170
171        let seq_len = char_indices.len();
172        let mut output = Array2::zeros((seq_len, self.embed_dim));
173
174        for (i, &idx) in char_indices.iter().enumerate() {
175            assert!(idx < self.vocab_size, "Index {} out of vocabulary size {}", idx, self.vocab_size);
176            output.row_mut(i).assign(&self.weight.row(idx));
177        }
178
179        output
180    }
181
182    /// Backward pass - compute gradients.
183    /// grad_output shape: (seq_len, embed_dim)
184    pub fn backward(&self, grad_output: &Array2<f64>) -> EmbeddingGradients {
185        let indices = self.input_cache.as_ref().expect("No cached input for backward pass");
186
187        let mut weight_grad = Array2::zeros((self.vocab_size, self.embed_dim));
188
189        for (i, &idx) in indices.iter().enumerate() {
190            for j in 0..self.embed_dim {
191                weight_grad[[idx, j]] += grad_output[[i, j]];
192            }
193        }
194
195        EmbeddingGradients { weight: weight_grad }
196    }
197
198    /// Update parameters with optimizer.
199    pub fn update_parameters<O: Optimizer>(&mut self, gradients: &EmbeddingGradients, optimizer: &mut O, prefix: &str) {
200        optimizer.update(&format!("{}_weight", prefix), &mut self.weight, &gradients.weight);
201    }
202
203    /// Get number of parameters.
204    pub fn num_parameters(&self) -> usize {
205        self.weight.len()
206    }
207}
208
209/// Sample from logits with temperature scaling.
210///
211/// Higher temperature = more random, lower = more deterministic.
212pub fn sample_with_temperature(logits: &Array1<f64>, temperature: f64) -> usize {
213    assert!(temperature > 0.0, "Temperature must be positive");
214
215    // Scale logits by temperature
216    let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
217
218    // Softmax with numerical stability
219    let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
220    let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
221    let sum: f64 = exp_vals.iter().sum();
222    let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
223
224    // Sample from distribution
225    let mut rng_val = rand::random::<f64>();
226    for (i, &prob) in probs.iter().enumerate() {
227        rng_val -= prob;
228        if rng_val <= 0.0 {
229            return i;
230        }
231    }
232
233    probs.len() - 1
234}
235
236/// Sample from top-k most likely tokens.
237///
238/// Filters to k highest probability tokens before sampling.
239pub fn sample_top_k(logits: &Array1<f64>, k: usize, temperature: f64) -> usize {
240    assert!(k > 0, "k must be positive");
241    assert!(temperature > 0.0, "Temperature must be positive");
242
243    let k = k.min(logits.len());
244
245    // Get indices sorted by logit value (descending)
246    let mut indexed: Vec<(usize, f64)> = logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
247    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
248
249    // Keep top k
250    let top_k: Vec<(usize, f64)> = indexed.into_iter().take(k).collect();
251
252    // Apply temperature and softmax to top-k only
253    let scaled: Vec<f64> = top_k.iter().map(|(_, v)| v / temperature).collect();
254    let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
255    let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
256    let sum: f64 = exp_vals.iter().sum();
257    let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
258
259    // Sample
260    let mut rng_val = rand::random::<f64>();
261    for (i, &prob) in probs.iter().enumerate() {
262        rng_val -= prob;
263        if rng_val <= 0.0 {
264            return top_k[i].0;
265        }
266    }
267
268    top_k[k - 1].0
269}
270
271/// Nucleus (top-p) sampling.
272///
273/// Samples from smallest set of tokens whose cumulative probability exceeds p.
274pub fn sample_nucleus(logits: &Array1<f64>, p: f64, temperature: f64) -> usize {
275    assert!(p > 0.0 && p <= 1.0, "p must be in (0, 1]");
276    assert!(temperature > 0.0, "Temperature must be positive");
277
278    // Apply temperature and softmax
279    let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
280    let max_val = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
281    let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
282    let sum: f64 = exp_vals.iter().sum();
283    let probs: Vec<f64> = exp_vals.iter().map(|&x| x / sum).collect();
284
285    // Sort by probability (descending)
286    let mut indexed: Vec<(usize, f64)> = probs.iter().enumerate().map(|(i, &v)| (i, v)).collect();
287    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
288
289    // Find nucleus (cumulative prob >= p)
290    let mut cumulative = 0.0;
291    let mut nucleus: Vec<(usize, f64)> = Vec::new();
292    for (idx, prob) in indexed {
293        cumulative += prob;
294        nucleus.push((idx, prob));
295        if cumulative >= p {
296            break;
297        }
298    }
299
300    // Renormalize nucleus probabilities
301    let nucleus_sum: f64 = nucleus.iter().map(|(_, prob)| prob).sum();
302    let nucleus_probs: Vec<f64> = nucleus.iter().map(|(_, prob)| prob / nucleus_sum).collect();
303
304    // Sample from nucleus
305    let mut rng_val = rand::random::<f64>();
306    for (i, &prob) in nucleus_probs.iter().enumerate() {
307        rng_val -= prob;
308        if rng_val <= 0.0 {
309            return nucleus[i].0;
310        }
311    }
312
313    nucleus.last().map(|(idx, _)| *idx).unwrap_or(0)
314}
315
316/// Get argmax (greedy decoding).
317pub fn argmax(logits: &Array1<f64>) -> usize {
318    logits.iter()
319        .enumerate()
320        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
321        .map(|(idx, _)| idx)
322        .unwrap_or(0)
323}
324
325/// Apply softmax to logits.
326pub fn softmax(logits: &Array1<f64>) -> Array1<f64> {
327    let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
328    let exp_vals: Array1<f64> = logits.mapv(|x| (x - max_val).exp());
329    let sum: f64 = exp_vals.sum();
330    exp_vals / sum
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use ndarray::arr1;
337
338    #[test]
339    fn test_vocabulary_from_text() {
340        let vocab = TextVocabulary::from_text("hello");
341        assert_eq!(vocab.size(), 4); // h, e, l, o
342        assert!(vocab.contains('h'));
343        assert!(vocab.contains('l'));
344        assert!(!vocab.contains('x'));
345    }
346
347    #[test]
348    fn test_vocabulary_encode_decode() {
349        let vocab = TextVocabulary::from_text("abc");
350        let encoded = vocab.encode("cab");
351        let decoded = vocab.decode(&encoded);
352        assert_eq!(decoded, "cab");
353    }
354
355    #[test]
356    fn test_embedding_forward() {
357        let mut emb = CharacterEmbedding::new(10, 8);
358        let output = emb.forward(&[0, 3, 5]);
359        assert_eq!(output.shape(), &[3, 8]);
360    }
361
362    #[test]
363    fn test_embedding_lookup() {
364        let emb = CharacterEmbedding::new(10, 8);
365        let vec = emb.lookup(5);
366        assert_eq!(vec.len(), 8);
367    }
368
369    #[test]
370    fn test_sample_with_temperature() {
371        let logits = arr1(&[1.0, 2.0, 3.0]);
372        let idx = sample_with_temperature(&logits, 1.0);
373        assert!(idx < 3);
374    }
375
376    #[test]
377    fn test_sample_top_k() {
378        let logits = arr1(&[1.0, 5.0, 2.0, 0.5]);
379        let idx = sample_top_k(&logits, 2, 1.0);
380        // Should only sample from indices 1 or 2 (top 2)
381        assert!(idx == 1 || idx == 2);
382    }
383
384    #[test]
385    fn test_sample_nucleus() {
386        let logits = arr1(&[0.0, 10.0, 0.0]); // Very peaked distribution
387        let idx = sample_nucleus(&logits, 0.9, 1.0);
388        assert_eq!(idx, 1); // Should almost always be 1
389    }
390
391    #[test]
392    fn test_argmax() {
393        let logits = arr1(&[1.0, 5.0, 2.0]);
394        assert_eq!(argmax(&logits), 1);
395    }
396
397    #[test]
398    fn test_softmax() {
399        let logits = arr1(&[1.0, 2.0, 3.0]);
400        let probs = softmax(&logits);
401        assert!((probs.sum() - 1.0).abs() < 1e-6);
402    }
403}