Skip to main content

torsh_text/
generation.rs

1use crate::{Result, TextError};
2// ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
3use scirs2_core::random::Random;
4use scirs2_core::rngs::StdRng;
5use scirs2_core::RngExt;
6use torsh_tensor::Tensor;
7
8// ============================================================================
9// Text Generation Utilities
10// ============================================================================
11
12/// Configuration for text generation
13#[derive(Debug, Clone)]
14pub struct GenerationConfig {
15    pub max_length: usize,
16    pub min_length: usize,
17    pub do_sample: bool,
18    pub early_stopping: bool,
19    pub num_beams: usize,
20    pub temperature: f32,
21    pub top_k: Option<usize>,
22    pub top_p: Option<f32>,
23    pub repetition_penalty: f32,
24    pub length_penalty: f32,
25    pub no_repeat_ngram_size: usize,
26    pub encoder_no_repeat_ngram_size: usize,
27    pub bad_words_ids: Vec<Vec<u32>>,
28    pub force_words_ids: Vec<Vec<u32>>,
29    pub pad_token_id: Option<u32>,
30    pub bos_token_id: Option<u32>,
31    pub eos_token_id: Option<u32>,
32    pub decoder_start_token_id: Option<u32>,
33}
34
35impl Default for GenerationConfig {
36    fn default() -> Self {
37        Self {
38            max_length: 50,
39            min_length: 0,
40            do_sample: false,
41            early_stopping: false,
42            num_beams: 1,
43            temperature: 1.0,
44            top_k: None,
45            top_p: None,
46            repetition_penalty: 1.0,
47            length_penalty: 1.0,
48            no_repeat_ngram_size: 0,
49            encoder_no_repeat_ngram_size: 0,
50            bad_words_ids: Vec::new(),
51            force_words_ids: Vec::new(),
52            pad_token_id: None,
53            bos_token_id: None,
54            eos_token_id: None,
55            decoder_start_token_id: None,
56        }
57    }
58}
59
60// ============================================================================
61// Sampling Methods
62// ============================================================================
63
64pub struct TextSampler {
65    // ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
66    rng: Random<StdRng>,
67}
68
69impl Default for TextSampler {
70    fn default() -> Self {
71        Self {
72            rng: Random::seed(42),
73        }
74    }
75}
76
77impl TextSampler {
78    /// Greedy sampling - always select the token with highest probability
79    pub fn greedy_sample(&self, logits: &Tensor) -> Result<u32> {
80        let vocab_size = logits.shape().dims()[logits.shape().ndim() - 1];
81        let mut max_idx = 0;
82        let mut max_val = f32::NEG_INFINITY;
83
84        // Simple implementation - could be optimized with tensor operations
85        for i in 0..vocab_size {
86            let val = logits.select(0, i as i64)?.item()?;
87            if val > max_val {
88                max_val = val;
89                max_idx = i;
90            }
91        }
92
93        Ok(max_idx as u32)
94    }
95
96    /// Temperature sampling
97    pub fn temperature_sample(&mut self, logits: &Tensor, temperature: f32) -> Result<u32> {
98        if temperature <= 0.0 {
99            return self.greedy_sample(logits);
100        }
101
102        // Apply temperature scaling
103        let scaled_logits = logits.div_scalar(temperature)?;
104
105        // Apply softmax to get probabilities
106        let probs = scaled_logits.softmax(-1)?;
107
108        self.multinomial_sample(&probs)
109    }
110
111    /// Top-k sampling
112    pub fn top_k_sample(&mut self, logits: &Tensor, k: usize, temperature: f32) -> Result<u32> {
113        let vocab_size = logits.shape().dims()[logits.shape().ndim() - 1];
114        let k = k.min(vocab_size);
115
116        // Get top-k indices and values
117        let (top_values, top_indices) = self.get_top_k(logits, k)?;
118
119        // Apply temperature scaling
120        let scaled_values = if temperature > 0.0 {
121            top_values.div_scalar(temperature)?
122        } else {
123            top_values
124        };
125
126        // Apply softmax
127        let probs = scaled_values.softmax(-1)?;
128
129        // Sample from the distribution
130        let local_idx = self.multinomial_sample(&probs)?;
131
132        // Convert back to original vocabulary index
133        let original_idx = top_indices.select(0, local_idx as i64)?.item()?;
134        Ok(original_idx as u32)
135    }
136
137    /// Top-p (nucleus) sampling
138    pub fn top_p_sample(&mut self, logits: &Tensor, p: f32, temperature: f32) -> Result<u32> {
139        // Apply temperature scaling
140        let scaled_logits = if temperature > 0.0 {
141            logits.div_scalar(temperature)?
142        } else {
143            logits.clone()
144        };
145
146        // Apply softmax to get probabilities
147        let probs = scaled_logits.softmax(-1)?;
148
149        // Get sorted probabilities and indices
150        let (sorted_probs, sorted_indices) = self.sort_descending(&probs)?;
151
152        // Calculate cumulative probabilities
153        let cumsum = self.cumulative_sum(&sorted_probs)?;
154
155        // Find the cutoff point where cumulative probability exceeds p
156        let vocab_size = probs.shape().dims()[probs.shape().ndim() - 1];
157        let mut cutoff = vocab_size;
158
159        for i in 0..vocab_size {
160            let cum_prob = cumsum.select(0, i as i64)?.item()?;
161            if cum_prob > p {
162                cutoff = i + 1;
163                break;
164            }
165        }
166
167        // Keep only the top-p tokens
168        let nucleus_probs = sorted_probs.narrow(0, 0, cutoff)?;
169        let nucleus_indices = sorted_indices.narrow(0, 0, cutoff)?;
170
171        // Renormalize probabilities
172        let sum_tensor = nucleus_probs.sum()?;
173        let renormalized_probs = nucleus_probs.div(&sum_tensor)?;
174
175        // Sample from the nucleus
176        let local_idx = self.multinomial_sample(&renormalized_probs)?;
177
178        // Convert back to original vocabulary index
179        let original_idx = nucleus_indices.select(0, local_idx as i64)?.item()?;
180        Ok(original_idx as u32)
181    }
182
183    /// Combined top-k and top-p sampling
184    pub fn top_k_top_p_sample(
185        &mut self,
186        logits: &Tensor,
187        k: Option<usize>,
188        p: Option<f32>,
189        temperature: f32,
190    ) -> Result<u32> {
191        let mut working_logits = logits.clone();
192
193        // Apply top-k filtering first if specified
194        if let Some(k_val) = k {
195            let vocab_size = working_logits.shape().dims()[working_logits.shape().ndim() - 1];
196            if k_val < vocab_size {
197                let (top_values, top_indices) = self.get_top_k(&working_logits, k_val)?;
198
199                // Create new logits tensor filled with negative infinity
200                let mut new_logits_data = vec![f32::NEG_INFINITY; vocab_size];
201
202                // Set top-k values
203                for i in 0..k_val {
204                    let idx = top_indices.select(0, i as i64)?.item()? as usize;
205                    let val = top_values.select(0, i as i64)?.item()?;
206                    if idx < vocab_size {
207                        new_logits_data[idx] = val;
208                    }
209                }
210
211                working_logits = Tensor::from_data(
212                    new_logits_data,
213                    working_logits.shape().dims().to_vec(),
214                    torsh_core::device::DeviceType::Cpu,
215                )?;
216            }
217        }
218
219        // Apply top-p filtering if specified
220        if let Some(p_val) = p {
221            return self.top_p_sample(&working_logits, p_val, temperature);
222        }
223
224        // Otherwise use temperature sampling
225        self.temperature_sample(&working_logits, temperature)
226    }
227
228    // Helper methods
229    fn multinomial_sample(&mut self, probs: &Tensor) -> Result<u32> {
230        let vocab_size = probs.shape().dims()[probs.shape().ndim() - 1];
231        let random_val: f32 = self.rng.random();
232
233        let mut cumulative = 0.0;
234        for i in 0..vocab_size {
235            let prob = probs.select(0, i as i64)?.item()?;
236            cumulative += prob;
237            if random_val <= cumulative {
238                return Ok(i as u32);
239            }
240        }
241
242        // Fallback to last token if rounding errors occur
243        Ok((vocab_size - 1) as u32)
244    }
245
246    fn get_top_k(&self, tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
247        // Simplified implementation - in practice would use more efficient sorting
248        let vocab_size = tensor.shape().dims()[tensor.shape().ndim() - 1];
249        let mut values_and_indices: Vec<(f32, usize)> = Vec::new();
250
251        for i in 0..vocab_size {
252            let val = tensor.select(0, i as i64)?.item()?;
253            values_and_indices.push((val, i));
254        }
255
256        values_and_indices
257            .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
258        values_and_indices.truncate(k);
259
260        let values: Vec<f32> = values_and_indices.iter().map(|(v, _)| *v).collect();
261        let indices: Vec<f32> = values_and_indices.iter().map(|(_, i)| *i as f32).collect();
262
263        let values_tensor = Tensor::from_vec(values, &[k])?.to_dtype(tensor.dtype())?;
264        let indices_tensor = Tensor::from_vec(indices, &[k])?.to_dtype(tensor.dtype())?;
265
266        Ok((values_tensor, indices_tensor))
267    }
268
269    fn sort_descending(&self, tensor: &Tensor) -> Result<(Tensor, Tensor)> {
270        let vocab_size = tensor.shape().dims()[tensor.shape().ndim() - 1];
271        let mut values_and_indices: Vec<(f32, usize)> = Vec::new();
272
273        for i in 0..vocab_size {
274            let val = tensor.select(0, i as i64)?.item()?;
275            values_and_indices.push((val, i));
276        }
277
278        values_and_indices
279            .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
280
281        let values: Vec<f32> = values_and_indices.iter().map(|(v, _)| *v).collect();
282        let indices: Vec<f32> = values_and_indices.iter().map(|(_, i)| *i as f32).collect();
283
284        let values_tensor = Tensor::from_vec(values, &[vocab_size])?.to_dtype(tensor.dtype())?;
285        let indices_tensor = Tensor::from_vec(indices, &[vocab_size])?.to_dtype(tensor.dtype())?;
286
287        Ok((values_tensor, indices_tensor))
288    }
289
290    fn cumulative_sum(&self, tensor: &Tensor) -> Result<Tensor> {
291        let size = tensor.shape().dims()[tensor.shape().ndim() - 1];
292        let mut cumsum = Vec::new();
293        let mut running_sum = 0.0;
294
295        for i in 0..size {
296            let val = tensor.select(0, i as i64)?.item()?;
297            running_sum += val;
298            cumsum.push(running_sum);
299        }
300
301        Ok(Tensor::from_vec(cumsum, &[size])?.to_dtype(tensor.dtype())?)
302    }
303}
304
305// ============================================================================
306// Beam Search
307// ============================================================================
308
309#[derive(Debug, Clone)]
310pub struct BeamHypothesis {
311    pub tokens: Vec<u32>,
312    pub score: f32,
313    pub length: usize,
314}
315
316impl BeamHypothesis {
317    pub fn new(tokens: Vec<u32>, score: f32) -> Self {
318        let length = tokens.len();
319        Self {
320            tokens,
321            score,
322            length,
323        }
324    }
325
326    pub fn normalized_score(&self, length_penalty: f32) -> f32 {
327        self.score / (self.length as f32).powf(length_penalty)
328    }
329}
330
331pub struct BeamSearchDecoder {
332    num_beams: usize,
333    max_length: usize,
334    length_penalty: f32,
335    early_stopping: bool,
336    eos_token_id: Option<u32>,
337}
338
339impl BeamSearchDecoder {
340    pub fn new(
341        num_beams: usize,
342        max_length: usize,
343        length_penalty: f32,
344        early_stopping: bool,
345        eos_token_id: Option<u32>,
346    ) -> Self {
347        Self {
348            num_beams,
349            max_length,
350            length_penalty,
351            early_stopping,
352            eos_token_id,
353        }
354    }
355
356    pub fn search(
357        &self,
358        initial_tokens: Vec<u32>,
359        vocab_size: usize,
360        get_logits: impl Fn(&[u32]) -> Result<Tensor>,
361    ) -> Result<Vec<BeamHypothesis>> {
362        let mut beam_hypotheses = BeamHypothesesPool::new(
363            self.num_beams,
364            self.max_length,
365            self.length_penalty,
366            self.early_stopping,
367        );
368
369        // Initialize beams
370        let mut beams: Vec<BeamHypothesis> = vec![BeamHypothesis::new(initial_tokens.clone(), 0.0)];
371
372        for _step in 0..self.max_length {
373            let mut all_candidates = Vec::new();
374
375            for beam in &beams {
376                if let Some(eos_id) = self.eos_token_id {
377                    if beam.tokens.last() == Some(&eos_id) {
378                        // This beam has ended, add to final hypotheses
379                        beam_hypotheses.add(beam.clone());
380                        continue;
381                    }
382                }
383
384                // Get logits for current beam
385                let logits = get_logits(&beam.tokens)?;
386                let log_probs = logits.log_softmax(-1)?;
387
388                // Get top-k candidates for this beam
389                for token_id in 0..vocab_size.min(self.num_beams * 2) {
390                    let token_log_prob = log_probs.select(0, token_id as i64)?.item()?;
391                    let new_score = beam.score + token_log_prob;
392
393                    let mut new_tokens = beam.tokens.clone();
394                    new_tokens.push(token_id as u32);
395
396                    all_candidates.push(BeamHypothesis::new(new_tokens, new_score));
397                }
398            }
399
400            // Sort candidates by score and keep top num_beams
401            all_candidates.sort_by(|a, b| {
402                b.normalized_score(self.length_penalty)
403                    .partial_cmp(&a.normalized_score(self.length_penalty))
404                    .unwrap_or(std::cmp::Ordering::Equal)
405            });
406            beams = all_candidates.into_iter().take(self.num_beams).collect();
407
408            // Check early stopping
409            if self.early_stopping
410                && beam_hypotheses.is_done(
411                    beams
412                        .iter()
413                        .map(|b| b.normalized_score(self.length_penalty))
414                        .fold(f32::NEG_INFINITY, f32::max),
415                )
416            {
417                break;
418            }
419
420            // Remove finished beams
421            beams.retain(|beam| {
422                if let Some(eos_id) = self.eos_token_id {
423                    beam.tokens.last() != Some(&eos_id)
424                } else {
425                    true
426                }
427            });
428
429            if beams.is_empty() {
430                break;
431            }
432        }
433
434        // Add remaining beams to hypotheses
435        for beam in beams {
436            beam_hypotheses.add(beam);
437        }
438
439        Ok(beam_hypotheses.finalize())
440    }
441}
442
443struct BeamHypothesesPool {
444    hypotheses: Vec<BeamHypothesis>,
445    max_hypotheses: usize,
446    max_length: usize,
447    length_penalty: f32,
448    early_stopping: bool,
449}
450
451impl BeamHypothesesPool {
452    fn new(
453        max_hypotheses: usize,
454        max_length: usize,
455        length_penalty: f32,
456        early_stopping: bool,
457    ) -> Self {
458        Self {
459            hypotheses: Vec::new(),
460            max_hypotheses,
461            max_length,
462            length_penalty,
463            early_stopping,
464        }
465    }
466
467    fn add(&mut self, hypothesis: BeamHypothesis) {
468        let score = hypothesis.normalized_score(self.length_penalty);
469
470        // Insert in sorted order
471        let insert_pos = self
472            .hypotheses
473            .binary_search_by(|h| {
474                score
475                    .partial_cmp(&h.normalized_score(self.length_penalty))
476                    .unwrap_or(std::cmp::Ordering::Equal)
477            })
478            .unwrap_or_else(|e| e);
479
480        self.hypotheses.insert(insert_pos, hypothesis);
481
482        // Keep only top hypotheses
483        if self.hypotheses.len() > self.max_hypotheses {
484            self.hypotheses.truncate(self.max_hypotheses);
485        }
486    }
487
488    fn is_done(&self, best_sum_logprobs: f32) -> bool {
489        if !self.early_stopping {
490            return false;
491        }
492
493        if self.hypotheses.len() < self.max_hypotheses {
494            return false;
495        }
496
497        let worst_score = self
498            .hypotheses
499            .last()
500            .map(|h| h.normalized_score(self.length_penalty))
501            .unwrap_or(f32::NEG_INFINITY);
502        let best_possible_score =
503            best_sum_logprobs / (self.max_length as f32).powf(self.length_penalty);
504
505        worst_score >= best_possible_score
506    }
507
508    fn finalize(mut self) -> Vec<BeamHypothesis> {
509        self.hypotheses.sort_by(|a, b| {
510            b.normalized_score(self.length_penalty)
511                .partial_cmp(&a.normalized_score(self.length_penalty))
512                .unwrap_or(std::cmp::Ordering::Equal)
513        });
514        self.hypotheses
515    }
516}
517
518// ============================================================================
519// Repetition and Constraint Handling
520// ============================================================================
521
522pub struct RepetitionPenalty;
523
524impl RepetitionPenalty {
525    pub fn apply(logits: &Tensor, generated_tokens: &[u32], penalty: f32) -> Result<Tensor> {
526        if penalty == 1.0 {
527            return Ok(logits.clone());
528        }
529
530        let mut penalized_logits = logits.clone();
531
532        // Apply penalty to repeated tokens
533        for &token in generated_tokens {
534            let current_logit = penalized_logits.select(0, token as i64)?.item()?;
535            let penalized_value = if current_logit > 0.0 {
536                current_logit / penalty
537            } else {
538                current_logit * penalty
539            };
540
541            // Use index_select and scatter to simulate index_put
542            let _token_tensor = Tensor::from_vec(vec![token as i64], &[1])?;
543            let _penalty_tensor = Tensor::scalar(penalized_value)?;
544
545            // Simplified approach: directly modify the logit at the token position
546            let vocab_size = penalized_logits.shape().dims()[0];
547            let mut logits_vec = penalized_logits.to_vec()?;
548            logits_vec[token as usize] = penalized_value;
549            penalized_logits = Tensor::from_vec(logits_vec, &[vocab_size])?;
550        }
551
552        Ok(penalized_logits)
553    }
554}
555
556pub struct NGramRepetitionFilter {
557    no_repeat_ngram_size: usize,
558}
559
560impl NGramRepetitionFilter {
561    pub fn new(no_repeat_ngram_size: usize) -> Self {
562        Self {
563            no_repeat_ngram_size,
564        }
565    }
566
567    pub fn filter_logits(&self, logits: &Tensor, generated_tokens: &[u32]) -> Result<Tensor> {
568        if self.no_repeat_ngram_size == 0 || generated_tokens.len() < self.no_repeat_ngram_size {
569            return Ok(logits.clone());
570        }
571
572        let mut filtered_logits = logits.clone();
573        let _vocab_size = logits.shape().dims()[logits.shape().ndim() - 1];
574
575        // Extract n-grams from generated sequence
576        let mut banned_tokens = std::collections::HashSet::new();
577
578        for i in 0..generated_tokens.len() - self.no_repeat_ngram_size + 1 {
579            let ngram = &generated_tokens[i..i + self.no_repeat_ngram_size - 1];
580
581            // Check if current context matches this n-gram prefix
582            let current_context =
583                &generated_tokens[generated_tokens.len() - self.no_repeat_ngram_size + 1..];
584
585            if ngram == current_context {
586                // Ban the token that would complete this n-gram
587                let banned_token = generated_tokens[i + self.no_repeat_ngram_size - 1];
588                banned_tokens.insert(banned_token);
589            }
590        }
591
592        // Set banned tokens to negative infinity
593        let vocab_size = filtered_logits.shape().dims()[0];
594        let mut logits_vec = filtered_logits.to_vec()?;
595        for banned_token in banned_tokens {
596            if (banned_token as usize) < logits_vec.len() {
597                logits_vec[banned_token as usize] = f32::NEG_INFINITY;
598            }
599        }
600        filtered_logits = Tensor::from_vec(logits_vec, &[vocab_size])?;
601
602        Ok(filtered_logits)
603    }
604}
605
606// ============================================================================
607// Text Generation Pipeline
608// ============================================================================
609
610pub struct TextGenerator {
611    sampler: TextSampler,
612    beam_decoder: Option<BeamSearchDecoder>,
613    repetition_penalty: RepetitionPenalty,
614    ngram_filter: Option<NGramRepetitionFilter>,
615}
616
617impl TextGenerator {
618    pub fn new(config: &GenerationConfig) -> Self {
619        let beam_decoder = if config.num_beams > 1 {
620            Some(BeamSearchDecoder::new(
621                config.num_beams,
622                config.max_length,
623                config.length_penalty,
624                config.early_stopping,
625                config.eos_token_id,
626            ))
627        } else {
628            None
629        };
630
631        let ngram_filter = if config.no_repeat_ngram_size > 0 {
632            Some(NGramRepetitionFilter::new(config.no_repeat_ngram_size))
633        } else {
634            None
635        };
636
637        Self {
638            sampler: TextSampler::default(),
639            beam_decoder,
640            repetition_penalty: RepetitionPenalty,
641            ngram_filter,
642        }
643    }
644
645    pub fn generate(
646        &mut self,
647        initial_tokens: Vec<u32>,
648        vocab_size: usize,
649        config: &GenerationConfig,
650        get_logits: impl Fn(&[u32]) -> Result<Tensor> + Clone,
651    ) -> Result<Vec<Vec<u32>>> {
652        if config.num_beams > 1 {
653            // Beam search
654            if let Some(ref decoder) = self.beam_decoder {
655                let hypotheses = decoder.search(initial_tokens, vocab_size, get_logits)?;
656                Ok(hypotheses.into_iter().map(|h| h.tokens).collect())
657            } else {
658                Err(TextError::ModelError(
659                    "Beam decoder not initialized".to_string(),
660                ))
661            }
662        } else {
663            // Sampling-based generation
664            let result =
665                self.generate_with_sampling(initial_tokens, vocab_size, config, get_logits)?;
666            Ok(vec![result])
667        }
668    }
669
670    fn generate_with_sampling(
671        &mut self,
672        mut tokens: Vec<u32>,
673        _vocab_size: usize,
674        config: &GenerationConfig,
675        get_logits: impl Fn(&[u32]) -> Result<Tensor>,
676    ) -> Result<Vec<u32>> {
677        for _ in 0..config.max_length {
678            // Get logits for current sequence
679            let mut logits = get_logits(&tokens)?;
680
681            // Apply repetition penalty
682            if config.repetition_penalty != 1.0 {
683                logits = RepetitionPenalty::apply(&logits, &tokens, config.repetition_penalty)?;
684            }
685
686            // Apply n-gram repetition filter
687            if let Some(ref filter) = self.ngram_filter {
688                logits = filter.filter_logits(&logits, &tokens)?;
689            }
690
691            // Sample next token
692            let next_token = if config.do_sample {
693                self.sampler.top_k_top_p_sample(
694                    &logits,
695                    config.top_k,
696                    config.top_p,
697                    config.temperature,
698                )?
699            } else {
700                self.sampler.greedy_sample(&logits)?
701            };
702
703            tokens.push(next_token);
704
705            // Check for EOS token
706            if let Some(eos_id) = config.eos_token_id {
707                if next_token == eos_id {
708                    break;
709                }
710            }
711
712            // Check minimum length
713            if tokens.len() >= config.min_length {
714                if let Some(eos_id) = config.eos_token_id {
715                    if next_token == eos_id {
716                        break;
717                    }
718                }
719            }
720        }
721
722        Ok(tokens)
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use torsh_core::{device::DeviceType as Device, dtype::DType};
730
731    #[test]
732    fn test_text_sampler_creation() {
733        let _sampler = TextSampler::default();
734        // Just test that it doesn't panic
735    }
736
737    #[test]
738    fn test_generation_config_default() {
739        let config = GenerationConfig::default();
740        assert_eq!(config.max_length, 50);
741        assert_eq!(config.num_beams, 1);
742        assert!(!config.do_sample);
743    }
744
745    #[test]
746    fn test_beam_hypothesis() {
747        let tokens = vec![1, 2, 3];
748        let score = -1.5;
749        let hypothesis = BeamHypothesis::new(tokens.clone(), score);
750
751        assert_eq!(hypothesis.tokens, tokens);
752        assert_eq!(hypothesis.score, score);
753        assert_eq!(hypothesis.length, 3);
754    }
755
756    #[test]
757    fn test_greedy_sampling() {
758        let _device = Device::Cpu;
759        let dtype = DType::F32;
760
761        // Create a simple logits tensor where token 2 has highest probability
762        let logits = Tensor::from_vec(vec![0.1, 0.2, 0.9, 0.3], &[4])
763            .unwrap()
764            .to_dtype(dtype)
765            .unwrap();
766
767        let sampler = TextSampler::default();
768        let result = sampler.greedy_sample(&logits).unwrap();
769
770        assert_eq!(result, 2); // Should select the token with highest logit
771    }
772}