rust_bert/pipelines/
summarization.rs

1// Copyright 2020 The Facebook AI Research Team Authors
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Summarization pipeline
15//! Abstractive summarization of texts based on the BART encoder-decoder architecture
16//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
17//! By default, the dependencies for this model will be downloaded for a BART model finetuned on CNN/DM.
18//! Customized BART models can be loaded by overwriting the resources in the configuration.
19//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/bart-cnn
20//!
21//!
22//! ```no_run
23//! # fn main() -> anyhow::Result<()> {
24//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
25//! use rust_bert::pipelines::summarization::SummarizationModel;
26//! let mut model = SummarizationModel::new(Default::default())?;
27//!
28//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
29//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
30//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
31//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
32//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
33//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
34//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
35//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
36//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
37//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
38//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
39//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
40//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
41//! said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
42//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
43//! a potentially habitable planet, but further observations will be required to say for sure. \"
44//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
45//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
46//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
47//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
48//! about exoplanets like K2-18b."];
49//!
50//! let output = model.summarize(&input);
51//! # Ok(())
52//! # }
53//! ```
54//! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
55//!
56//! Example output: \
57//! ```no_run
58//! # let output =
59//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
60//!  This is the first such discovery in a planet in its star's habitable zone.
61//!  The planet is not too hot and not too cold for liquid water to exist."
62//! # ;
63//! ```
64
65use tch::{Device, Kind};
66
67use crate::bart::BartGenerator;
68use crate::common::error::RustBertError;
69use crate::pegasus::PegasusConditionalGenerator;
70use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
71use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
72use crate::prophetnet::ProphetNetConditionalGenerator;
73use crate::resources::ResourceProvider;
74use crate::t5::T5Generator;
75
76use crate::longt5::LongT5Generator;
77#[cfg(feature = "onnx")]
78use crate::pipelines::onnx::ONNXConditionalGenerator;
79#[cfg(feature = "remote")]
80use crate::{
81    bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
82    resources::RemoteResource,
83};
84
85/// # Configuration for text summarization
86/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
87/// different set of default parameters and sets the device to place the model on.
88pub struct SummarizationConfig {
89    /// Model type
90    pub model_type: ModelType,
91    /// Model weights resource (default: pretrained BART model on CNN-DM)
92    pub model_resource: ModelResource,
93    /// Config resource (default: pretrained BART model on CNN-DM)
94    pub config_resource: Box<dyn ResourceProvider + Send>,
95    /// Vocab resource (default: pretrained BART model on CNN-DM)
96    pub vocab_resource: Box<dyn ResourceProvider + Send>,
97    /// Merges resource (default: pretrained BART model on CNN-DM)
98    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
99    /// Minimum sequence length (default: 0)
100    pub min_length: i64,
101    /// Maximum sequence length (default: 20)
102    pub max_length: Option<i64>,
103    /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
104    pub do_sample: bool,
105    /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
106    pub early_stopping: bool,
107    /// Number of beams for beam search (default: 5)
108    pub num_beams: i64,
109    /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
110    pub temperature: f64,
111    /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
112    pub top_k: i64,
113    /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
114    pub top_p: f64,
115    /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
116    pub repetition_penalty: f64,
117    /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
118    pub length_penalty: f64,
119    /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
120    pub no_repeat_ngram_size: i64,
121    /// Number of sequences to return for each prompt text (default: 1)
122    pub num_return_sequences: i64,
123    /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
124    pub num_beam_groups: Option<i64>,
125    /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
126    pub diversity_penalty: Option<f64>,
127    /// Device to place the model on (default: CUDA/GPU when available)
128    pub device: Device,
129    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
130    pub kind: Option<Kind>,
131}
132
133impl SummarizationConfig {
134    /// Instantiate a new summarization configuration of the supplied type.
135    ///
136    /// # Arguments
137    ///
138    /// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
139    /// * model_resource - The `ModelResources` pointing to the model to load (e.g.  model.ot)
140    /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
141    /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g.  vocab.txt/vocab.json)
142    /// * merges_resource - The `ResourceProvider`  pointing to the tokenizer's merge file or SentencePiece model to load (e.g.  merges.txt).
143    pub fn new<RC, RV>(
144        model_type: ModelType,
145        model_resource: ModelResource,
146        config_resource: RC,
147        vocab_resource: RV,
148        merges_resource: Option<RV>,
149    ) -> SummarizationConfig
150    where
151        RC: ResourceProvider + Send + 'static,
152        RV: ResourceProvider + Send + 'static,
153    {
154        SummarizationConfig {
155            model_type,
156            model_resource,
157            config_resource: Box::new(config_resource),
158            vocab_resource: Box::new(vocab_resource),
159            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
160            min_length: 56,
161            max_length: Some(142),
162            do_sample: false,
163            early_stopping: true,
164            num_beams: 3,
165            temperature: 1.0,
166            top_k: 50,
167            top_p: 1.0,
168            repetition_penalty: 1.0,
169            length_penalty: 1.0,
170            no_repeat_ngram_size: 3,
171            num_return_sequences: 1,
172            num_beam_groups: None,
173            diversity_penalty: None,
174            device: Device::cuda_if_available(),
175            kind: None,
176        }
177    }
178}
179
180#[cfg(feature = "remote")]
181impl Default for SummarizationConfig {
182    fn default() -> SummarizationConfig {
183        SummarizationConfig::new(
184            ModelType::Bart,
185            ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
186                BartModelResources::BART_CNN,
187            ))),
188            RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
189            RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
190            Some(RemoteResource::from_pretrained(
191                BartMergesResources::BART_CNN,
192            )),
193        )
194    }
195}
196
197impl From<SummarizationConfig> for GenerateConfig {
198    fn from(config: SummarizationConfig) -> GenerateConfig {
199        GenerateConfig {
200            model_type: config.model_type,
201            model_resource: config.model_resource,
202            config_resource: config.config_resource,
203            merges_resource: config.merges_resource,
204            vocab_resource: config.vocab_resource,
205            min_length: config.min_length,
206            max_length: config.max_length,
207            do_sample: config.do_sample,
208            early_stopping: config.early_stopping,
209            num_beams: config.num_beams,
210            temperature: config.temperature,
211            top_k: config.top_k,
212            top_p: config.top_p,
213            repetition_penalty: config.repetition_penalty,
214            length_penalty: config.length_penalty,
215            no_repeat_ngram_size: config.no_repeat_ngram_size,
216            num_return_sequences: config.num_return_sequences,
217            num_beam_groups: config.num_beam_groups,
218            diversity_penalty: config.diversity_penalty,
219            device: config.device,
220            kind: config.kind,
221        }
222    }
223}
224
225/// # Abstraction that holds one particular summarization model, for any of the supported models
226pub enum SummarizationOption {
227    /// Summarizer based on BART model
228    Bart(BartGenerator),
229    /// Summarizer based on T5 model
230    T5(T5Generator),
231    /// Summarizer based on LongT5 model
232    LongT5(LongT5Generator),
233    /// Summarizer based on ProphetNet model
234    ProphetNet(ProphetNetConditionalGenerator),
235    /// Summarizer based on Pegasus model
236    Pegasus(PegasusConditionalGenerator),
237    /// Summarizer based on ONNX model
238    #[cfg(feature = "onnx")]
239    ONNX(ONNXConditionalGenerator),
240}
241
242impl SummarizationOption {
243    pub fn new(config: SummarizationConfig) -> Result<Self, RustBertError> {
244        match (config.model_type, &config.model_resource) {
245            #[cfg(feature = "onnx")]
246            (_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
247                ONNXConditionalGenerator::new(config.into(), None, None)?,
248            )),
249            (ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
250                config.into(),
251            )?)),
252            (ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
253            (ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(LongT5Generator::new(
254                config.into(),
255            )?)),
256            (ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
257                ProphetNetConditionalGenerator::new(config.into())?,
258            )),
259            (ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
260                PegasusConditionalGenerator::new(config.into())?,
261            )),
262            _ => Err(RustBertError::InvalidConfigurationError(format!(
263                "Summarization not implemented for {:?}!",
264                config.model_type
265            ))),
266        }
267    }
268
269    pub fn new_with_tokenizer(
270        config: SummarizationConfig,
271        tokenizer: TokenizerOption,
272    ) -> Result<Self, RustBertError> {
273        match (config.model_type, &config.model_resource) {
274            #[cfg(feature = "onnx")]
275            (_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
276                ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
277            )),
278            (ModelType::Bart, _) => Ok(SummarizationOption::Bart(
279                BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
280            )),
281            (ModelType::T5, _) => Ok(SummarizationOption::T5(T5Generator::new_with_tokenizer(
282                config.into(),
283                tokenizer,
284            )?)),
285            (ModelType::LongT5, _) => Ok(SummarizationOption::LongT5(
286                LongT5Generator::new_with_tokenizer(config.into(), tokenizer)?,
287            )),
288            (ModelType::ProphetNet, _) => Ok(SummarizationOption::ProphetNet(
289                ProphetNetConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
290            )),
291            (ModelType::Pegasus, _) => Ok(SummarizationOption::Pegasus(
292                PegasusConditionalGenerator::new_with_tokenizer(config.into(), tokenizer)?,
293            )),
294            _ => Err(RustBertError::InvalidConfigurationError(format!(
295                "Summarization not implemented for {:?}!",
296                config.model_type
297            ))),
298        }
299    }
300
301    /// Returns the `ModelType` for this SummarizationOption
302    pub fn model_type(&self) -> ModelType {
303        match *self {
304            Self::Bart(_) => ModelType::Bart,
305            Self::T5(_) => ModelType::T5,
306            Self::LongT5(_) => ModelType::LongT5,
307            Self::ProphetNet(_) => ModelType::ProphetNet,
308            Self::Pegasus(_) => ModelType::Pegasus,
309            #[cfg(feature = "onnx")]
310            Self::ONNX(_) => ModelType::ONNX,
311        }
312    }
313
314    /// Interface method to access tokenizer
315    pub fn get_tokenizer(&self) -> &TokenizerOption {
316        match self {
317            Self::Bart(model_ref) => model_ref.get_tokenizer(),
318            Self::T5(model_ref) => model_ref.get_tokenizer(),
319            Self::LongT5(model_ref) => model_ref.get_tokenizer(),
320            Self::ProphetNet(model_ref) => model_ref.get_tokenizer(),
321            Self::Pegasus(model_ref) => model_ref.get_tokenizer(),
322            #[cfg(feature = "onnx")]
323            Self::ONNX(model_ref) => model_ref.get_tokenizer(),
324        }
325    }
326
327    /// Interface method to access tokenizer
328    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
329        match self {
330            Self::Bart(model_ref) => model_ref.get_tokenizer_mut(),
331            Self::T5(model_ref) => model_ref.get_tokenizer_mut(),
332            Self::LongT5(model_ref) => model_ref.get_tokenizer_mut(),
333            Self::ProphetNet(model_ref) => model_ref.get_tokenizer_mut(),
334            Self::Pegasus(model_ref) => model_ref.get_tokenizer_mut(),
335            #[cfg(feature = "onnx")]
336            Self::ONNX(model_ref) => model_ref.get_tokenizer_mut(),
337        }
338    }
339
340    /// Interface method to generate() of the particular models.
341    pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Result<Vec<String>, RustBertError>
342    where
343        S: AsRef<str> + Send + Sync,
344    {
345        Ok(match *self {
346            Self::Bart(ref model) => model
347                .generate(prompt_texts, None)?
348                .into_iter()
349                .map(|output| output.text)
350                .collect(),
351            Self::T5(ref model) => model
352                .generate(prompt_texts, None)?
353                .into_iter()
354                .map(|output| output.text)
355                .collect(),
356            Self::LongT5(ref model) => model
357                .generate(prompt_texts, None)?
358                .into_iter()
359                .map(|output| output.text)
360                .collect(),
361            Self::ProphetNet(ref model) => model
362                .generate(prompt_texts, None)?
363                .into_iter()
364                .map(|output| output.text)
365                .collect(),
366            Self::Pegasus(ref model) => model
367                .generate(prompt_texts, None)?
368                .into_iter()
369                .map(|output| output.text)
370                .collect(),
371            #[cfg(feature = "onnx")]
372            Self::ONNX(ref model) => model
373                .generate(prompt_texts, None)?
374                .into_iter()
375                .map(|output| output.text)
376                .collect(),
377        })
378    }
379}
380
381/// # SummarizationModel to perform summarization
382pub struct SummarizationModel {
383    model: SummarizationOption,
384    prefix: Option<String>,
385}
386
387impl SummarizationModel {
388    /// Build a new `SummarizationModel`
389    ///
390    /// # Arguments
391    ///
392    /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU)
393    ///
394    /// # Example
395    ///
396    /// ```no_run
397    /// # fn main() -> anyhow::Result<()> {
398    /// use rust_bert::pipelines::summarization::SummarizationModel;
399    ///
400    /// let mut summarization_model = SummarizationModel::new(Default::default())?;
401    /// # Ok(())
402    /// # }
403    /// ```
404    pub fn new(
405        summarization_config: SummarizationConfig,
406    ) -> Result<SummarizationModel, RustBertError> {
407        let prefix = match summarization_config.model_type {
408            ModelType::T5 => Some("summarize: ".to_string()),
409            _ => None,
410        };
411        let model = SummarizationOption::new(summarization_config)?;
412
413        Ok(SummarizationModel { model, prefix })
414    }
415
416    /// Build a new `SummarizationModel` with a provided tokenizer.
417    ///
418    /// # Arguments
419    ///
420    /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU)
421    /// * `tokenizer` - `TokenizerOption` tokenizer to use for summarization.
422    ///
423    /// # Example
424    ///
425    /// ```no_run
426    /// # fn main() -> anyhow::Result<()> {
427    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
428    /// use rust_bert::pipelines::summarization::SummarizationModel;
429    /// let tokenizer = TokenizerOption::from_file(
430    ///     ModelType::Bart,
431    ///     "path/to/vocab.json",
432    ///     Some("path/to/merges.txt"),
433    ///     false,
434    ///     None,
435    ///     None,
436    /// )?;
437    /// let mut summarization_model =
438    ///     SummarizationModel::new_with_tokenizer(Default::default(), tokenizer)?;
439    /// # Ok(())
440    /// # }
441    /// ```
442    pub fn new_with_tokenizer(
443        summarization_config: SummarizationConfig,
444        tokenizer: TokenizerOption,
445    ) -> Result<SummarizationModel, RustBertError> {
446        let prefix = match summarization_config.model_type {
447            ModelType::T5 => Some("summarize: ".to_string()),
448            _ => None,
449        };
450        let model = SummarizationOption::new_with_tokenizer(summarization_config, tokenizer)?;
451
452        Ok(SummarizationModel { model, prefix })
453    }
454
455    /// Get a reference to the model tokenizer.
456    pub fn get_tokenizer(&self) -> &TokenizerOption {
457        self.model.get_tokenizer()
458    }
459
460    /// Get a mutable reference to the model tokenizer.
461    pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
462        self.model.get_tokenizer_mut()
463    }
464
465    /// Summarize texts provided
466    ///
467    /// # Arguments
468    ///
469    /// * `input` - `&[&str]` Array of texts to summarize.
470    ///
471    /// # Returns
472    /// * `Vec<String>` Summarized texts
473    ///
474    /// # Example
475    ///
476    /// ```no_run
477    /// # fn main() -> anyhow::Result<()> {
478    /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
479    /// use rust_bert::pipelines::summarization::SummarizationModel;
480    /// let model = SummarizationModel::new(Default::default())?;
481    ///
482    /// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
483    /// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
484    /// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
485    /// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
486    /// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
487    /// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
488    /// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
489    /// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
490    /// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
491    /// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
492    /// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
493    /// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
494    /// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
495    /// said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
496    /// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
497    /// a potentially habitable planet, but further observations will be required to say for sure. \"
498    /// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
499    /// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
500    /// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
501    /// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
502    /// about exoplanets like K2-18b."];
503    ///
504    /// let output = model.summarize(&input);
505    /// # Ok(())
506    /// # }
507    /// ```
508    /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
509    pub fn summarize<S>(&self, texts: &[S]) -> Result<Vec<String>, RustBertError>
510    where
511        S: AsRef<str> + Send + Sync,
512    {
513        match &self.prefix {
514            None => self.model.generate(Some(texts)),
515            Some(prefix) => {
516                let texts = texts
517                    .iter()
518                    .map(|text| format!("{}{}", prefix, text.as_ref()))
519                    .collect::<Vec<String>>();
520                self.model.generate(Some(&texts))
521            }
522        }
523    }
524}
525
526#[cfg(test)]
527mod test {
528    use super::*;
529
530    #[test]
531    #[ignore] // no need to run, compilation is enough to verify it is Send
532    fn test() {
533        let config = SummarizationConfig::default();
534        let _: Box<dyn Send> = Box::new(SummarizationModel::new(config));
535    }
536}