trustformers_core/generation/
streaming.rs1use serde::{Deserialize, Serialize};
2use std::collections::VecDeque;
3
4pub trait GenerationStreamTrait {
6 type Item;
7 fn next(&mut self) -> Option<Self::Item>;
8}
9
10#[derive(Debug, Clone)]
12pub struct GenerationToken {
13 pub token_id: usize,
14 pub token_str: String,
15 pub logprobs: Option<f32>,
16 pub is_finished: bool,
17 pub finish_reason: Option<FinishReason>,
18}
19
20impl GenerationToken {
21 pub fn new(
22 token_id: usize,
23 token_str: String,
24 logprobs: Option<f32>,
25 is_finished: bool,
26 ) -> Self {
27 Self {
28 token_id,
29 token_str,
30 logprobs,
31 is_finished,
32 finish_reason: None,
33 }
34 }
35
36 pub fn with_finish_reason(mut self, reason: FinishReason) -> Self {
37 self.finish_reason = Some(reason);
38 self.is_finished = true;
39 self
40 }
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub enum FinishReason {
46 MaxLength,
47 EosToken,
48 StopSequence,
49 UserStopped,
50 ConstraintViolation,
51 Error,
52}
53
54pub struct GenerationStream {
56 tokens: VecDeque<GenerationToken>,
57 finished: bool,
58}
59
60impl GenerationStream {
61 pub fn new() -> Self {
62 Self {
63 tokens: VecDeque::new(),
64 finished: false,
65 }
66 }
67
68 pub fn push_token(&mut self, token: GenerationToken) {
69 self.finished = token.is_finished;
70 self.tokens.push_back(token);
71 }
72
73 pub fn finish(&mut self, reason: FinishReason) {
74 self.finished = true;
75 if let Some(last_token) = self.tokens.back_mut() {
76 last_token.is_finished = true;
77 last_token.finish_reason = Some(reason);
78 }
79 }
80
81 pub fn is_finished(&self) -> bool {
82 self.finished
83 }
84
85 pub fn len(&self) -> usize {
86 self.tokens.len()
87 }
88
89 pub fn is_empty(&self) -> bool {
90 self.tokens.is_empty()
91 }
92}
93
94impl Default for GenerationStream {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl GenerationStreamTrait for GenerationStream {
101 type Item = GenerationToken;
102
103 fn next(&mut self) -> Option<Self::Item> {
104 self.tokens.pop_front()
106 }
107}