Skip to main content

trustformers_core/generation/
config.rs

1use serde::{Deserialize, Serialize};
2
3/// Generation strategies available in TrustformeRS
4#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
5pub enum GenerationStrategy {
6    /// Greedy decoding - always select highest probability token
7    #[default]
8    Greedy,
9    /// Random sampling with temperature
10    Sampling { temperature: f32 },
11    /// Top-k sampling
12    TopK { k: usize, temperature: f32 },
13    /// Top-p (nucleus) sampling
14    TopP { p: f32, temperature: f32 },
15    /// Beam search
16    BeamSearch { num_beams: usize },
17    /// Contrastive search
18    ContrastiveSearch { penalty_alpha: f32, top_k: usize },
19}
20
21/// Configuration for text generation
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct GenerationConfig {
24    pub strategy: GenerationStrategy,
25    pub max_length: Option<usize>,
26    pub max_new_tokens: Option<usize>,
27    pub min_length: Option<usize>,
28    pub do_sample: bool,
29    pub early_stopping: bool,
30    pub num_return_sequences: usize,
31    pub pad_token_id: Option<usize>,
32    pub eos_token_id: Option<usize>,
33    pub bos_token_id: Option<usize>,
34    pub repetition_penalty: f32,
35    pub length_penalty: f32,
36    pub no_repeat_ngram_size: Option<usize>,
37    pub use_cache: bool,
38    pub streaming: bool,
39    pub guided_generation: Option<GuidedGenerationConfig>,
40    pub watermarking: Option<WatermarkingConfig>,
41}
42
43impl Default for GenerationConfig {
44    fn default() -> Self {
45        Self {
46            strategy: GenerationStrategy::default(),
47            max_length: Some(100),
48            max_new_tokens: None,
49            min_length: Some(1),
50            do_sample: false,
51            early_stopping: false,
52            num_return_sequences: 1,
53            pad_token_id: None,
54            eos_token_id: None,
55            bos_token_id: None,
56            repetition_penalty: 1.0,
57            length_penalty: 1.0,
58            no_repeat_ngram_size: None,
59            use_cache: true,
60            streaming: false,
61            guided_generation: None,
62            watermarking: None,
63        }
64    }
65}
66
67/// Configuration for guided generation (constrained generation)
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct GuidedGenerationConfig {
70    pub regex_pattern: Option<String>,
71    pub grammar: Option<String>,
72    pub json_schema: Option<String>,
73    pub choice_list: Option<Vec<String>>,
74    pub max_violations: Option<usize>,
75    pub backtrack_on_violation: bool,
76    pub cfg: Option<CFGConfig>,
77}
78
79impl Default for GuidedGenerationConfig {
80    fn default() -> Self {
81        Self {
82            regex_pattern: None,
83            grammar: None,
84            json_schema: None,
85            choice_list: None,
86            max_violations: Some(3),
87            backtrack_on_violation: true,
88            cfg: None,
89        }
90    }
91}
92
93/// Configuration for Classifier-Free Guidance
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct CFGConfig {
96    pub guidance_scale: f32,
97    pub unconditional_prompt: Option<String>,
98    pub negative_prompt: Option<String>,
99    pub dynamic_thresholding: bool,
100    pub threshold_percentile: f32,
101}
102
103impl Default for CFGConfig {
104    fn default() -> Self {
105        Self {
106            guidance_scale: 7.5,
107            unconditional_prompt: None,
108            negative_prompt: None,
109            dynamic_thresholding: false,
110            threshold_percentile: 0.95,
111        }
112    }
113}
114
115/// Watermarking algorithms available
116#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
117pub enum WatermarkingAlgorithm {
118    GreenList,
119    SoftRedList,
120    ExponentialMinimum,
121}
122
123/// Configuration for text watermarking
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct WatermarkingConfig {
126    pub algorithm: WatermarkingAlgorithm,
127    pub gamma: f32,
128    pub delta: f32,
129    pub hash_key: Option<u64>,
130    pub vocab_size: usize,
131    pub context_length: usize,
132    pub detection_threshold: f32,
133}
134
135impl Default for WatermarkingConfig {
136    fn default() -> Self {
137        Self {
138            algorithm: WatermarkingAlgorithm::GreenList,
139            gamma: 0.5,
140            delta: 2.0,
141            hash_key: None,
142            vocab_size: 50257, // GPT-2 default
143            context_length: 4,
144            detection_threshold: 4.0,
145        }
146    }
147}
148
149/// Configuration for assisted generation
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct AssistedGenerationConfig {
152    pub draft_model_name: String,
153    pub candidate_length: usize,
154    pub acceptance_threshold: f32,
155    pub max_draft_tokens: usize,
156    pub use_dynamic_speculation: bool,
157    pub speculation_depth: usize,
158}
159
160impl Default for AssistedGenerationConfig {
161    fn default() -> Self {
162        Self {
163            draft_model_name: "distilbert-base-uncased".to_string(),
164            candidate_length: 5,
165            acceptance_threshold: 0.8,
166            max_draft_tokens: 20,
167            use_dynamic_speculation: true,
168            speculation_depth: 3,
169        }
170    }
171}
172
173/// Builder for GenerationConfig using standardized patterns
174pub mod builder {
175    use super::*;
176    use crate::errors::Result;
177    use crate::patterns::{validators, StandardConfig, ValidatedBuilder};
178
179    // Implement StandardConfig for our configuration types
180    impl StandardConfig for GenerationConfig {
181        fn validate(&self) -> Result<()> {
182            // Validate max_length and max_new_tokens relationship
183            if let (Some(max_length), Some(max_new_tokens)) = (self.max_length, self.max_new_tokens)
184            {
185                if max_new_tokens > max_length {
186                    return Err(crate::errors::TrustformersError::config_error(
187                        "max_new_tokens cannot be greater than max_length",
188                        "generation_config_validation",
189                    ));
190                }
191            }
192
193            // Validate min_length vs max_length
194            if let (Some(min_length), Some(max_length)) = (self.min_length, self.max_length) {
195                if min_length > max_length {
196                    return Err(crate::errors::TrustformersError::config_error(
197                        "min_length cannot be greater than max_length",
198                        "generation_config_validation",
199                    ));
200                }
201            }
202
203            // Validate repetition penalty
204            validators::positive(self.repetition_penalty, "repetition_penalty")?;
205
206            // Validate length penalty
207            validators::positive(self.length_penalty, "length_penalty")?;
208
209            // Validate num_return_sequences
210            validators::positive(self.num_return_sequences, "num_return_sequences")?;
211
212            // Validate strategy-specific parameters
213            match self.strategy {
214                GenerationStrategy::Sampling { temperature } => {
215                    validators::positive(temperature, "temperature")?;
216                },
217                GenerationStrategy::TopK { k, temperature } => {
218                    validators::positive(k, "k")?;
219                    validators::positive(temperature, "temperature")?;
220                },
221                GenerationStrategy::TopP { p, temperature } => {
222                    validators::numeric_range(p, 0.0, 1.0, "p")?;
223                    validators::positive(temperature, "temperature")?;
224                },
225                GenerationStrategy::BeamSearch { num_beams } => {
226                    validators::positive(num_beams, "num_beams")?;
227                },
228                GenerationStrategy::ContrastiveSearch {
229                    penalty_alpha,
230                    top_k,
231                } => {
232                    validators::positive(penalty_alpha, "penalty_alpha")?;
233                    validators::positive(top_k, "top_k")?;
234                },
235                GenerationStrategy::Greedy => {}, // No validation needed
236            }
237
238            Ok(())
239        }
240    }
241
242    impl StandardConfig for GuidedGenerationConfig {}
243    impl StandardConfig for CFGConfig {}
244    impl StandardConfig for WatermarkingConfig {}
245    impl StandardConfig for AssistedGenerationConfig {}
246
247    /// Builder for GenerationConfig
248    pub type GenerationConfigBuilder = ValidatedBuilder<GenerationConfig>;
249
250    impl GenerationConfigBuilder {
251        /// Create a new GenerationConfig builder with validation
252        pub fn with_validation() -> Self {
253            ValidatedBuilder::new().add_validator(|config: &GenerationConfig| config.validate())
254        }
255
256        /// Set the generation strategy
257        pub fn strategy(mut self, strategy: GenerationStrategy) -> Self {
258            self.data_mut().strategy = strategy;
259            self
260        }
261
262        /// Set max length
263        pub fn max_length(mut self, max_length: usize) -> Self {
264            self.data_mut().max_length = Some(max_length);
265            self
266        }
267
268        /// Set max new tokens
269        pub fn max_new_tokens(mut self, max_new_tokens: usize) -> Self {
270            self.data_mut().max_new_tokens = Some(max_new_tokens);
271            self
272        }
273
274        /// Set min length
275        pub fn min_length(mut self, min_length: usize) -> Self {
276            self.data_mut().min_length = Some(min_length);
277            self
278        }
279
280        /// Enable sampling
281        pub fn enable_sampling(mut self, do_sample: bool) -> Self {
282            self.data_mut().do_sample = do_sample;
283            self
284        }
285
286        /// Enable early stopping
287        pub fn early_stopping(mut self, early_stopping: bool) -> Self {
288            self.data_mut().early_stopping = early_stopping;
289            self
290        }
291
292        /// Set number of return sequences
293        pub fn num_return_sequences(mut self, num_sequences: usize) -> Self {
294            self.data_mut().num_return_sequences = num_sequences;
295            self
296        }
297
298        /// Set pad token ID
299        pub fn pad_token_id(mut self, token_id: usize) -> Self {
300            self.data_mut().pad_token_id = Some(token_id);
301            self
302        }
303
304        /// Set EOS token ID
305        pub fn eos_token_id(mut self, token_id: usize) -> Self {
306            self.data_mut().eos_token_id = Some(token_id);
307            self
308        }
309
310        /// Set BOS token ID
311        pub fn bos_token_id(mut self, token_id: usize) -> Self {
312            self.data_mut().bos_token_id = Some(token_id);
313            self
314        }
315
316        /// Set repetition penalty
317        pub fn repetition_penalty(mut self, penalty: f32) -> Self {
318            self.data_mut().repetition_penalty = penalty;
319            self
320        }
321
322        /// Set length penalty
323        pub fn length_penalty(mut self, penalty: f32) -> Self {
324            self.data_mut().length_penalty = penalty;
325            self
326        }
327
328        /// Set no repeat ngram size
329        pub fn no_repeat_ngram_size(mut self, size: usize) -> Self {
330            self.data_mut().no_repeat_ngram_size = Some(size);
331            self
332        }
333
334        /// Enable/disable caching
335        pub fn use_cache(mut self, use_cache: bool) -> Self {
336            self.data_mut().use_cache = use_cache;
337            self
338        }
339
340        /// Enable/disable streaming
341        pub fn streaming(mut self, streaming: bool) -> Self {
342            self.data_mut().streaming = streaming;
343            self
344        }
345
346        /// Set guided generation config
347        pub fn guided_generation(mut self, config: GuidedGenerationConfig) -> Self {
348            self.data_mut().guided_generation = Some(config);
349            self
350        }
351
352        /// Set watermarking config
353        pub fn watermarking(mut self, config: WatermarkingConfig) -> Self {
354            self.data_mut().watermarking = Some(config);
355            self
356        }
357
358        /// Quick setup for greedy decoding
359        pub fn greedy(mut self) -> Self {
360            self.data_mut().strategy = GenerationStrategy::Greedy;
361            self.data_mut().do_sample = false;
362            self
363        }
364
365        /// Quick setup for sampling with temperature
366        pub fn sampling_with_temperature(mut self, temperature: f32) -> Self {
367            self.data_mut().strategy = GenerationStrategy::Sampling { temperature };
368            self.data_mut().do_sample = true;
369            self
370        }
371
372        /// Quick setup for top-k sampling
373        pub fn top_k_sampling(mut self, k: usize, temperature: f32) -> Self {
374            self.data_mut().strategy = GenerationStrategy::TopK { k, temperature };
375            self.data_mut().do_sample = true;
376            self
377        }
378
379        /// Quick setup for top-p sampling
380        pub fn top_p_sampling(mut self, p: f32, temperature: f32) -> Self {
381            self.data_mut().strategy = GenerationStrategy::TopP { p, temperature };
382            self.data_mut().do_sample = true;
383            self
384        }
385
386        /// Quick setup for beam search
387        pub fn beam_search(mut self, num_beams: usize) -> Self {
388            self.data_mut().strategy = GenerationStrategy::BeamSearch { num_beams };
389            self.data_mut().do_sample = false;
390            self
391        }
392    }
393
394    // Convenience function for creating a builder
395    pub fn generation_config() -> GenerationConfigBuilder {
396        GenerationConfigBuilder::with_validation()
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::builder::*;
403    use super::*;
404    use crate::patterns::builder::Builder;
405
406    #[test]
407    fn test_generation_config_builder() {
408        let config = generation_config()
409            .greedy()
410            .max_length(100)
411            .early_stopping(true)
412            .build()
413            .expect("operation failed in test");
414
415        assert_eq!(config.strategy, GenerationStrategy::Greedy);
416        assert_eq!(config.max_length, Some(100));
417        assert!(config.early_stopping);
418    }
419
420    #[test]
421    fn test_generation_config_validation() {
422        // This should fail validation (min_length > max_length)
423        let result = generation_config().min_length(200).max_length(100).build();
424
425        assert!(result.is_err());
426    }
427
428    #[test]
429    fn test_sampling_config() {
430        let config = generation_config()
431            .sampling_with_temperature(0.8)
432            .max_new_tokens(50)
433            .repetition_penalty(1.1)
434            .build()
435            .expect("operation failed in test");
436
437        if let GenerationStrategy::Sampling { temperature } = config.strategy {
438            assert_eq!(temperature, 0.8);
439        } else {
440            panic!("Expected sampling strategy but got {:?}", config.strategy);
441        }
442
443        assert_eq!(config.max_new_tokens, Some(50));
444        assert_eq!(config.repetition_penalty, 1.1);
445    }
446
447    #[test]
448    fn test_beam_search_config() {
449        let config = generation_config()
450            .beam_search(4)
451            .max_length(200)
452            .length_penalty(0.8)
453            .build()
454            .expect("operation failed in test");
455
456        if let GenerationStrategy::BeamSearch { num_beams } = config.strategy {
457            assert_eq!(num_beams, 4);
458        } else {
459            panic!(
460                "Expected beam search strategy but got {:?}",
461                config.strategy
462            );
463        }
464    }
465}