1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
5pub enum GenerationStrategy {
6 #[default]
8 Greedy,
9 Sampling { temperature: f32 },
11 TopK { k: usize, temperature: f32 },
13 TopP { p: f32, temperature: f32 },
15 BeamSearch { num_beams: usize },
17 ContrastiveSearch { penalty_alpha: f32, top_k: usize },
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GenerationConfig {
24 pub strategy: GenerationStrategy,
25 pub max_length: Option<usize>,
26 pub max_new_tokens: Option<usize>,
27 pub min_length: Option<usize>,
28 pub do_sample: bool,
29 pub early_stopping: bool,
30 pub num_return_sequences: usize,
31 pub pad_token_id: Option<usize>,
32 pub eos_token_id: Option<usize>,
33 pub bos_token_id: Option<usize>,
34 pub repetition_penalty: f32,
35 pub length_penalty: f32,
36 pub no_repeat_ngram_size: Option<usize>,
37 pub use_cache: bool,
38 pub streaming: bool,
39 pub guided_generation: Option<GuidedGenerationConfig>,
40 pub watermarking: Option<WatermarkingConfig>,
41}
42
43impl Default for GenerationConfig {
44 fn default() -> Self {
45 Self {
46 strategy: GenerationStrategy::default(),
47 max_length: Some(100),
48 max_new_tokens: None,
49 min_length: Some(1),
50 do_sample: false,
51 early_stopping: false,
52 num_return_sequences: 1,
53 pad_token_id: None,
54 eos_token_id: None,
55 bos_token_id: None,
56 repetition_penalty: 1.0,
57 length_penalty: 1.0,
58 no_repeat_ngram_size: None,
59 use_cache: true,
60 streaming: false,
61 guided_generation: None,
62 watermarking: None,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct GuidedGenerationConfig {
70 pub regex_pattern: Option<String>,
71 pub grammar: Option<String>,
72 pub json_schema: Option<String>,
73 pub choice_list: Option<Vec<String>>,
74 pub max_violations: Option<usize>,
75 pub backtrack_on_violation: bool,
76 pub cfg: Option<CFGConfig>,
77}
78
79impl Default for GuidedGenerationConfig {
80 fn default() -> Self {
81 Self {
82 regex_pattern: None,
83 grammar: None,
84 json_schema: None,
85 choice_list: None,
86 max_violations: Some(3),
87 backtrack_on_violation: true,
88 cfg: None,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct CFGConfig {
96 pub guidance_scale: f32,
97 pub unconditional_prompt: Option<String>,
98 pub negative_prompt: Option<String>,
99 pub dynamic_thresholding: bool,
100 pub threshold_percentile: f32,
101}
102
103impl Default for CFGConfig {
104 fn default() -> Self {
105 Self {
106 guidance_scale: 7.5,
107 unconditional_prompt: None,
108 negative_prompt: None,
109 dynamic_thresholding: false,
110 threshold_percentile: 0.95,
111 }
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
117pub enum WatermarkingAlgorithm {
118 GreenList,
119 SoftRedList,
120 ExponentialMinimum,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct WatermarkingConfig {
126 pub algorithm: WatermarkingAlgorithm,
127 pub gamma: f32,
128 pub delta: f32,
129 pub hash_key: Option<u64>,
130 pub vocab_size: usize,
131 pub context_length: usize,
132 pub detection_threshold: f32,
133}
134
135impl Default for WatermarkingConfig {
136 fn default() -> Self {
137 Self {
138 algorithm: WatermarkingAlgorithm::GreenList,
139 gamma: 0.5,
140 delta: 2.0,
141 hash_key: None,
142 vocab_size: 50257, context_length: 4,
144 detection_threshold: 4.0,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct AssistedGenerationConfig {
152 pub draft_model_name: String,
153 pub candidate_length: usize,
154 pub acceptance_threshold: f32,
155 pub max_draft_tokens: usize,
156 pub use_dynamic_speculation: bool,
157 pub speculation_depth: usize,
158}
159
160impl Default for AssistedGenerationConfig {
161 fn default() -> Self {
162 Self {
163 draft_model_name: "distilbert-base-uncased".to_string(),
164 candidate_length: 5,
165 acceptance_threshold: 0.8,
166 max_draft_tokens: 20,
167 use_dynamic_speculation: true,
168 speculation_depth: 3,
169 }
170 }
171}
172
173pub mod builder {
175 use super::*;
176 use crate::errors::Result;
177 use crate::patterns::{validators, StandardConfig, ValidatedBuilder};
178
179 impl StandardConfig for GenerationConfig {
181 fn validate(&self) -> Result<()> {
182 if let (Some(max_length), Some(max_new_tokens)) = (self.max_length, self.max_new_tokens)
184 {
185 if max_new_tokens > max_length {
186 return Err(crate::errors::TrustformersError::config_error(
187 "max_new_tokens cannot be greater than max_length",
188 "generation_config_validation",
189 ));
190 }
191 }
192
193 if let (Some(min_length), Some(max_length)) = (self.min_length, self.max_length) {
195 if min_length > max_length {
196 return Err(crate::errors::TrustformersError::config_error(
197 "min_length cannot be greater than max_length",
198 "generation_config_validation",
199 ));
200 }
201 }
202
203 validators::positive(self.repetition_penalty, "repetition_penalty")?;
205
206 validators::positive(self.length_penalty, "length_penalty")?;
208
209 validators::positive(self.num_return_sequences, "num_return_sequences")?;
211
212 match self.strategy {
214 GenerationStrategy::Sampling { temperature } => {
215 validators::positive(temperature, "temperature")?;
216 },
217 GenerationStrategy::TopK { k, temperature } => {
218 validators::positive(k, "k")?;
219 validators::positive(temperature, "temperature")?;
220 },
221 GenerationStrategy::TopP { p, temperature } => {
222 validators::numeric_range(p, 0.0, 1.0, "p")?;
223 validators::positive(temperature, "temperature")?;
224 },
225 GenerationStrategy::BeamSearch { num_beams } => {
226 validators::positive(num_beams, "num_beams")?;
227 },
228 GenerationStrategy::ContrastiveSearch {
229 penalty_alpha,
230 top_k,
231 } => {
232 validators::positive(penalty_alpha, "penalty_alpha")?;
233 validators::positive(top_k, "top_k")?;
234 },
235 GenerationStrategy::Greedy => {}, }
237
238 Ok(())
239 }
240 }
241
242 impl StandardConfig for GuidedGenerationConfig {}
243 impl StandardConfig for CFGConfig {}
244 impl StandardConfig for WatermarkingConfig {}
245 impl StandardConfig for AssistedGenerationConfig {}
246
247 pub type GenerationConfigBuilder = ValidatedBuilder<GenerationConfig>;
249
250 impl GenerationConfigBuilder {
251 pub fn with_validation() -> Self {
253 ValidatedBuilder::new().add_validator(|config: &GenerationConfig| config.validate())
254 }
255
256 pub fn strategy(mut self, strategy: GenerationStrategy) -> Self {
258 self.data_mut().strategy = strategy;
259 self
260 }
261
262 pub fn max_length(mut self, max_length: usize) -> Self {
264 self.data_mut().max_length = Some(max_length);
265 self
266 }
267
268 pub fn max_new_tokens(mut self, max_new_tokens: usize) -> Self {
270 self.data_mut().max_new_tokens = Some(max_new_tokens);
271 self
272 }
273
274 pub fn min_length(mut self, min_length: usize) -> Self {
276 self.data_mut().min_length = Some(min_length);
277 self
278 }
279
280 pub fn enable_sampling(mut self, do_sample: bool) -> Self {
282 self.data_mut().do_sample = do_sample;
283 self
284 }
285
286 pub fn early_stopping(mut self, early_stopping: bool) -> Self {
288 self.data_mut().early_stopping = early_stopping;
289 self
290 }
291
292 pub fn num_return_sequences(mut self, num_sequences: usize) -> Self {
294 self.data_mut().num_return_sequences = num_sequences;
295 self
296 }
297
298 pub fn pad_token_id(mut self, token_id: usize) -> Self {
300 self.data_mut().pad_token_id = Some(token_id);
301 self
302 }
303
304 pub fn eos_token_id(mut self, token_id: usize) -> Self {
306 self.data_mut().eos_token_id = Some(token_id);
307 self
308 }
309
310 pub fn bos_token_id(mut self, token_id: usize) -> Self {
312 self.data_mut().bos_token_id = Some(token_id);
313 self
314 }
315
316 pub fn repetition_penalty(mut self, penalty: f32) -> Self {
318 self.data_mut().repetition_penalty = penalty;
319 self
320 }
321
322 pub fn length_penalty(mut self, penalty: f32) -> Self {
324 self.data_mut().length_penalty = penalty;
325 self
326 }
327
328 pub fn no_repeat_ngram_size(mut self, size: usize) -> Self {
330 self.data_mut().no_repeat_ngram_size = Some(size);
331 self
332 }
333
334 pub fn use_cache(mut self, use_cache: bool) -> Self {
336 self.data_mut().use_cache = use_cache;
337 self
338 }
339
340 pub fn streaming(mut self, streaming: bool) -> Self {
342 self.data_mut().streaming = streaming;
343 self
344 }
345
346 pub fn guided_generation(mut self, config: GuidedGenerationConfig) -> Self {
348 self.data_mut().guided_generation = Some(config);
349 self
350 }
351
352 pub fn watermarking(mut self, config: WatermarkingConfig) -> Self {
354 self.data_mut().watermarking = Some(config);
355 self
356 }
357
358 pub fn greedy(mut self) -> Self {
360 self.data_mut().strategy = GenerationStrategy::Greedy;
361 self.data_mut().do_sample = false;
362 self
363 }
364
365 pub fn sampling_with_temperature(mut self, temperature: f32) -> Self {
367 self.data_mut().strategy = GenerationStrategy::Sampling { temperature };
368 self.data_mut().do_sample = true;
369 self
370 }
371
372 pub fn top_k_sampling(mut self, k: usize, temperature: f32) -> Self {
374 self.data_mut().strategy = GenerationStrategy::TopK { k, temperature };
375 self.data_mut().do_sample = true;
376 self
377 }
378
379 pub fn top_p_sampling(mut self, p: f32, temperature: f32) -> Self {
381 self.data_mut().strategy = GenerationStrategy::TopP { p, temperature };
382 self.data_mut().do_sample = true;
383 self
384 }
385
386 pub fn beam_search(mut self, num_beams: usize) -> Self {
388 self.data_mut().strategy = GenerationStrategy::BeamSearch { num_beams };
389 self.data_mut().do_sample = false;
390 self
391 }
392 }
393
394 pub fn generation_config() -> GenerationConfigBuilder {
396 GenerationConfigBuilder::with_validation()
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::builder::*;
403 use super::*;
404 use crate::patterns::builder::Builder;
405
406 #[test]
407 fn test_generation_config_builder() {
408 let config = generation_config()
409 .greedy()
410 .max_length(100)
411 .early_stopping(true)
412 .build()
413 .expect("operation failed in test");
414
415 assert_eq!(config.strategy, GenerationStrategy::Greedy);
416 assert_eq!(config.max_length, Some(100));
417 assert!(config.early_stopping);
418 }
419
420 #[test]
421 fn test_generation_config_validation() {
422 let result = generation_config().min_length(200).max_length(100).build();
424
425 assert!(result.is_err());
426 }
427
428 #[test]
429 fn test_sampling_config() {
430 let config = generation_config()
431 .sampling_with_temperature(0.8)
432 .max_new_tokens(50)
433 .repetition_penalty(1.1)
434 .build()
435 .expect("operation failed in test");
436
437 if let GenerationStrategy::Sampling { temperature } = config.strategy {
438 assert_eq!(temperature, 0.8);
439 } else {
440 panic!("Expected sampling strategy but got {:?}", config.strategy);
441 }
442
443 assert_eq!(config.max_new_tokens, Some(50));
444 assert_eq!(config.repetition_penalty, 1.1);
445 }
446
447 #[test]
448 fn test_beam_search_config() {
449 let config = generation_config()
450 .beam_search(4)
451 .max_length(200)
452 .length_penalty(0.8)
453 .build()
454 .expect("operation failed in test");
455
456 if let GenerationStrategy::BeamSearch { num_beams } = config.strategy {
457 assert_eq!(num_beams, 4);
458 } else {
459 panic!(
460 "Expected beam search strategy but got {:?}",
461 config.strategy
462 );
463 }
464 }
465}