1use tch::{Device, Kind};
35
36use crate::common::error::RustBertError;
37use crate::gpt2::GPT2Generator;
38use crate::gpt_j::GptJGenerator;
39use crate::gpt_neo::GptNeoGenerator;
40use crate::openai_gpt::OpenAIGenerator;
41use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
42use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
43use crate::reformer::ReformerGenerator;
44use crate::resources::ResourceProvider;
45use crate::t5::T5Generator;
46use crate::xlnet::XLNetGenerator;
47
48#[cfg(feature = "onnx")]
49use crate::pipelines::onnx::ONNXCausalGenerator;
50#[cfg(feature = "remote")]
51use crate::{
52 gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
53 resources::RemoteResource,
54};
55
56pub struct TextGenerationConfig {
60 pub model_type: ModelType,
62 pub model_resource: ModelResource,
64 pub config_resource: Box<dyn ResourceProvider + Send>,
66 pub vocab_resource: Box<dyn ResourceProvider + Send>,
68 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
70 pub min_length: i64,
72 pub max_length: Option<i64>,
74 pub do_sample: bool,
76 pub early_stopping: bool,
78 pub num_beams: i64,
80 pub temperature: f64,
82 pub top_k: i64,
84 pub top_p: f64,
86 pub repetition_penalty: f64,
88 pub length_penalty: f64,
90 pub no_repeat_ngram_size: i64,
92 pub num_return_sequences: i64,
94 pub num_beam_groups: Option<i64>,
96 pub diversity_penalty: Option<f64>,
98 pub device: Device,
100 pub kind: Option<Kind>,
102}
103
104impl TextGenerationConfig {
105 pub fn new<RC, RV>(
115 model_type: ModelType,
116 model_resource: ModelResource,
117 config_resource: RC,
118 vocab_resource: RV,
119 merges_resource: Option<RV>,
120 ) -> TextGenerationConfig
121 where
122 RC: ResourceProvider + Send + 'static,
123 RV: ResourceProvider + Send + 'static,
124 {
125 TextGenerationConfig {
126 model_type,
127 model_resource,
128 config_resource: Box::new(config_resource),
129 vocab_resource: Box::new(vocab_resource),
130 merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
131 min_length: 0,
132 max_length: Some(56),
133 do_sample: true,
134 early_stopping: true,
135 num_beams: 5,
136 temperature: 1.0,
137 top_k: 0,
138 top_p: 0.9,
139 repetition_penalty: 1.0,
140 length_penalty: 1.0,
141 no_repeat_ngram_size: 0,
142 num_return_sequences: 1,
143 num_beam_groups: None,
144 diversity_penalty: None,
145 device: Device::cuda_if_available(),
146 kind: None,
147 }
148 }
149}
150
151#[cfg(feature = "remote")]
152impl Default for TextGenerationConfig {
153 fn default() -> TextGenerationConfig {
154 TextGenerationConfig::new(
155 ModelType::GPT2,
156 ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
157 Gpt2ModelResources::GPT2_MEDIUM,
158 ))),
159 RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
160 RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
161 Some(RemoteResource::from_pretrained(
162 Gpt2MergesResources::GPT2_MEDIUM,
163 )),
164 )
165 }
166}
167
168impl From<TextGenerationConfig> for GenerateConfig {
169 fn from(config: TextGenerationConfig) -> GenerateConfig {
170 GenerateConfig {
171 model_type: config.model_type,
172 model_resource: config.model_resource,
173 config_resource: config.config_resource,
174 merges_resource: config.merges_resource,
175 vocab_resource: config.vocab_resource,
176 min_length: config.min_length,
177 max_length: config.max_length,
178 do_sample: config.do_sample,
179 early_stopping: config.early_stopping,
180 num_beams: config.num_beams,
181 temperature: config.temperature,
182 top_k: config.top_k,
183 top_p: config.top_p,
184 repetition_penalty: config.repetition_penalty,
185 length_penalty: config.length_penalty,
186 no_repeat_ngram_size: config.no_repeat_ngram_size,
187 num_return_sequences: config.num_return_sequences,
188 num_beam_groups: config.num_beam_groups,
189 diversity_penalty: config.diversity_penalty,
190 device: config.device,
191 kind: config.kind,
192 }
193 }
194}
195
196pub enum TextGenerationOption {
198 GPT2(GPT2Generator),
200 GPT(OpenAIGenerator),
202 GPTNeo(GptNeoGenerator),
204 GPTJ(GptJGenerator),
206 XLNet(XLNetGenerator),
208 Reformer(ReformerGenerator),
210 T5(T5Generator),
212 #[cfg(feature = "onnx")]
214 ONNX(ONNXCausalGenerator),
215}
216
217impl TextGenerationOption {
218 pub fn new(config: TextGenerationConfig) -> Result<Self, RustBertError> {
219 match (config.model_type, &config.model_resource) {
220 #[cfg(feature = "onnx")]
221 (_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
222 ONNXCausalGenerator::new(config.into(), None, None)?,
223 )),
224 (ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
225 config.into(),
226 )?)),
227 (ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(OpenAIGenerator::new(
228 config.into(),
229 )?)),
230 (ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(XLNetGenerator::new(
231 config.into(),
232 )?)),
233 (ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(ReformerGenerator::new(
234 config.into(),
235 )?)),
236 (ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(GptNeoGenerator::new(
237 config.into(),
238 )?)),
239 (ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(GptJGenerator::new(
240 config.into(),
241 )?)),
242 (ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new(config.into())?)),
243 _ => Err(RustBertError::InvalidConfigurationError(format!(
244 "Text generation not implemented for {:?}!",
245 config.model_type
246 ))),
247 }
248 }
249
250 pub fn new_with_tokenizer(
251 config: TextGenerationConfig,
252 tokenizer: TokenizerOption,
253 ) -> Result<Self, RustBertError> {
254 match (config.model_type, &config.model_resource) {
255 #[cfg(feature = "onnx")]
256 (_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
257 ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
258 )),
259 (ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
260 GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,
261 )),
262 (ModelType::OpenAiGpt, _) => Ok(TextGenerationOption::GPT(
263 OpenAIGenerator::new_with_tokenizer(config.into(), tokenizer)?,
264 )),
265 (ModelType::XLNet, _) => Ok(TextGenerationOption::XLNet(
266 XLNetGenerator::new_with_tokenizer(config.into(), tokenizer)?,
267 )),
268 (ModelType::Reformer, _) => Ok(TextGenerationOption::Reformer(
269 ReformerGenerator::new_with_tokenizer(config.into(), tokenizer)?,
270 )),
271 (ModelType::GPTNeo, _) => Ok(TextGenerationOption::GPTNeo(
272 GptNeoGenerator::new_with_tokenizer(config.into(), tokenizer)?,
273 )),
274 (ModelType::GPTJ, _) => Ok(TextGenerationOption::GPTJ(
275 GptJGenerator::new_with_tokenizer(config.into(), tokenizer)?,
276 )),
277 (ModelType::T5, _) => Ok(TextGenerationOption::T5(T5Generator::new_with_tokenizer(
278 config.into(),
279 tokenizer,
280 )?)),
281 _ => Err(RustBertError::InvalidConfigurationError(format!(
282 "Text generation not implemented for {:?}!",
283 config.model_type
284 ))),
285 }
286 }
287
288 pub fn model_type(&self) -> ModelType {
290 match *self {
291 Self::GPT(_) => ModelType::OpenAiGpt,
292 Self::GPT2(_) => ModelType::GPT2,
293 Self::GPTNeo(_) => ModelType::GPTNeo,
294 Self::GPTJ(_) => ModelType::GPTJ,
295 Self::XLNet(_) => ModelType::XLNet,
296 Self::Reformer(_) => ModelType::Reformer,
297 Self::T5(_) => ModelType::T5,
298 #[cfg(feature = "onnx")]
299 Self::ONNX(_) => ModelType::ONNX,
300 }
301 }
302 pub fn get_tokenizer(&self) -> &TokenizerOption {
304 match self {
305 Self::GPT(model_ref) => model_ref.get_tokenizer(),
306 Self::GPT2(model_ref) => model_ref.get_tokenizer(),
307 Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
308 Self::GPTJ(model_ref) => model_ref.get_tokenizer(),
309 Self::XLNet(model_ref) => model_ref.get_tokenizer(),
310 Self::Reformer(model_ref) => model_ref.get_tokenizer(),
311 Self::T5(model_ref) => model_ref.get_tokenizer(),
312 #[cfg(feature = "onnx")]
313 Self::ONNX(model_ref) => model_ref.get_tokenizer(),
314 }
315 }
316
317 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
319 match self {
320 Self::GPT(model_ref) => model_ref.get_tokenizer_mut(),
321 Self::GPT2(model_ref) => model_ref.get_tokenizer_mut(),
322 Self::GPTNeo(model_ref) => model_ref.get_tokenizer_mut(),
323 Self::GPTJ(model_ref) => model_ref.get_tokenizer_mut(),
324 Self::XLNet(model_ref) => model_ref.get_tokenizer_mut(),
325 Self::Reformer(model_ref) => model_ref.get_tokenizer_mut(),
326 Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
327 #[cfg(feature = "onnx")]
328 Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
329 }
330 }
331
332 pub fn generate_indices<S>(
334 &self,
335 prompt_texts: Option<&[S]>,
336 min_length: Option<i64>,
337 max_length: Option<i64>,
338 ) -> Result<Vec<Vec<i64>>, RustBertError>
339 where
340 S: AsRef<str> + Send + Sync,
341 {
342 let generate_options = Some(GenerateOptions {
343 min_length,
344 max_length,
345 ..Default::default()
346 });
347 Ok(match *self {
348 Self::GPT(ref model) => model
349 .generate_indices(prompt_texts, generate_options)?
350 .into_iter()
351 .map(|output| output.indices)
352 .collect(),
353 Self::GPT2(ref model) => model
354 .generate_indices(prompt_texts, generate_options)?
355 .into_iter()
356 .map(|output| output.indices)
357 .collect(),
358 Self::GPTNeo(ref model) => model
359 .generate_indices(prompt_texts, generate_options)?
360 .into_iter()
361 .map(|output| output.indices)
362 .collect(),
363 Self::GPTJ(ref model) => model
364 .generate_indices(prompt_texts, generate_options)?
365 .into_iter()
366 .map(|output| output.indices)
367 .collect(),
368 Self::XLNet(ref model) => model
369 .generate_indices(prompt_texts, generate_options)?
370 .into_iter()
371 .map(|output| output.indices)
372 .collect(),
373 Self::Reformer(ref model) => model
374 .generate_indices(prompt_texts, generate_options)?
375 .into_iter()
376 .map(|output| output.indices)
377 .collect(),
378 Self::T5(ref model) => model
379 .generate_indices(prompt_texts, generate_options)?
380 .into_iter()
381 .map(|output| output.indices)
382 .collect(),
383 #[cfg(feature = "onnx")]
384 Self::ONNX(ref model) => model
385 .generate_indices(prompt_texts, generate_options)?
386 .into_iter()
387 .map(|output| output.indices)
388 .collect(),
389 })
390 }
391
392 pub fn half(&mut self) -> Result<(), RustBertError> {
393 match self {
394 Self::GPT(model_ref) => model_ref.half(),
395 Self::GPT2(model_ref) => model_ref.half(),
396 Self::GPTNeo(model_ref) => model_ref.half(),
397 Self::GPTJ(model_ref) => model_ref.half(),
398 Self::XLNet(model_ref) => model_ref.half(),
399 Self::Reformer(model_ref) => model_ref.half(),
400 Self::T5(model_ref) => model_ref.half(),
401 #[cfg(feature = "onnx")]
402 Self::ONNX(_) => Err(RustBertError::OrtError(
403 "Type casting not supported for ONNX models.".to_string(),
404 )),
405 }
406 }
407
408 pub fn float(&mut self) -> Result<(), RustBertError> {
409 match self {
410 Self::GPT(model_ref) => model_ref.float(),
411 Self::GPT2(model_ref) => model_ref.float(),
412 Self::GPTNeo(model_ref) => model_ref.float(),
413 Self::GPTJ(model_ref) => model_ref.float(),
414 Self::XLNet(model_ref) => model_ref.float(),
415 Self::Reformer(model_ref) => model_ref.float(),
416 Self::T5(model_ref) => model_ref.float(),
417 #[cfg(feature = "onnx")]
418 Self::ONNX(_) => Err(RustBertError::OrtError(
419 "Type casting not supported for ONNX models.".to_string(),
420 )),
421 }
422 }
423
424 pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
425 match self {
426 Self::GPT(model_ref) => model_ref.set_device(device),
427 Self::GPT2(model_ref) => model_ref.set_device(device),
428 Self::GPTNeo(model_ref) => model_ref.set_device(device),
429 Self::GPTJ(model_ref) => model_ref.set_device(device),
430 Self::XLNet(model_ref) => model_ref.set_device(device),
431 Self::Reformer(model_ref) => model_ref.set_device(device),
432 Self::T5(model_ref) => model_ref.set_device(device),
433 #[cfg(feature = "onnx")]
434 Self::ONNX(_) => Err(RustBertError::OrtError(
435 "Device assignment not supported for ONNX models.".to_string(),
436 )),
437 }
438 }
439}
440
441pub struct TextGenerationModel {
443 model: TextGenerationOption,
444 prefix: Option<String>,
445 prefix_length: Option<i64>,
446 min_length: i64,
447 max_length: Option<i64>,
448}
449
450impl TextGenerationModel {
451 pub fn new(
469 generation_config: TextGenerationConfig,
470 ) -> Result<TextGenerationModel, RustBertError> {
471 let (prefix, min_length, max_length) =
472 TextGenerationModel::get_prefix_min_max_length(&generation_config);
473 let model = TextGenerationOption::new(generation_config)?;
474 let prefix_length = prefix
475 .as_ref()
476 .map(|prefix| model.get_tokenizer().tokenize(prefix).len() as i64);
477 Ok(TextGenerationModel {
478 model,
479 prefix,
480 prefix_length,
481 min_length,
482 max_length,
483 })
484 }
485
486 pub fn new_with_tokenizer(
513 generation_config: TextGenerationConfig,
514 tokenizer: TokenizerOption,
515 ) -> Result<TextGenerationModel, RustBertError> {
516 let (prefix, min_length, max_length) =
517 TextGenerationModel::get_prefix_min_max_length(&generation_config);
518 let model = TextGenerationOption::new_with_tokenizer(generation_config, tokenizer)?;
519 let prefix_length = prefix
520 .as_ref()
521 .map(|prefix| model.get_tokenizer().tokenize(prefix).len() as i64);
522 Ok(TextGenerationModel {
523 model,
524 prefix,
525 prefix_length,
526 min_length,
527 max_length,
528 })
529 }
530
531 fn get_prefix_min_max_length(
532 generation_config: &TextGenerationConfig,
533 ) -> (Option<String>, i64, Option<i64>) {
534 let prefix = match generation_config.model_type {
535 ModelType::XLNet => Some(
536 "In 1991, the remains of Russian Tsar Nicholas II and his family \
537(except for Alexei and Maria) are discovered. \
538The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the \
539remainder of the story. 1883 Western Siberia, \
540a young Grigori Rasputin is asked by his father and a group of men to perform magic. \
541Rasputin has a vision and denounces one of the men as a horse thief. Although his \
542father initially slaps him for making such an accusation, Rasputin watches as the \
543man is chased outside and beaten. Twenty years later, Rasputin sees a vision of \
544the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, \
545with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
546 .to_string(),
547 ),
548 _ => None,
549 };
550
551 let min_length = generation_config.min_length;
552 let max_length = generation_config.max_length;
553 (prefix, min_length, max_length)
554 }
555
556 pub fn get_tokenizer(&self) -> &TokenizerOption {
557 self.model.get_tokenizer()
558 }
559
560 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
561 self.model.get_tokenizer_mut()
562 }
563
564 pub fn half(&mut self) -> Result<(), RustBertError> {
565 self.model.half()
566 }
567
568 pub fn float(&mut self) -> Result<(), RustBertError> {
569 self.model.float()
570 }
571
572 pub fn set_device(&mut self, device: Device) -> Result<(), RustBertError> {
573 self.model.set_device(device)
574 }
575
576 pub fn generate<'a, S>(
603 &self,
604 texts: &[S],
605 prefix: impl Into<Option<&'a str>>,
606 ) -> Result<Vec<String>, RustBertError>
607 where
608 S: AsRef<str> + Send + Sync,
609 {
610 let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
611 (Some(query_prefix), _) => (
612 Some(query_prefix),
613 Some(self.model.get_tokenizer().tokenize(query_prefix).len() as i64),
614 ),
615 (None, Some(pipeline_prefix)) => (Some(pipeline_prefix.as_str()), self.prefix_length),
616 (None, None) => (None, None),
617 };
618 let generated_indices = match (prefix, prefix_length) {
619 (None, _) => self.model.generate_indices(Some(texts), None, None),
620 (Some(prefix), Some(prefix_length)) => {
621 let texts = texts
622 .as_ref()
623 .iter()
624 .map(|text| format!("{} {}", prefix, text.as_ref()))
625 .collect::<Vec<String>>();
626 self.model.generate_indices(
627 Some(&texts),
628 Some(self.min_length + prefix_length),
629 self.max_length.map(|max_length| max_length + prefix_length),
630 )
631 }
632 _ => Err(RustBertError::ValueError(
633 "Prefix length not defined but prefix provided!".to_string(),
634 )),
635 }?;
636
637 let mut output = Vec::with_capacity(generated_indices.len());
638 for generated_sequence in generated_indices {
639 output.push(self.model.get_tokenizer().decode(
640 &generated_sequence[prefix_length.unwrap_or(0) as usize..],
641 true,
642 true,
643 ));
644 }
645 Ok(output)
646 }
647}
648
649#[cfg(test)]
650mod test {
651 use super::*;
652
653 #[test]
654 #[ignore] fn test() {
656 let config = TextGenerationConfig::default();
657 let _: Box<dyn Send> = Box::new(TextGenerationModel::new(config));
658 }
659}