Skip to main content

trustformers_core/generation/
streaming.rs

1use serde::{Deserialize, Serialize};
2use std::collections::VecDeque;
3
4// Simplified stream trait for our use case
5pub trait GenerationStreamTrait {
6    type Item;
7    fn next(&mut self) -> Option<Self::Item>;
8}
9
10/// Streaming generation result
11#[derive(Debug, Clone)]
12pub struct GenerationToken {
13    pub token_id: usize,
14    pub token_str: String,
15    pub logprobs: Option<f32>,
16    pub is_finished: bool,
17    pub finish_reason: Option<FinishReason>,
18}
19
20impl GenerationToken {
21    pub fn new(
22        token_id: usize,
23        token_str: String,
24        logprobs: Option<f32>,
25        is_finished: bool,
26    ) -> Self {
27        Self {
28            token_id,
29            token_str,
30            logprobs,
31            is_finished,
32            finish_reason: None,
33        }
34    }
35
36    pub fn with_finish_reason(mut self, reason: FinishReason) -> Self {
37        self.finish_reason = Some(reason);
38        self.is_finished = true;
39        self
40    }
41}
42
43/// Reasons why generation finished
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub enum FinishReason {
46    MaxLength,
47    EosToken,
48    StopSequence,
49    UserStopped,
50    ConstraintViolation,
51    Error,
52}
53
54/// Streaming generation iterator
55pub struct GenerationStream {
56    tokens: VecDeque<GenerationToken>,
57    finished: bool,
58}
59
60impl GenerationStream {
61    pub fn new() -> Self {
62        Self {
63            tokens: VecDeque::new(),
64            finished: false,
65        }
66    }
67
68    pub fn push_token(&mut self, token: GenerationToken) {
69        self.finished = token.is_finished;
70        self.tokens.push_back(token);
71    }
72
73    pub fn finish(&mut self, reason: FinishReason) {
74        self.finished = true;
75        if let Some(last_token) = self.tokens.back_mut() {
76            last_token.is_finished = true;
77            last_token.finish_reason = Some(reason);
78        }
79    }
80
81    pub fn is_finished(&self) -> bool {
82        self.finished
83    }
84
85    pub fn len(&self) -> usize {
86        self.tokens.len()
87    }
88
89    pub fn is_empty(&self) -> bool {
90        self.tokens.is_empty()
91    }
92}
93
94impl Default for GenerationStream {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl GenerationStreamTrait for GenerationStream {
101    type Item = GenerationToken;
102
103    fn next(&mut self) -> Option<Self::Item> {
104        // Would be pending in async context if !self.finished
105        self.tokens.pop_front()
106    }
107}