rust_bert/pipelines/
conversation.rs

1// Copyright 2019-present Microsoft
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Multi-turn dialogue
15//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
16//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
17//! The DialoGPT's page states that
18//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
19//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
20//!
21//!
22//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/dialgpt-medium
23//! The following illustrates how to run a 2-turns conversation using a conversation manager:
24//! ```no_run
25//! # fn main() -> anyhow::Result<()> {
26//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
27//! let conversation_model = ConversationModel::new(Default::default())?;
28//! let mut conversation_manager = ConversationManager::new();
29//!
30//! let conversation_id =
31//!     conversation_manager.create("Going to the movies tonight - any suggestions?");
32//! let output = conversation_model.generate_responses(&mut conversation_manager);
33//!
34//! let _ = conversation_manager
35//!     .get(&conversation_id)
36//!     .unwrap()
37//!     .add_user_input("Is it an action movie?")?;
38//!
39//! let output = conversation_model.generate_responses(&mut conversation_manager);
40//!
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Example output: \
46//! ```no_run
47//! # let output = [
48//! "{a0cb3c15-9a5a-4a34-958d-95eddac0215a: \"The Big Lebowski\"}",
49//! "{a0cb3c15-9a5a-4a34-958d-95eddac0215a: \"It's a comedy.\"}"
50//! # ];
51//! ```
52//!
53//! # Disclaimer
54//! The authors of this repository are not responsible for any generation
55//! from the 3rd party utilization of the pretrained system.
56use crate::common::error::RustBertError;
57use crate::gpt2::GPT2Generator;
58use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
59use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
60use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
61use crate::resources::ResourceProvider;
62use std::collections::HashMap;
63use tch::{Device, Kind, Tensor};
64use uuid::Uuid;
65
66#[cfg(feature = "remote")]
67use crate::{
68    gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
69    resources::RemoteResource,
70};
71
72/// # Configuration for multi-turn classification
73/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
74/// different set of default parameters and sets the device to place the model on.
75pub struct ConversationConfig {
76    /// Model type
77    pub model_type: ModelType,
78    /// Model weights resource (default: DialoGPT-medium)
79    pub model_resource: ModelResource,
80    /// Config resource (default: DialoGPT-medium)
81    pub config_resource: Box<dyn ResourceProvider + Send>,
82    /// Vocab resource (default: DialoGPT-medium)
83    pub vocab_resource: Box<dyn ResourceProvider + Send>,
84    /// Merges resource (default: DialoGPT-medium)
85    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
86    /// Minimum sequence length (default: 0)
87    pub min_length: i64,
88    /// Maximum sequence length (default: 20)
89    pub max_length: Option<i64>,
90    /// Minimum free length available for generated responses (default: 32)
91    pub min_length_for_response: i64,
92    /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
93    pub do_sample: bool,
94    /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
95    pub early_stopping: bool,
96    /// Number of beams for beam search (default: 5)
97    pub num_beams: i64,
98    /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
99    pub temperature: f64,
100    /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
101    pub top_k: i64,
102    /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
103    pub top_p: f64,
104    /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
105    pub repetition_penalty: f64,
106    /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
107    pub length_penalty: f64,
108    /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
109    pub no_repeat_ngram_size: i64,
110    /// Number of sequences to return for each prompt text (default: 1)
111    pub num_return_sequences: i64,
112    /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
113    pub num_beam_groups: Option<i64>,
114    /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
115    pub diversity_penalty: Option<f64>,
116    /// Device to place the model on (default: CUDA/GPU when available)
117    pub device: Device,
118    /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
119    pub kind: Option<Kind>,
120}
121
122#[cfg(feature = "remote")]
123impl Default for ConversationConfig {
124    fn default() -> ConversationConfig {
125        ConversationConfig {
126            model_type: ModelType::GPT2,
127            model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
128                Gpt2ModelResources::DIALOGPT_MEDIUM,
129            ))),
130            config_resource: Box::new(RemoteResource::from_pretrained(
131                Gpt2ConfigResources::DIALOGPT_MEDIUM,
132            )),
133            vocab_resource: Box::new(RemoteResource::from_pretrained(
134                Gpt2VocabResources::DIALOGPT_MEDIUM,
135            )),
136            merges_resource: Some(Box::new(RemoteResource::from_pretrained(
137                Gpt2MergesResources::DIALOGPT_MEDIUM,
138            ))),
139            min_length: 0,
140            max_length: Some(1000),
141            min_length_for_response: 64,
142            do_sample: true,
143            early_stopping: false,
144            num_beams: 1,
145            temperature: 1.0,
146            top_k: 50,
147            top_p: 0.9,
148            repetition_penalty: 1.0,
149            length_penalty: 1.0,
150            no_repeat_ngram_size: 0,
151            num_return_sequences: 1,
152            num_beam_groups: None,
153            diversity_penalty: None,
154            device: Device::cuda_if_available(),
155            kind: None,
156        }
157    }
158}
159
160impl From<ConversationConfig> for GenerateConfig {
161    fn from(config: ConversationConfig) -> GenerateConfig {
162        GenerateConfig {
163            model_type: config.model_type,
164            model_resource: config.model_resource,
165            config_resource: config.config_resource,
166            merges_resource: config.merges_resource,
167            vocab_resource: config.vocab_resource,
168            min_length: config.min_length,
169            max_length: config.max_length,
170            do_sample: config.do_sample,
171            early_stopping: config.early_stopping,
172            num_beams: config.num_beams,
173            temperature: config.temperature,
174            top_k: config.top_k,
175            top_p: config.top_p,
176            repetition_penalty: config.repetition_penalty,
177            length_penalty: config.length_penalty,
178            no_repeat_ngram_size: config.no_repeat_ngram_size,
179            num_return_sequences: config.num_return_sequences,
180            num_beam_groups: config.num_beam_groups,
181            diversity_penalty: config.diversity_penalty,
182            device: config.device,
183            kind: config.kind,
184        }
185    }
186}
187
188#[derive(Debug, Clone)]
189/// Data structure keeping track of a conversation in the system. It contains past user inputs and
190/// generated answers, a history of the tokens generated and a placeholder for new user inputs to be
191/// processed by the system if submitted for prediction
192pub struct Conversation {
193    /// Past user inputs that have already been processed
194    pub past_user_inputs: Vec<String>,
195    /// Past system generated responses
196    pub generated_responses: Vec<String>,
197    /// New user input that needs to be processed
198    pub new_user_input: Option<String>,
199    ///  History of the tokens passed as an input and generated so far used as context for next turn generation
200    pub history: Vec<Vec<i64>>,
201}
202
203impl Conversation {
204    /// Build a new `Conversation` with an initial user input
205    ///
206    /// # Arguments
207    ///
208    /// * `text` - `String` with the initial user input to start a conversation
209    ///
210    /// # Example
211    ///
212    /// ```no_run
213    /// use rust_bert::pipelines::conversation::Conversation;
214    ///
215    /// let conversation = Conversation::new("Hi there!");
216    /// ```
217    pub fn new(text: &str) -> Conversation {
218        Conversation {
219            past_user_inputs: vec![],
220            generated_responses: vec![],
221            new_user_input: Some(text.to_string()),
222            history: vec![],
223        }
224    }
225
226    /// Build a new `Conversation` placeholder without user input
227    ///
228    /// # Example
229    ///
230    /// ```no_run
231    /// use rust_bert::pipelines::conversation::Conversation;
232    ///
233    /// let conversation = Conversation::new_empty();
234    /// ```
235    pub fn new_empty() -> Conversation {
236        Conversation {
237            past_user_inputs: vec![],
238            generated_responses: vec![],
239            new_user_input: None,
240            history: vec![],
241        }
242    }
243
244    /// Adds a new user input to the conversation. This method returns an error if an unprocessed
245    /// user input already exists
246    ///
247    /// # Arguments
248    ///
249    /// * `text` - `&str` with the additional user input to continue a conversation
250    ///
251    /// # Example
252    ///
253    /// ```no_run
254    /// use rust_bert::pipelines::conversation::Conversation;
255    ///
256    /// let mut conversation = Conversation::new_empty();
257    /// conversation.add_user_input("Hi there!").unwrap();
258    /// ```
259    pub fn add_user_input(&mut self, text: &str) -> Result<(), RustBertError> {
260        if self.new_user_input.is_some() {
261            Err(RustBertError::ValueError(
262                "User input already provided for this conversation".into(),
263            ))
264        } else {
265            self.new_user_input = Some(text.to_string());
266            Ok(())
267        }
268    }
269
270    /// Adds a new user input to the conversation. If an unprocessed user input already exists,
271    /// its contents are overwritten by the new value provided.
272    ///
273    /// # Arguments
274    ///
275    /// * `text` - `&str` with the additional user input to continue a conversation
276    ///
277    /// # Returns
278    ///
279    /// * `Option<String>` containing overwritten string if applicable
280    ///
281    /// # Example
282    ///
283    /// ```no_run
284    /// use rust_bert::pipelines::conversation::Conversation;
285    ///
286    /// let mut conversation = Conversation::new_empty();
287    /// conversation
288    ///     .add_user_input("This input will not be used")
289    ///     .unwrap();
290    /// let unused_string = conversation.add_user_input_with_overwrite("Hi there!");
291    /// ```
292    pub fn add_user_input_with_overwrite(&mut self, text: &str) -> Option<String> {
293        let old_user_input = if self.new_user_input.is_some() {
294            self.new_user_input.clone()
295        } else {
296            None
297        };
298        self.new_user_input = Some(text.to_string());
299        old_user_input
300    }
301
302    /// Returns `true` if the conversation contains new user inputs to process
303    ///
304    /// # Returns
305    ///
306    /// * `bool` flag indicating if the conversation contains new inputs to process
307    ///
308    /// # Example
309    ///
310    /// ```no_run
311    /// use rust_bert::pipelines::conversation::Conversation;
312    ///
313    /// let mut conversation = Conversation::new_empty();
314    /// let false_value = conversation.contains_new_input();
315    /// conversation
316    ///     .add_user_input("This input will not be used")
317    ///     .unwrap();
318    /// let true_value = conversation.contains_new_input();
319    /// ```
320    pub fn contains_new_input(&self) -> bool {
321        self.new_user_input.is_some()
322    }
323
324    /// Marks the conversation as processed and moves the user input that was up for
325    /// processing to the past user inputs.
326    ///
327    /// # Example
328    ///
329    /// ```no_run
330    /// use rust_bert::pipelines::conversation::Conversation;
331    ///
332    /// let mut conversation = Conversation::new_empty();
333    /// let false_value = conversation.contains_new_input();
334    /// conversation
335    ///     .add_user_input("This input will not be used")
336    ///     .unwrap();
337    /// let true_value = conversation.contains_new_input();
338    /// conversation.mark_processed();
339    /// let false_value = conversation.contains_new_input();
340    /// assert_eq!(conversation.past_user_inputs.len(), 1usize);
341    /// ```
342    pub fn mark_processed(&mut self) {
343        if self.new_user_input.is_some() {
344            self.past_user_inputs
345                .push(self.new_user_input.clone().unwrap());
346            self.new_user_input = None;
347        }
348    }
349
350    /// Returns the last user input provided (including non-processed inputs).
351    ///
352    /// # Returns
353    ///
354    /// * `Option<&str>` representation of the last user input provided
355    ///
356    /// # Example
357    ///
358    /// ```no_run
359    /// use rust_bert::pipelines::conversation::Conversation;
360    ///
361    /// let mut conversation = Conversation::new_empty();
362    /// let none_value = conversation.get_last_input();
363    /// conversation
364    ///     .add_user_input("This input will not be used")
365    ///     .unwrap();
366    /// let last_provided_input = conversation.get_last_input();
367    /// assert_eq!(last_provided_input, Some("This input will not be used"));
368    /// ```
369    pub fn get_last_input(&self) -> Option<&str> {
370        if self.new_user_input.is_some() {
371            Some(self.new_user_input.as_ref().unwrap().as_str())
372        } else if !self.past_user_inputs.is_empty() {
373            Some(self.past_user_inputs.last().unwrap().as_str())
374        } else {
375            None
376        }
377    }
378
379    /// Returns the last response generated by the system.
380    ///
381    /// # Returns
382    ///
383    /// * `Option<&str>` representation of the last response generated by the system.
384    ///
385    /// # Example
386    ///
387    /// ```no_run
388    /// use rust_bert::pipelines::conversation::Conversation;
389    ///
390    /// let mut conversation = Conversation::new("Hi There");
391    /// let non_value = conversation.get_last_response();
392    /// ```
393    pub fn get_last_response(&self) -> Option<&str> {
394        if !self.generated_responses.is_empty() {
395            Some(self.generated_responses.last().unwrap().as_str())
396        } else {
397            None
398        }
399    }
400
401    fn append(&mut self, text: &str, ids: &[i64]) {
402        match &self.new_user_input {
403            Some(_) => {
404                self.mark_processed();
405                if self.past_user_inputs.len() >= self.generated_responses.len() {
406                    self.generated_responses.push(text.to_string());
407                } else {
408                    let _ = self.add_user_input(text);
409                }
410            }
411            None => {
412                let _ = self.add_user_input(text);
413            }
414        }
415        self.history.push(ids.to_vec());
416    }
417
418    /// Initializes a conversation form a prior state. It is assumed that a conversation always
419    /// start from a user interaction.
420    ///
421    /// # Arguments
422    /// - texts: sequence of strings, alternating between past user inputs and past generated responses.
423    /// - ids: sequence of sequence of ids, alternating between past user inputs and past generated responses.
424    ///
425    /// These can be generated via a `ConversationModel`'s `encode_prompts`.
426    ///
427    /// # Example:
428    ///
429    /// ```no_run
430    /// # fn main() -> anyhow::Result<()> {
431    /// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
432    /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
433    /// let model = ConversationModel::new(Default::default())?;
434    ///
435    /// let mut conversation_manager = ConversationManager::new();
436    /// let history = [
437    ///     "Going to the movies tonight - any suggestions?",
438    ///     "The Big Lebowski",
439    ///     "Is it an action movie?",
440    /// ];
441    /// let encoded_history = model.encode_prompts(&history);
442    ///
443    /// let conversation_1_id = conversation_manager.create_empty();
444    /// let _ = conversation_manager
445    ///     .get(&conversation_1_id)
446    ///     .unwrap()
447    ///     .load_from_history(&history, &encoded_history);
448    /// # Ok(())
449    /// # }
450    /// ```
451    pub fn load_from_history<S, SI>(&mut self, texts: &[S], ids: &[SI])
452    where
453        S: AsRef<str>,
454        SI: AsRef<[i64]>,
455    {
456        for (round_text, round_ids) in texts.iter().zip(ids.iter()) {
457            self.append(round_text.as_ref(), round_ids.as_ref());
458        }
459
460        if texts.len() / 2 == 1 {
461            self.history.pop();
462        }
463    }
464}
465
466/// Data structure allowing the management of conversations and main input to the dialogue model.
467/// It contains a `HashMap` of conversations with `UUID` keys
468#[derive(Debug)]
469pub struct ConversationManager {
470    conversations: HashMap<Uuid, Conversation>,
471}
472
473impl ConversationManager {
474    /// Build a new `ConversationManager`
475    ///
476    /// # Example
477    ///
478    /// ```no_run
479    /// use rust_bert::pipelines::conversation::ConversationManager;
480    ///
481    /// let conversation_manager = ConversationManager::new();
482    /// ```
483    pub fn new() -> ConversationManager {
484        ConversationManager {
485            conversations: HashMap::new(),
486        }
487    }
488
489    /// Returns a list of the active conversations (containing new inputs to be processed by the model)
490    ///
491    /// # Returns
492    ///
493    /// * `(Vec<&Uuid>, Vec<&mut Conversation>)` Tuple of vectors with the active `UUID` and `Conversations`
494    ///
495    /// # Example
496    ///
497    /// ```no_run
498    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
499    ///
500    /// let mut conversation_manager = ConversationManager::new();
501    ///
502    /// let conversation = Conversation::new("Hi there!");
503    /// let empty_conversation = Conversation::new_empty();
504    /// let conversation_id = conversation_manager.add(conversation);
505    /// let empty_conversation_id = conversation_manager.add(empty_conversation);
506    ///
507    /// let active_conversations = conversation_manager.get_active_conversations();
508    /// assert_eq!(active_conversations.0.len(), 1usize);
509    /// ```
510    pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
511        let mut active_uuid = vec![];
512        let mut active_conversations = vec![];
513        for (uuid, conversation) in self.conversations.iter_mut() {
514            if conversation.new_user_input.is_some() {
515                active_uuid.push(uuid);
516                active_conversations.push(conversation)
517            }
518        }
519        (active_uuid, active_conversations)
520    }
521
522    /// Returns a mutable reference to the conversation wih the provided UUID
523    ///
524    /// # Arguments
525    ///
526    /// * `uuid` - `&Uuid` of the conversation to retrieve
527    ///
528    /// # Returns
529    ///
530    /// * `Option<&mut Conversation>` Optional mutable reference to the conversation matching the UUID provided
531    ///
532    /// # Example
533    ///
534    /// ```no_run
535    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
536    ///
537    /// let mut conversation_manager = ConversationManager::new();
538    ///
539    /// let conversation = Conversation::new("Hi there!");
540    /// let conversation_id = conversation_manager.add(conversation);
541    ///
542    /// let conversation_ref = conversation_manager.get(&conversation_id);
543    /// ```
544    pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
545        self.conversations.get_mut(uuid)
546    }
547
548    /// Returns a HashMap containing references to all conversations stored in the manager
549    ///
550    /// # Example
551    ///
552    /// ```no_run
553    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
554    ///
555    /// let mut conversation_manager = ConversationManager::new();
556    ///
557    /// let conversation = Conversation::new("Hi there!");
558    /// let conversation_id = conversation_manager.add(conversation);
559    ///
560    /// let all_conversations = conversation_manager.get_all();
561    /// ```
562    pub fn get_all(&mut self) -> HashMap<&Uuid, &Conversation> {
563        let mut output = HashMap::with_capacity(self.conversations.len());
564        for (uuid, conversation) in self.conversations.iter() {
565            output.insert(uuid, conversation);
566        }
567        output
568    }
569
570    /// Creates a conversation and add it to the conversation manager
571    ///
572    /// # Arguments
573    ///
574    /// * `text` - `&str` string slice with an original user input
575    ///
576    /// # Returns
577    ///
578    /// * `Uuid` for the conversation created
579    ///
580    /// # Example
581    ///
582    /// ```no_run
583    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
584    ///
585    /// let mut conversation_manager = ConversationManager::new();
586    ///
587    /// let conversation_id = conversation_manager.create("Hi there!");
588    /// ```
589    pub fn create(&mut self, text: &str) -> Uuid {
590        let conversation = Conversation::new(text);
591        self.add(conversation)
592    }
593
594    /// Creates an empty conversation and add it to the conversation manager
595    ///
596    /// # Returns
597    ///
598    /// * `Uuid` for the conversation created
599    ///
600    /// # Example
601    ///
602    /// ```no_run
603    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
604    ///
605    /// let mut conversation_manager = ConversationManager::new();
606    ///
607    /// let conversation_id = conversation_manager.create_empty();
608    /// ```
609    pub fn create_empty(&mut self) -> Uuid {
610        let conversation = Conversation::new_empty();
611        self.add(conversation)
612    }
613
614    /// Adds an existing conversation to the conversation manager
615    ///
616    /// # Arguments
617    ///
618    /// * `conversation` - `Conversation` to be added to the conversation manager
619    ///
620    /// # Returns
621    ///
622    /// * `Uuid` for the conversation created
623    ///
624    /// # Example
625    ///
626    /// ```no_run
627    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
628    ///
629    /// let mut conversation_manager = ConversationManager::new();
630    ///
631    /// let conversation = Conversation::new("Hi there!");
632    /// let conversation_id = conversation_manager.add(conversation);
633    /// ```
634    pub fn add(&mut self, conversation: Conversation) -> Uuid {
635        let mut uuid = Uuid::new_v4();
636        while self.conversations.contains_key(&uuid) {
637            uuid = Uuid::new_v4();
638        }
639        self.conversations.insert(uuid, conversation);
640        uuid
641    }
642
643    /// Deregister a conversation from the conversation manager
644    ///
645    /// # Arguments
646    ///
647    /// * `uuid` - `&Uuid` of the conversation to deregister from the conversation manager
648    ///
649    /// # Returns
650    ///
651    /// * `Option<Conversation>` de-registered conversation
652    ///
653    /// # Example
654    ///
655    /// ```no_run
656    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
657    ///
658    /// let mut conversation_manager = ConversationManager::new();
659    ///
660    /// let conversation_id = conversation_manager.create("Hi there!");
661    /// conversation_manager.remove(&conversation_id);
662    /// ```
663    pub fn remove(&mut self, uuid: &Uuid) -> Option<Conversation> {
664        self.conversations.remove(uuid)
665    }
666
667    /// Clear all conversations from the conversation manager, and returns the conversations and their
668    /// former UUID.
669    ///
670    /// # Returns
671    ///
672    /// * `HashMap<Uuid, Conversation>` de-registered conversations
673    ///
674    /// # Example
675    ///
676    /// ```no_run
677    /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
678    ///
679    /// let mut conversation_manager = ConversationManager::new();
680    ///
681    /// let conversation_id = conversation_manager.create("Hi there!");
682    /// let conversations = conversation_manager.clear();
683    /// ```
684    pub fn clear(&mut self) -> HashMap<Uuid, Conversation> {
685        let mut output = HashMap::with_capacity(self.conversations.len());
686        for (uuid, conversation) in self.conversations.iter() {
687            output.insert(*uuid, conversation.clone());
688        }
689        self.conversations = HashMap::new();
690        output
691    }
692}
693
694impl Default for ConversationManager {
695    fn default() -> Self {
696        Self::new()
697    }
698}
699
700/// # Abstraction that holds one particular conversation model, for any of the supported models
701pub enum ConversationOption {
702    /// Conversation based on GPT2 model
703    GPT2(GPT2Generator),
704}
705
706impl ConversationOption {
707    pub fn new(config: ConversationConfig) -> Result<Self, RustBertError> {
708        match config.model_type {
709            ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)),
710            _ => Err(RustBertError::InvalidConfigurationError(
711                "GPT2 is currently the only supported model for conversation generation"
712                    .to_string(),
713            )),
714        }
715    }
716
717    pub fn new_with_tokenizer(
718        config: ConversationConfig,
719        tokenizer: TokenizerOption,
720    ) -> Result<Self, RustBertError> {
721        match config.model_type {
722            ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new_with_tokenizer(
723                config.into(),
724                tokenizer,
725            )?)),
726            _ => Err(RustBertError::InvalidConfigurationError(
727                "GPT2 is currently the only supported model for conversation generation"
728                    .to_string(),
729            )),
730        }
731    }
732
733    pub fn get_eos_id(&self) -> Result<i64, RustBertError> {
734        match self {
735            Self::GPT2(model_ref) => {
736                Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
737            }
738        }
739    }
740
741    /// Get a reference to the model tokenizer.
742    pub fn get_tokenizer(&self) -> &TokenizerOption {
743        match self {
744            Self::GPT2(model_ref) => model_ref._get_tokenizer(),
745        }
746    }
747
748    /// Get a mutable reference to the model tokenizer.
749    pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
750        match self {
751            Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
752        }
753    }
754
755    /// Returns the `ModelType` for this ConversationOption
756    pub fn model_type(&self) -> ModelType {
757        match *self {
758            Self::GPT2(_) => ModelType::GPT2,
759        }
760    }
761
762    /// Interface method to generate_from_ids_and_past() of the particular models.
763    pub fn generate_from_ids_and_past(
764        &self,
765        input_ids: Tensor,
766        attention_mask: Option<Tensor>,
767    ) -> Result<Vec<Vec<i64>>, RustBertError> {
768        Ok(match *self {
769            Self::GPT2(ref model) => model
770                .generate_from_ids_and_past(input_ids, attention_mask, None)?
771                .into_iter()
772                .map(|output| output.indices)
773                .collect(),
774        })
775    }
776}
777
778/// # Conversation model
779/// Processes a ConversationManager and generate system responses for active conversations.
780pub struct ConversationModel {
781    model: ConversationOption,
782    eos_token_id: i64,
783    max_allowed_context_length: Option<i64>,
784    device: Device,
785}
786
787impl ConversationModel {
788    /// Build a new `ConversationModel`
789    ///
790    /// # Arguments
791    ///
792    /// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
793    ///
794    /// # Example
795    ///
796    /// ```no_run
797    /// # fn main() -> anyhow::Result<()> {
798    /// use rust_bert::pipelines::conversation::ConversationModel;
799    ///
800    /// let conversation_model = ConversationModel::new(Default::default())?;
801    /// # Ok(())
802    /// # }
803    /// ```
804    pub fn new(
805        conversation_config: ConversationConfig,
806    ) -> Result<ConversationModel, RustBertError> {
807        let max_allowed_length = conversation_config
808            .max_length
809            .map(|max_length| max_length - conversation_config.min_length_for_response);
810        let device = conversation_config.device;
811        let model = ConversationOption::new(conversation_config)?;
812        let eos_token_id = model.get_eos_id()?;
813        Ok(ConversationModel {
814            model,
815            eos_token_id,
816            max_allowed_context_length: max_allowed_length,
817            device,
818        })
819    }
820
821    /// Build a new `ConversationModel` with a provided tokenizer.
822    ///
823    /// # Arguments
824    ///
825    /// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
826    /// * `tokenizer` - `TokenizerOption` tokenizer to use for conversation
827    ///
828    /// # Example
829    ///
830    /// ```no_run
831    /// # fn main() -> anyhow::Result<()> {
832    /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
833    /// use rust_bert::pipelines::conversation::ConversationModel;
834    /// let tokenizer = TokenizerOption::from_file(
835    ///     ModelType::GPT2,
836    ///     "path/to/vocab.json",
837    ///     Some("path/to/merges.txt"),
838    ///     false,
839    ///     None,
840    ///     None,
841    /// )?;
842    /// let conversation_model = ConversationModel::new_with_tokenizer(Default::default(), tokenizer)?;
843    /// # Ok(())
844    /// # }
845    /// ```
846    pub fn new_with_tokenizer(
847        conversation_config: ConversationConfig,
848        tokenizer: TokenizerOption,
849    ) -> Result<ConversationModel, RustBertError> {
850        let max_allowed_length = conversation_config
851            .max_length
852            .map(|max_length| max_length - conversation_config.min_length_for_response);
853        let device = conversation_config.device;
854        let model = ConversationOption::new_with_tokenizer(conversation_config, tokenizer)?;
855        let eos_token_id = model.get_eos_id()?;
856        Ok(ConversationModel {
857            model,
858            eos_token_id,
859            max_allowed_context_length: max_allowed_length,
860            device,
861        })
862    }
863
864    /// Perform a multi-turn conversation based on user input
865    ///
866    /// # Arguments
867    ///
868    /// * `conversation_manager` - `&mut ConversationManager` Conversation manager keeping track of active conversations
869    ///
870    /// # Returns
871    /// * `HashMap<&Uuid, &str>` Responses from the model for each active conversation, referenced by Uuid
872    ///
873    /// # Example
874    ///
875    /// ```no_run
876    /// # fn main() -> anyhow::Result<()> {
877    /// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
878    /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
879    /// let model = ConversationModel::new(Default::default())?;
880    ///
881    /// let mut conversation_manager = ConversationManager::new();
882    /// conversation_manager.create("Hello, how are you?");
883    ///
884    /// let output = model.generate_responses(&mut conversation_manager);
885    /// # Ok(())
886    /// # }
887    /// ```
888    pub fn generate_responses<'a>(
889        &self,
890        conversation_manager: &'a mut ConversationManager,
891    ) -> Result<HashMap<&'a Uuid, &'a str>, RustBertError> {
892        let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
893        let updated_conversations = if !active_uuid.is_empty() {
894            let texts = active_conversations
895                .iter()
896                .map(|c| c.new_user_input.as_ref().unwrap().as_str())
897                .collect::<Vec<&str>>();
898
899            let history = active_conversations
900                .iter()
901                .map(|c| c.history.iter().flatten().copied().collect())
902                .collect::<Vec<Vec<i64>>>();
903
904            let prompt_ids = self.encode_prompts(texts.as_ref());
905            let (input_tensor, attention_mask) =
906                self.concat_input_history(prompt_ids.as_ref(), history);
907            let input_length = *input_tensor.size().last().unwrap() as usize;
908            let mut generated = self
909                .model
910                .generate_from_ids_and_past(input_tensor, Some(attention_mask))?;
911            let removed_padding_quantities = self.clean_padding_indices(&mut generated);
912
913            let mut output = HashMap::with_capacity(active_uuid.len());
914
915            for (
916                ((conversation, (generated_sequence, conversation_promp_ids)), uuid),
917                removed_padding,
918            ) in active_conversations
919                .into_iter()
920                .zip(generated.into_iter().zip(prompt_ids.into_iter()))
921                .zip(active_uuid.into_iter())
922                .zip(removed_padding_quantities.into_iter())
923            {
924                let generated_response = &generated_sequence[input_length - removed_padding.0..];
925                conversation
926                    .generated_responses
927                    .push(
928                        self.model
929                            .get_tokenizer()
930                            .decode(generated_response, true, true),
931                    );
932                conversation.history.push(conversation_promp_ids);
933                conversation.history.push(generated_response.to_vec());
934                conversation.mark_processed();
935                output.insert(uuid, conversation.get_last_response().unwrap());
936            }
937            output
938        } else {
939            HashMap::new()
940        };
941        Ok(updated_conversations)
942    }
943
944    fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) -> Vec<(usize, usize)> {
945        // In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
946        let pad_token = self
947            .model
948            .get_tokenizer()
949            .get_pad_id()
950            .unwrap_or(self.eos_token_id);
951        let mut removed_tokens = Vec::with_capacity(model_output.len());
952        for sequence_history in model_output {
953            let index_end = sequence_history
954                .iter()
955                .rev()
956                .position(|&r| r != pad_token)
957                .unwrap();
958            let index_start = sequence_history
959                .iter()
960                .position(|&r| r != pad_token)
961                .unwrap();
962            if index_end > 0 {
963                sequence_history.drain(sequence_history.len() - index_end + 1..);
964            }
965            sequence_history.drain(..index_start);
966            removed_tokens.push((index_start, index_end));
967        }
968        removed_tokens
969    }
970
971    fn concat_input_history(
972        &self,
973        inputs: &[Vec<i64>],
974        history: Vec<Vec<i64>>,
975    ) -> (Tensor, Tensor) {
976        // Concatenates the history token indices with new user input
977        let pad_token = self
978            .model
979            .get_tokenizer()
980            .get_pad_id()
981            .unwrap_or(self.eos_token_id);
982
983        assert_eq!(
984            inputs.len(),
985            history.len(),
986            "Length of inputs should equal length of history"
987        );
988
989        let mut concatenated_inputs = Vec::with_capacity(inputs.len());
990        for (input, history) in inputs.iter().zip(history.iter()) {
991            let mut concatenated_element = Vec::with_capacity(input.len() + history.len());
992            concatenated_element.extend_from_slice(history);
993            concatenated_element.extend_from_slice(input);
994            concatenated_inputs.push(concatenated_element);
995        }
996
997        let truncated_concatenated_inputs = concatenated_inputs
998            .iter()
999            .map(|input| match self.max_allowed_context_length {
1000                Some(max_allowed_context_length)
1001                    if input.len() > max_allowed_context_length as usize =>
1002                {
1003                    let start = self.get_truncated_input_index(
1004                        input,
1005                        max_allowed_context_length as usize,
1006                        pad_token,
1007                    );
1008                    &input[start..]
1009                }
1010                _ => input.as_slice(),
1011            })
1012            .collect::<Vec<&[i64]>>();
1013
1014        let max_len = truncated_concatenated_inputs
1015            .iter()
1016            .map(|input| input.len())
1017            .max()
1018            .unwrap();
1019
1020        let attention_mask = Tensor::ones(
1021            [inputs.len() as i64, max_len as i64],
1022            (Kind::Int8, self.device),
1023        );
1024
1025        let concatenated_inputs = truncated_concatenated_inputs
1026            .into_iter()
1027            .enumerate()
1028            .map(|(input_idx, input)| {
1029                let _ = attention_mask
1030                    .get(input_idx as i64)
1031                    .slice(0, 0, (max_len - input.len()) as i64, 1)
1032                    .fill_(0);
1033                let mut padded_input = vec![pad_token; max_len - input.len()];
1034                padded_input.extend(input);
1035                padded_input
1036            })
1037            .map(|tokens| Tensor::from_slice(&tokens).to(self.device))
1038            .collect::<Vec<Tensor>>();
1039
1040        (Tensor::stack(&concatenated_inputs, 0), attention_mask)
1041    }
1042
1043    fn get_truncated_input_index(
1044        &self,
1045        history: &[i64],
1046        max_length: usize,
1047        pad_token: i64,
1048    ) -> usize {
1049        let start_length = history.len();
1050        let eos_indices: Vec<usize> = history
1051            .iter()
1052            .enumerate()
1053            .filter(|(i, &e)| {
1054                (e == pad_token)
1055                    & (*i != start_length - 1)
1056                    & ((start_length as isize - max_length as isize - *i as isize) < 0)
1057            })
1058            .map(|(i, _)| i + 1)
1059            .collect();
1060
1061        // Return the position of the first EOS index that fits the max length requirement.
1062        // If it does not exist, no solution exists and truncate text at a non-EOS position
1063        *eos_indices.first().unwrap_or(&(start_length - max_length))
1064    }
1065
1066    /// Encodes prompts into Vectors of indices to be processed by the model. This method may be used to
1067    /// initialize the history of a conversation with a prior state.
1068    ///
1069    /// # Example:
1070    ///
1071    /// ```no_run
1072    /// # fn main() -> anyhow::Result<()> {
1073    /// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
1074    /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
1075    /// let model = ConversationModel::new(Default::default())?;
1076    /// let history = [
1077    ///     "Going to the movies tonight - any suggestions?",
1078    ///     "The Big Lebowski",
1079    ///     "Is it an action movie?",
1080    /// ];
1081    /// let encoded_history = model.encode_prompts(&history);
1082    /// # Ok(())
1083    /// # }
1084    /// ```
1085    pub fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
1086        // Encode the user prompt into token ids
1087        let tokens = self.model.get_tokenizer().tokenize_list(texts);
1088
1089        tokens
1090            .into_iter()
1091            .map(|prompt_tokens| {
1092                self.model
1093                    .get_tokenizer()
1094                    .convert_tokens_to_ids(&prompt_tokens)
1095            })
1096            .map(|mut tokens| {
1097                if let Some(max_allowed_context_length) = self.max_allowed_context_length {
1098                    tokens.truncate(max_allowed_context_length as usize - 1);
1099                }
1100                tokens.push(self.eos_token_id);
1101                tokens
1102            })
1103            .collect::<Vec<Vec<i64>>>()
1104    }
1105}
1106
1107#[cfg(test)]
1108mod test {
1109    use super::*;
1110
1111    #[test]
1112    #[ignore] // no need to run, compilation is enough to verify it is Send
1113    fn test() {
1114        let config = ConversationConfig::default();
1115        let _: Box<dyn Send> = Box::new(ConversationModel::new(config));
1116    }
1117}