rust_bert/pipelines/summarization.rs
1// Copyright 2020 The Facebook AI Research Team Authors
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Summarization pipeline
15//! Abstractive summarization of texts based on the BART encoder-decoder architecture
16//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
17//! By default, the dependencies for this model will be downloaded for a BART model finetuned on CNN/DM.
18//! Customized BART models can be loaded by overwriting the resources in the configuration.
19//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/bart-cnn
20//!
21//!
22//! ```no_run
23//! # fn main() -> anyhow::Result<()> {
24//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
25//! use rust_bert::pipelines::summarization::SummarizationModel;
26//! let mut model = SummarizationModel::new(Default::default())?;
27//!
28//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
29//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
30//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
31//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
32//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
33//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
34//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
35//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
36//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
37//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
38//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
39//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
40//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
41//! said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
42//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
43//! a potentially habitable planet, but further observations will be required to say for sure. \"
44//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
45//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
46//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
47//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
48//! about exoplanets like K2-18b."];
49//!
50//! let output = model.summarize(&input);
51//! # Ok(())
52//! # }
53//! ```
54//! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
55//!
56//! Example output: \
57//! ```no_run
58//! # let output =
59//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
60//! This is the first such discovery in a planet in its star's habitable zone.
61//! The planet is not too hot and not too cold for liquid water to exist."
62//! # ;
63//! ```
64
65use tch::{Device, Kind};
66
67use crate::bart::BartGenerator;
68use crate::common::error::RustBertError;
69use crate::pegasus::PegasusConditionalGenerator;
70use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
71use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
72use crate::prophetnet::ProphetNetConditionalGenerator;
73use crate::resources::ResourceProvider;
74use crate::t5::T5Generator;
75
76use crate::longt5::LongT5Generator;
77#[cfg(feature = "onnx")]
78use crate::pipelines::onnx::ONNXConditionalGenerator;
79#[cfg(feature = "remote")]
80use crate::{
81 bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
82 resources::RemoteResource,
83};
84
85/// # Configuration for text summarization
86/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
87/// different set of default parameters and sets the device to place the model on.
88pub struct SummarizationConfig {
89 /// Model type
90 pub model_type: ModelType,
91 /// Model weights resource (default: pretrained BART model on CNN-DM)
92 pub model_resource: ModelResource,
93 /// Config resource (default: pretrained BART model on CNN-DM)
94 pub config_resource: Box<dyn ResourceProvider + Send>,
95 /// Vocab resource (default: pretrained BART model on CNN-DM)
96 pub vocab_resource: Box<dyn ResourceProvider + Send>,
97 /// Merges resource (default: pretrained BART model on CNN-DM)
98 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
99 /// Minimum sequence length (default: 0)
100 pub min_length: i64,
101 /// Maximum sequence length (default: 20)
102 pub max_length: Option<i64>,
103 /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
104 pub do_sample: bool,
105 /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
106 pub early_stopping: bool,
107 /// Number of beams for beam search (default: 5)
108 pub num_beams: i64,
109 /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
110 pub temperature: f64,
111 /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
112 pub top_k: i64,
113 /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
114 pub top_p: f64,
115 /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
116 pub repetition_penalty: f64,
117 /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
118 pub length_penalty: f64,
119 /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
120 pub no_repeat_ngram_size: i64,
121 /// Number of sequences to return for each prompt text (default: 1)
122 pub num_return_sequences: i64,
123 /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
124 pub num_beam_groups: Option<i64>,
125 /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
126 pub diversity_penalty: Option<f64>,
127 /// Device to place the model on (default: CUDA/GPU when available)
128 pub device: Device,
129 /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
130 pub kind: Option<Kind>,
131}
132
133impl SummarizationConfig {
134 /// Instantiate a new summarization configuration of the supplied type.
135 ///
136 /// # Arguments
137 ///
138 /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
139 /// * model_resource - The `ModelResources` pointing to the model to load (e.g. model.ot)
140 /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
141 /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
142 /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
143 pub fn new<RC, RV>(
144 model_type: ModelType,
145 model_resource: ModelResource,
146 config_resource: RC,
147 vocab_resource: RV,
148 merges_resource: Option<RV>,
149 ) -> SummarizationConfig
150 where
151 RC: ResourceProvider + Send + 'static,
152 RV: ResourceProvider + Send + 'static,
153 {
154 SummarizationConfig {
155 model_type,
156 model_resource,
157 config_resource: Box::new(config_resource),
158 vocab_resource: Box::new(vocab_resource),
159 merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
160 min_length: 56,
161 max_length: Some(142),
162 do_sample: false,
163 early_stopping: true,
164 num_beams: 3,
165 temperature: 1.0,
166 top_k: 50,
167 top_p: 1.0,
168 repetition_penalty: 1.0,
169 length_penalty: 1.0,
170 no_repeat_ngram_size: 3,
171 num_return_sequences: 1,
172 num_beam_groups: None,
173 diversity_penalty: None,
174 device: Device::cuda_if_available(),
175 kind: None,
176 }
177 }
178}
179
180#[cfg(feature = "remote")]
181impl Default for SummarizationConfig {
182 fn default() -> SummarizationConfig {
183 SummarizationConfig::new(
184 ModelType::Bart,
185 ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
186 BartModelResources::BART_CNN,
187 ))),
188 RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
189 RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
190 Some(RemoteResource::from_pretrained(
191 BartMergesResources::BART_CNN,
192 )),
193 )
194 }
195}
196
197impl From<SummarizationConfig> for GenerateConfig {
198 fn from(config: SummarizationConfig) -> GenerateConfig {
199 GenerateConfig {
200 model_type: config.model_type,
201 model_resource: config.model_resource,
202 config_resource: config.config_resource,
203 merges_resource: config.merges_resource,
204 vocab_resource: config.vocab_resource,
205 min_length: config.min_length,
206 max_length: config.max_length,
207 do_sample: config.do_sample,
208 early_stopping: config.early_stopping,
209 num_beams: config.num_beams,
210 temperature: config.temperature,
211 top_k: config.top_k,
212 top_p: config.top_p,
213 repetition_penalty: config.repetition_penalty,
214 length_penalty: config.length_penalty,
215 no_repeat_ngram_size: config.no_repeat_ngram_size,
216 num_return_sequences: config.num_return_sequences,
217 num_beam_groups: config.num_beam_groups,
218 diversity_penalty: config.diversity_penalty,
219 device: config.device,
220 kind: config.kind,
221 }
222 }
223}
224
225/// # Abstraction that holds one particular summarization model, for any of the supported models
226pub enum SummarizationOption {
227 /// Summarizer based on BART model
228 Bart(BartGenerator),
229 /// Summarizer based on T5 model
230 T5(T5Generator),
231 /// Summarizer based on LongT5 model
232 LongT5(LongT5Generator),
233 /// Summarizer based on ProphetNet model
234 ProphetNet(ProphetNetConditionalGenerator),
235 /// Summarizer based on Pegasus model
236 Pegasus(PegasusConditionalGenerator),
237 /// Summarizer based on ONNX model
238 #[cfg(feature = "onnx")]
239 ONNX(ONNXConditionalGenerator),
240}
241
242impl SummarizationOption {
243 pub fn new(config: SummarizationConfig) -> Result<Self, RustBertError> {
244 match (config.model_type, &config.model_resource) {
245 #[cfg(feature = "onnx")]
246 (_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
247 ONNXConditionalGenerator::new(config.into(), None, None)?,
248 )),
249 (ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
250 config.into(),
251 )?)),
252 (ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
253 (ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(LongT5Generator::new(
254 config.into(),
255 )?)),
256 (ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
257 ProphetNetConditionalGenerator::new(config.into())?,
258 )),
259 (ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
260 PegasusConditionalGenerator::new(config.into())?,
261 )),
262 _ => Err(RustBertError::InvalidConfigurationError(format!(
263 "Summarization not implemented for {:?}!",
264 config.model_type
265 ))),
266 }
267 }
268
269 pub fn new_with_tokenizer(
270 config: SummarizationConfig,
271 tokenizer: TokenizerOption,
272 ) -> Result<Self, RustBertError> {
273 match (config.model_type, &config.model_resource) {
274 #[cfg(feature = "onnx")]
275 (_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
276 ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
277 )),
278 (ModelType::Bart, _) => Ok(SummarizationOption::Bart(
279 BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
280 )),
281 (ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
282 config.into(),
283 tokenizer,
284 )?)),
285 (ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(
286 LongT5Generator::new_with_tokenizer(config.into(), tokenizer)?,
287 )),
288 (ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
289 ProphetNetConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
290 )),
291 (ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
292 PegasusConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
293 )),
294 _ => Err(RustBertError::InvalidConfigurationError(format!(
295 "Summarization not implemented for {:?}!",
296 config.model_type
297 ))),
298 }
299 }
300
301 /// Returns the `ModelType` for this SummarizationOption
302 pub fn model_type(&self) -> ModelType {
303 match *self {
304 Self::Bart(_) => ModelType::Bart,
305 Self::T5(_) => ModelType::T5,
306 Self::LongT5(_) => ModelType::LongT5,
307 Self::ProphetNet(_) => ModelType::ProphetNet,
308 Self::Pegasus(_) => ModelType::Pegasus,
309 #[cfg(feature = "onnx")]
310 Self::ONNX(_) => ModelType::ONNX,
311 }
312 }
313
314 /// Interface method to access tokenizer
315 pub fn get_tokenizer(&self) -> &TokenizerOption {
316 match self {
317 Self::Bart(model_ref) => model_ref.get_tokenizer(),
318 Self::T5(model_ref) => model_ref.get_tokenizer(),
319 Self::LongT5(model_ref) => model_ref.get_tokenizer(),
320 Self::ProphetNet(model_ref) => model_ref.get_tokenizer(),
321 Self::Pegasus(model_ref) => model_ref.get_tokenizer(),
322 #[cfg(feature = "onnx")]
323 Self::ONNX(model_ref) => model_ref.get_tokenizer(),
324 }
325 }
326
327 /// Interface method to access tokenizer
328 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
329 match self {
330 Self::Bart(model_ref) => model_ref.get_tokenizer_mut(),
331 Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
332 Self::LongT5(model_ref) => model_ref.get_tokenizer_mut(),
333 Self::ProphetNet(model_ref) => model_ref.get_tokenizer_mut(),
334 Self::Pegasus(model_ref) => model_ref.get_tokenizer_mut(),
335 #[cfg(feature = "onnx")]
336 Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
337 }
338 }
339
340 /// Interface method to generate() of the particular models.
341 pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Result<Vec<String>, RustBertError>
342 where
343 S: AsRef<str> + Send + Sync,
344 {
345 Ok(match *self {
346 Self::Bart(ref model) => model
347 .generate(prompt_texts, None)?
348 .into_iter()
349 .map(|output| output.text)
350 .collect(),
351 Self::T5(ref model) => model
352 .generate(prompt_texts, None)?
353 .into_iter()
354 .map(|output| output.text)
355 .collect(),
356 Self::LongT5(ref model) => model
357 .generate(prompt_texts, None)?
358 .into_iter()
359 .map(|output| output.text)
360 .collect(),
361 Self::ProphetNet(ref model) => model
362 .generate(prompt_texts, None)?
363 .into_iter()
364 .map(|output| output.text)
365 .collect(),
366 Self::Pegasus(ref model) => model
367 .generate(prompt_texts, None)?
368 .into_iter()
369 .map(|output| output.text)
370 .collect(),
371 #[cfg(feature = "onnx")]
372 Self::ONNX(ref model) => model
373 .generate(prompt_texts, None)?
374 .into_iter()
375 .map(|output| output.text)
376 .collect(),
377 })
378 }
379}
380
381/// # SummarizationModel to perform summarization
382pub struct SummarizationModel {
383 model: SummarizationOption,
384 prefix: Option<String>,
385}
386
387impl SummarizationModel {
388 /// Build a new `SummarizationModel`
389 ///
390 /// # Arguments
391 ///
392 /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU)
393 ///
394 /// # Example
395 ///
396 /// ```no_run
397 /// # fn main() -> anyhow::Result<()> {
398 /// use rust_bert::pipelines::summarization::SummarizationModel;
399 ///
400 /// let mut summarization_model = SummarizationModel::new(Default::default())?;
401 /// # Ok(())
402 /// # }
403 /// ```
404 pub fn new(
405 summarization_config: SummarizationConfig,
406 ) -> Result<SummarizationModel, RustBertError> {
407 let prefix = match summarization_config.model_type {
408 ModelType::T5 => Some("summarize: ".to_string()),
409 _ => None,
410 };
411 let model = SummarizationOption::new(summarization_config)?;
412
413 Ok(SummarizationModel { model, prefix })
414 }
415
416 /// Build a new `SummarizationModel` with a provided tokenizer.
417 ///
418 /// # Arguments
419 ///
420 /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU)
421 /// * `tokenizer` - `TokenizerOption` tokenizer to use for summarization.
422 ///
423 /// # Example
424 ///
425 /// ```no_run
426 /// # fn main() -> anyhow::Result<()> {
427 /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
428 /// use rust_bert::pipelines::summarization::SummarizationModel;
429 /// let tokenizer = TokenizerOption::from_file(
430 /// ModelType::Bart,
431 /// "path/to/vocab.json",
432 /// Some("path/to/merges.txt"),
433 /// false,
434 /// None,
435 /// None,
436 /// )?;
437 /// let mut summarization_model =
438 /// SummarizationModel::new_with_tokenizer(Default::default(), tokenizer)?;
439 /// # Ok(())
440 /// # }
441 /// ```
442 pub fn new_with_tokenizer(
443 summarization_config: SummarizationConfig,
444 tokenizer: TokenizerOption,
445 ) -> Result<SummarizationModel, RustBertError> {
446 let prefix = match summarization_config.model_type {
447 ModelType::T5 => Some("summarize: ".to_string()),
448 _ => None,
449 };
450 let model = SummarizationOption::new_with_tokenizer(summarization_config, tokenizer)?;
451
452 Ok(SummarizationModel { model, prefix })
453 }
454
455 /// Get a reference to the model tokenizer.
456 pub fn get_tokenizer(&self) -> &TokenizerOption {
457 self.model.get_tokenizer()
458 }
459
460 /// Get a mutable reference to the model tokenizer.
461 pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
462 self.model.get_tokenizer_mut()
463 }
464
465 /// Summarize texts provided
466 ///
467 /// # Arguments
468 ///
469 /// * `input` - `&[&str]` Array of texts to summarize.
470 ///
471 /// # Returns
472 /// * `Vec<String>` Summarized texts
473 ///
474 /// # Example
475 ///
476 /// ```no_run
477 /// # fn main() -> anyhow::Result<()> {
478 /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
479 /// use rust_bert::pipelines::summarization::SummarizationModel;
480 /// let model = SummarizationModel::new(Default::default())?;
481 ///
482 /// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
483 /// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
484 /// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
485 /// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
486 /// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
487 /// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
488 /// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
489 /// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
490 /// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
491 /// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
492 /// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
493 /// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
494 /// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
495 /// said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
496 /// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
497 /// a potentially habitable planet, but further observations will be required to say for sure. \"
498 /// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
499 /// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
500 /// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
501 /// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
502 /// about exoplanets like K2-18b."];
503 ///
504 /// let output = model.summarize(&input);
505 /// # Ok(())
506 /// # }
507 /// ```
508 /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
509 pub fn summarize<S>(&self, texts: &[S]) -> Result<Vec<String>, RustBertError>
510 where
511 S: AsRef<str> + Send + Sync,
512 {
513 match &self.prefix {
514 None => self.model.generate(Some(texts)),
515 Some(prefix) => {
516 let texts = texts
517 .iter()
518 .map(|text| format!("{}{}", prefix, text.as_ref()))
519 .collect::<Vec<String>>();
520 self.model.generate(Some(&texts))
521 }
522 }
523 }
524}
525
526#[cfg(test)]
527mod test {
528 use super::*;
529
530 #[test]
531 #[ignore] // no need to run, compilation is enough to verify it is Send
532 fn test() {
533 let config = SummarizationConfig::default();
534 let _: Box<dyn Send> = Box::new(SummarizationModel::new(config));
535 }
536}