rust_bert/pipelines/conversation.rs
1// Copyright 2019-present Microsoft
2// Copyright 2020-present, the HuggingFace Inc. team.
3// Copyright 2020 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7// http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! # Multi-turn dialogue
15//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
16//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
17//! The DialoGPT's page states that
18//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
19//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
20//!
21//!
22//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/dialgpt-medium
23//! The following illustrates how to run a 2-turns conversation using a conversation manager:
24//! ```no_run
25//! # fn main() -> anyhow::Result<()> {
26//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
27//! let conversation_model = ConversationModel::new(Default::default())?;
28//! let mut conversation_manager = ConversationManager::new();
29//!
30//! let conversation_id =
31//! conversation_manager.create("Going to the movies tonight - any suggestions?");
32//! let output = conversation_model.generate_responses(&mut conversation_manager);
33//!
34//! let _ = conversation_manager
35//! .get(&conversation_id)
36//! .unwrap()
37//! .add_user_input("Is it an action movie?")?;
38//!
39//! let output = conversation_model.generate_responses(&mut conversation_manager);
40//!
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Example output: \
46//! ```no_run
47//! # let output = [
48//! "{a0cb3c15-9a5a-4a34-958d-95eddac0215a: \"The Big Lebowski\"}",
49//! "{a0cb3c15-9a5a-4a34-958d-95eddac0215a: \"It's a comedy.\"}"
50//! # ];
51//! ```
52//!
53//! # Disclaimer
54//! The authors of this repository are not responsible for any generation
55//! from the 3rd party utilization of the pretrained system.
56use crate::common::error::RustBertError;
57use crate::gpt2::GPT2Generator;
58use crate::pipelines::common::{ModelResource, ModelType, TokenizerOption};
59use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
60use crate::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
61use crate::resources::ResourceProvider;
62use std::collections::HashMap;
63use tch::{Device, Kind, Tensor};
64use uuid::Uuid;
65
66#[cfg(feature = "remote")]
67use crate::{
68 gpt2::{Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources},
69 resources::RemoteResource,
70};
71
72/// # Configuration for multi-turn classification
73/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
74/// different set of default parameters and sets the device to place the model on.
75pub struct ConversationConfig {
76 /// Model type
77 pub model_type: ModelType,
78 /// Model weights resource (default: DialoGPT-medium)
79 pub model_resource: ModelResource,
80 /// Config resource (default: DialoGPT-medium)
81 pub config_resource: Box<dyn ResourceProvider + Send>,
82 /// Vocab resource (default: DialoGPT-medium)
83 pub vocab_resource: Box<dyn ResourceProvider + Send>,
84 /// Merges resource (default: DialoGPT-medium)
85 pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
86 /// Minimum sequence length (default: 0)
87 pub min_length: i64,
88 /// Maximum sequence length (default: 20)
89 pub max_length: Option<i64>,
90 /// Minimum free length available for generated responses (default: 32)
91 pub min_length_for_response: i64,
92 /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
93 pub do_sample: bool,
94 /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
95 pub early_stopping: bool,
96 /// Number of beams for beam search (default: 5)
97 pub num_beams: i64,
98 /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
99 pub temperature: f64,
100 /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
101 pub top_k: i64,
102 /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
103 pub top_p: f64,
104 /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
105 pub repetition_penalty: f64,
106 /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
107 pub length_penalty: f64,
108 /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
109 pub no_repeat_ngram_size: i64,
110 /// Number of sequences to return for each prompt text (default: 1)
111 pub num_return_sequences: i64,
112 /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
113 pub num_beam_groups: Option<i64>,
114 /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
115 pub diversity_penalty: Option<f64>,
116 /// Device to place the model on (default: CUDA/GPU when available)
117 pub device: Device,
118 /// Model weights precision. If not provided, will default to full precision on CPU, or the loaded weights precision otherwise
119 pub kind: Option<Kind>,
120}
121
122#[cfg(feature = "remote")]
123impl Default for ConversationConfig {
124 fn default() -> ConversationConfig {
125 ConversationConfig {
126 model_type: ModelType::GPT2,
127 model_resource: ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
128 Gpt2ModelResources::DIALOGPT_MEDIUM,
129 ))),
130 config_resource: Box::new(RemoteResource::from_pretrained(
131 Gpt2ConfigResources::DIALOGPT_MEDIUM,
132 )),
133 vocab_resource: Box::new(RemoteResource::from_pretrained(
134 Gpt2VocabResources::DIALOGPT_MEDIUM,
135 )),
136 merges_resource: Some(Box::new(RemoteResource::from_pretrained(
137 Gpt2MergesResources::DIALOGPT_MEDIUM,
138 ))),
139 min_length: 0,
140 max_length: Some(1000),
141 min_length_for_response: 64,
142 do_sample: true,
143 early_stopping: false,
144 num_beams: 1,
145 temperature: 1.0,
146 top_k: 50,
147 top_p: 0.9,
148 repetition_penalty: 1.0,
149 length_penalty: 1.0,
150 no_repeat_ngram_size: 0,
151 num_return_sequences: 1,
152 num_beam_groups: None,
153 diversity_penalty: None,
154 device: Device::cuda_if_available(),
155 kind: None,
156 }
157 }
158}
159
160impl From<ConversationConfig> for GenerateConfig {
161 fn from(config: ConversationConfig) -> GenerateConfig {
162 GenerateConfig {
163 model_type: config.model_type,
164 model_resource: config.model_resource,
165 config_resource: config.config_resource,
166 merges_resource: config.merges_resource,
167 vocab_resource: config.vocab_resource,
168 min_length: config.min_length,
169 max_length: config.max_length,
170 do_sample: config.do_sample,
171 early_stopping: config.early_stopping,
172 num_beams: config.num_beams,
173 temperature: config.temperature,
174 top_k: config.top_k,
175 top_p: config.top_p,
176 repetition_penalty: config.repetition_penalty,
177 length_penalty: config.length_penalty,
178 no_repeat_ngram_size: config.no_repeat_ngram_size,
179 num_return_sequences: config.num_return_sequences,
180 num_beam_groups: config.num_beam_groups,
181 diversity_penalty: config.diversity_penalty,
182 device: config.device,
183 kind: config.kind,
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
189/// Data structure keeping track of a conversation in the system. It contains past user inputs and
190/// generated answers, a history of the tokens generated and a placeholder for new user inputs to be
191/// processed by the system if submitted for prediction
192pub struct Conversation {
193 /// Past user inputs that have already been processed
194 pub past_user_inputs: Vec<String>,
195 /// Past system generated responses
196 pub generated_responses: Vec<String>,
197 /// New user input that needs to be processed
198 pub new_user_input: Option<String>,
199 /// History of the tokens passed as an input and generated so far used as context for next turn generation
200 pub history: Vec<Vec<i64>>,
201}
202
203impl Conversation {
204 /// Build a new `Conversation` with an initial user input
205 ///
206 /// # Arguments
207 ///
208 /// * `text` - `String` with the initial user input to start a conversation
209 ///
210 /// # Example
211 ///
212 /// ```no_run
213 /// use rust_bert::pipelines::conversation::Conversation;
214 ///
215 /// let conversation = Conversation::new("Hi there!");
216 /// ```
217 pub fn new(text: &str) -> Conversation {
218 Conversation {
219 past_user_inputs: vec![],
220 generated_responses: vec![],
221 new_user_input: Some(text.to_string()),
222 history: vec![],
223 }
224 }
225
226 /// Build a new `Conversation` placeholder without user input
227 ///
228 /// # Example
229 ///
230 /// ```no_run
231 /// use rust_bert::pipelines::conversation::Conversation;
232 ///
233 /// let conversation = Conversation::new_empty();
234 /// ```
235 pub fn new_empty() -> Conversation {
236 Conversation {
237 past_user_inputs: vec![],
238 generated_responses: vec![],
239 new_user_input: None,
240 history: vec![],
241 }
242 }
243
244 /// Adds a new user input to the conversation. This method returns an error if an unprocessed
245 /// user input already exists
246 ///
247 /// # Arguments
248 ///
249 /// * `text` - `&str` with the additional user input to continue a conversation
250 ///
251 /// # Example
252 ///
253 /// ```no_run
254 /// use rust_bert::pipelines::conversation::Conversation;
255 ///
256 /// let mut conversation = Conversation::new_empty();
257 /// conversation.add_user_input("Hi there!").unwrap();
258 /// ```
259 pub fn add_user_input(&mut self, text: &str) -> Result<(), RustBertError> {
260 if self.new_user_input.is_some() {
261 Err(RustBertError::ValueError(
262 "User input already provided for this conversation".into(),
263 ))
264 } else {
265 self.new_user_input = Some(text.to_string());
266 Ok(())
267 }
268 }
269
270 /// Adds a new user input to the conversation. If an unprocessed user input already exists,
271 /// its contents are overwritten by the new value provided.
272 ///
273 /// # Arguments
274 ///
275 /// * `text` - `&str` with the additional user input to continue a conversation
276 ///
277 /// # Returns
278 ///
279 /// * `Option<String>` containing overwritten string if applicable
280 ///
281 /// # Example
282 ///
283 /// ```no_run
284 /// use rust_bert::pipelines::conversation::Conversation;
285 ///
286 /// let mut conversation = Conversation::new_empty();
287 /// conversation
288 /// .add_user_input("This input will not be used")
289 /// .unwrap();
290 /// let unused_string = conversation.add_user_input_with_overwrite("Hi there!");
291 /// ```
292 pub fn add_user_input_with_overwrite(&mut self, text: &str) -> Option<String> {
293 let old_user_input = if self.new_user_input.is_some() {
294 self.new_user_input.clone()
295 } else {
296 None
297 };
298 self.new_user_input = Some(text.to_string());
299 old_user_input
300 }
301
302 /// Returns `true` if the conversation contains new user inputs to process
303 ///
304 /// # Returns
305 ///
306 /// * `bool` flag indicating if the conversation contains new inputs to process
307 ///
308 /// # Example
309 ///
310 /// ```no_run
311 /// use rust_bert::pipelines::conversation::Conversation;
312 ///
313 /// let mut conversation = Conversation::new_empty();
314 /// let false_value = conversation.contains_new_input();
315 /// conversation
316 /// .add_user_input("This input will not be used")
317 /// .unwrap();
318 /// let true_value = conversation.contains_new_input();
319 /// ```
320 pub fn contains_new_input(&self) -> bool {
321 self.new_user_input.is_some()
322 }
323
324 /// Marks the conversation as processed and moves the user input that was up for
325 /// processing to the past user inputs.
326 ///
327 /// # Example
328 ///
329 /// ```no_run
330 /// use rust_bert::pipelines::conversation::Conversation;
331 ///
332 /// let mut conversation = Conversation::new_empty();
333 /// let false_value = conversation.contains_new_input();
334 /// conversation
335 /// .add_user_input("This input will not be used")
336 /// .unwrap();
337 /// let true_value = conversation.contains_new_input();
338 /// conversation.mark_processed();
339 /// let false_value = conversation.contains_new_input();
340 /// assert_eq!(conversation.past_user_inputs.len(), 1usize);
341 /// ```
342 pub fn mark_processed(&mut self) {
343 if self.new_user_input.is_some() {
344 self.past_user_inputs
345 .push(self.new_user_input.clone().unwrap());
346 self.new_user_input = None;
347 }
348 }
349
350 /// Returns the last user input provided (including non-processed inputs).
351 ///
352 /// # Returns
353 ///
354 /// * `Option<&str>` representation of the last user input provided
355 ///
356 /// # Example
357 ///
358 /// ```no_run
359 /// use rust_bert::pipelines::conversation::Conversation;
360 ///
361 /// let mut conversation = Conversation::new_empty();
362 /// let none_value = conversation.get_last_input();
363 /// conversation
364 /// .add_user_input("This input will not be used")
365 /// .unwrap();
366 /// let last_provided_input = conversation.get_last_input();
367 /// assert_eq!(last_provided_input, Some("This input will not be used"));
368 /// ```
369 pub fn get_last_input(&self) -> Option<&str> {
370 if self.new_user_input.is_some() {
371 Some(self.new_user_input.as_ref().unwrap().as_str())
372 } else if !self.past_user_inputs.is_empty() {
373 Some(self.past_user_inputs.last().unwrap().as_str())
374 } else {
375 None
376 }
377 }
378
379 /// Returns the last response generated by the system.
380 ///
381 /// # Returns
382 ///
383 /// * `Option<&str>` representation of the last response generated by the system.
384 ///
385 /// # Example
386 ///
387 /// ```no_run
388 /// use rust_bert::pipelines::conversation::Conversation;
389 ///
390 /// let mut conversation = Conversation::new("Hi There");
391 /// let non_value = conversation.get_last_response();
392 /// ```
393 pub fn get_last_response(&self) -> Option<&str> {
394 if !self.generated_responses.is_empty() {
395 Some(self.generated_responses.last().unwrap().as_str())
396 } else {
397 None
398 }
399 }
400
401 fn append(&mut self, text: &str, ids: &[i64]) {
402 match &self.new_user_input {
403 Some(_) => {
404 self.mark_processed();
405 if self.past_user_inputs.len() >= self.generated_responses.len() {
406 self.generated_responses.push(text.to_string());
407 } else {
408 let _ = self.add_user_input(text);
409 }
410 }
411 None => {
412 let _ = self.add_user_input(text);
413 }
414 }
415 self.history.push(ids.to_vec());
416 }
417
418 /// Initializes a conversation form a prior state. It is assumed that a conversation always
419 /// start from a user interaction.
420 ///
421 /// # Arguments
422 /// - texts: sequence of strings, alternating between past user inputs and past generated responses.
423 /// - ids: sequence of sequence of ids, alternating between past user inputs and past generated responses.
424 ///
425 /// These can be generated via a `ConversationModel`'s `encode_prompts`.
426 ///
427 /// # Example:
428 ///
429 /// ```no_run
430 /// # fn main() -> anyhow::Result<()> {
431 /// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
432 /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
433 /// let model = ConversationModel::new(Default::default())?;
434 ///
435 /// let mut conversation_manager = ConversationManager::new();
436 /// let history = [
437 /// "Going to the movies tonight - any suggestions?",
438 /// "The Big Lebowski",
439 /// "Is it an action movie?",
440 /// ];
441 /// let encoded_history = model.encode_prompts(&history);
442 ///
443 /// let conversation_1_id = conversation_manager.create_empty();
444 /// let _ = conversation_manager
445 /// .get(&conversation_1_id)
446 /// .unwrap()
447 /// .load_from_history(&history, &encoded_history);
448 /// # Ok(())
449 /// # }
450 /// ```
451 pub fn load_from_history<S, SI>(&mut self, texts: &[S], ids: &[SI])
452 where
453 S: AsRef<str>,
454 SI: AsRef<[i64]>,
455 {
456 for (round_text, round_ids) in texts.iter().zip(ids.iter()) {
457 self.append(round_text.as_ref(), round_ids.as_ref());
458 }
459
460 if texts.len() / 2 == 1 {
461 self.history.pop();
462 }
463 }
464}
465
466/// Data structure allowing the management of conversations and main input to the dialogue model.
467/// It contains a `HashMap` of conversations with `UUID` keys
468#[derive(Debug)]
469pub struct ConversationManager {
470 conversations: HashMap<Uuid, Conversation>,
471}
472
473impl ConversationManager {
474 /// Build a new `ConversationManager`
475 ///
476 /// # Example
477 ///
478 /// ```no_run
479 /// use rust_bert::pipelines::conversation::ConversationManager;
480 ///
481 /// let conversation_manager = ConversationManager::new();
482 /// ```
483 pub fn new() -> ConversationManager {
484 ConversationManager {
485 conversations: HashMap::new(),
486 }
487 }
488
489 /// Returns a list of the active conversations (containing new inputs to be processed by the model)
490 ///
491 /// # Returns
492 ///
493 /// * `(Vec<&Uuid>, Vec<&mut Conversation>)` Tuple of vectors with the active `UUID` and `Conversations`
494 ///
495 /// # Example
496 ///
497 /// ```no_run
498 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
499 ///
500 /// let mut conversation_manager = ConversationManager::new();
501 ///
502 /// let conversation = Conversation::new("Hi there!");
503 /// let empty_conversation = Conversation::new_empty();
504 /// let conversation_id = conversation_manager.add(conversation);
505 /// let empty_conversation_id = conversation_manager.add(empty_conversation);
506 ///
507 /// let active_conversations = conversation_manager.get_active_conversations();
508 /// assert_eq!(active_conversations.0.len(), 1usize);
509 /// ```
510 pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
511 let mut active_uuid = vec![];
512 let mut active_conversations = vec![];
513 for (uuid, conversation) in self.conversations.iter_mut() {
514 if conversation.new_user_input.is_some() {
515 active_uuid.push(uuid);
516 active_conversations.push(conversation)
517 }
518 }
519 (active_uuid, active_conversations)
520 }
521
522 /// Returns a mutable reference to the conversation wih the provided UUID
523 ///
524 /// # Arguments
525 ///
526 /// * `uuid` - `&Uuid` of the conversation to retrieve
527 ///
528 /// # Returns
529 ///
530 /// * `Option<&mut Conversation>` Optional mutable reference to the conversation matching the UUID provided
531 ///
532 /// # Example
533 ///
534 /// ```no_run
535 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
536 ///
537 /// let mut conversation_manager = ConversationManager::new();
538 ///
539 /// let conversation = Conversation::new("Hi there!");
540 /// let conversation_id = conversation_manager.add(conversation);
541 ///
542 /// let conversation_ref = conversation_manager.get(&conversation_id);
543 /// ```
544 pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
545 self.conversations.get_mut(uuid)
546 }
547
548 /// Returns a HashMap containing references to all conversations stored in the manager
549 ///
550 /// # Example
551 ///
552 /// ```no_run
553 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
554 ///
555 /// let mut conversation_manager = ConversationManager::new();
556 ///
557 /// let conversation = Conversation::new("Hi there!");
558 /// let conversation_id = conversation_manager.add(conversation);
559 ///
560 /// let all_conversations = conversation_manager.get_all();
561 /// ```
562 pub fn get_all(&mut self) -> HashMap<&Uuid, &Conversation> {
563 let mut output = HashMap::with_capacity(self.conversations.len());
564 for (uuid, conversation) in self.conversations.iter() {
565 output.insert(uuid, conversation);
566 }
567 output
568 }
569
570 /// Creates a conversation and add it to the conversation manager
571 ///
572 /// # Arguments
573 ///
574 /// * `text` - `&str` string slice with an original user input
575 ///
576 /// # Returns
577 ///
578 /// * `Uuid` for the conversation created
579 ///
580 /// # Example
581 ///
582 /// ```no_run
583 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
584 ///
585 /// let mut conversation_manager = ConversationManager::new();
586 ///
587 /// let conversation_id = conversation_manager.create("Hi there!");
588 /// ```
589 pub fn create(&mut self, text: &str) -> Uuid {
590 let conversation = Conversation::new(text);
591 self.add(conversation)
592 }
593
594 /// Creates an empty conversation and add it to the conversation manager
595 ///
596 /// # Returns
597 ///
598 /// * `Uuid` for the conversation created
599 ///
600 /// # Example
601 ///
602 /// ```no_run
603 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
604 ///
605 /// let mut conversation_manager = ConversationManager::new();
606 ///
607 /// let conversation_id = conversation_manager.create_empty();
608 /// ```
609 pub fn create_empty(&mut self) -> Uuid {
610 let conversation = Conversation::new_empty();
611 self.add(conversation)
612 }
613
614 /// Adds an existing conversation to the conversation manager
615 ///
616 /// # Arguments
617 ///
618 /// * `conversation` - `Conversation` to be added to the conversation manager
619 ///
620 /// # Returns
621 ///
622 /// * `Uuid` for the conversation created
623 ///
624 /// # Example
625 ///
626 /// ```no_run
627 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
628 ///
629 /// let mut conversation_manager = ConversationManager::new();
630 ///
631 /// let conversation = Conversation::new("Hi there!");
632 /// let conversation_id = conversation_manager.add(conversation);
633 /// ```
634 pub fn add(&mut self, conversation: Conversation) -> Uuid {
635 let mut uuid = Uuid::new_v4();
636 while self.conversations.contains_key(&uuid) {
637 uuid = Uuid::new_v4();
638 }
639 self.conversations.insert(uuid, conversation);
640 uuid
641 }
642
643 /// Deregister a conversation from the conversation manager
644 ///
645 /// # Arguments
646 ///
647 /// * `uuid` - `&Uuid` of the conversation to deregister from the conversation manager
648 ///
649 /// # Returns
650 ///
651 /// * `Option<Conversation>` de-registered conversation
652 ///
653 /// # Example
654 ///
655 /// ```no_run
656 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
657 ///
658 /// let mut conversation_manager = ConversationManager::new();
659 ///
660 /// let conversation_id = conversation_manager.create("Hi there!");
661 /// conversation_manager.remove(&conversation_id);
662 /// ```
663 pub fn remove(&mut self, uuid: &Uuid) -> Option<Conversation> {
664 self.conversations.remove(uuid)
665 }
666
667 /// Clear all conversations from the conversation manager, and returns the conversations and their
668 /// former UUID.
669 ///
670 /// # Returns
671 ///
672 /// * `HashMap<Uuid, Conversation>` de-registered conversations
673 ///
674 /// # Example
675 ///
676 /// ```no_run
677 /// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
678 ///
679 /// let mut conversation_manager = ConversationManager::new();
680 ///
681 /// let conversation_id = conversation_manager.create("Hi there!");
682 /// let conversations = conversation_manager.clear();
683 /// ```
684 pub fn clear(&mut self) -> HashMap<Uuid, Conversation> {
685 let mut output = HashMap::with_capacity(self.conversations.len());
686 for (uuid, conversation) in self.conversations.iter() {
687 output.insert(*uuid, conversation.clone());
688 }
689 self.conversations = HashMap::new();
690 output
691 }
692}
693
694impl Default for ConversationManager {
695 fn default() -> Self {
696 Self::new()
697 }
698}
699
700/// # Abstraction that holds one particular conversation model, for any of the supported models
701pub enum ConversationOption {
702 /// Conversation based on GPT2 model
703 GPT2(GPT2Generator),
704}
705
706impl ConversationOption {
707 pub fn new(config: ConversationConfig) -> Result<Self, RustBertError> {
708 match config.model_type {
709 ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new(config.into())?)),
710 _ => Err(RustBertError::InvalidConfigurationError(
711 "GPT2 is currently the only supported model for conversation generation"
712 .to_string(),
713 )),
714 }
715 }
716
717 pub fn new_with_tokenizer(
718 config: ConversationConfig,
719 tokenizer: TokenizerOption,
720 ) -> Result<Self, RustBertError> {
721 match config.model_type {
722 ModelType::GPT2 => Ok(ConversationOption::GPT2(GPT2Generator::new_with_tokenizer(
723 config.into(),
724 tokenizer,
725 )?)),
726 _ => Err(RustBertError::InvalidConfigurationError(
727 "GPT2 is currently the only supported model for conversation generation"
728 .to_string(),
729 )),
730 }
731 }
732
733 pub fn get_eos_id(&self) -> Result<i64, RustBertError> {
734 match self {
735 Self::GPT2(model_ref) => {
736 Ok(*model_ref.get_eos_ids().as_ref().unwrap().first().unwrap())
737 }
738 }
739 }
740
741 /// Get a reference to the model tokenizer.
742 pub fn get_tokenizer(&self) -> &TokenizerOption {
743 match self {
744 Self::GPT2(model_ref) => model_ref._get_tokenizer(),
745 }
746 }
747
748 /// Get a mutable reference to the model tokenizer.
749 pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
750 match self {
751 Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
752 }
753 }
754
755 /// Returns the `ModelType` for this ConversationOption
756 pub fn model_type(&self) -> ModelType {
757 match *self {
758 Self::GPT2(_) => ModelType::GPT2,
759 }
760 }
761
762 /// Interface method to generate_from_ids_and_past() of the particular models.
763 pub fn generate_from_ids_and_past(
764 &self,
765 input_ids: Tensor,
766 attention_mask: Option<Tensor>,
767 ) -> Result<Vec<Vec<i64>>, RustBertError> {
768 Ok(match *self {
769 Self::GPT2(ref model) => model
770 .generate_from_ids_and_past(input_ids, attention_mask, None)?
771 .into_iter()
772 .map(|output| output.indices)
773 .collect(),
774 })
775 }
776}
777
778/// # Conversation model
779/// Processes a ConversationManager and generate system responses for active conversations.
780pub struct ConversationModel {
781 model: ConversationOption,
782 eos_token_id: i64,
783 max_allowed_context_length: Option<i64>,
784 device: Device,
785}
786
787impl ConversationModel {
788 /// Build a new `ConversationModel`
789 ///
790 /// # Arguments
791 ///
792 /// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
793 ///
794 /// # Example
795 ///
796 /// ```no_run
797 /// # fn main() -> anyhow::Result<()> {
798 /// use rust_bert::pipelines::conversation::ConversationModel;
799 ///
800 /// let conversation_model = ConversationModel::new(Default::default())?;
801 /// # Ok(())
802 /// # }
803 /// ```
804 pub fn new(
805 conversation_config: ConversationConfig,
806 ) -> Result<ConversationModel, RustBertError> {
807 let max_allowed_length = conversation_config
808 .max_length
809 .map(|max_length| max_length - conversation_config.min_length_for_response);
810 let device = conversation_config.device;
811 let model = ConversationOption::new(conversation_config)?;
812 let eos_token_id = model.get_eos_id()?;
813 Ok(ConversationModel {
814 model,
815 eos_token_id,
816 max_allowed_context_length: max_allowed_length,
817 device,
818 })
819 }
820
821 /// Build a new `ConversationModel` with a provided tokenizer.
822 ///
823 /// # Arguments
824 ///
825 /// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
826 /// * `tokenizer` - `TokenizerOption` tokenizer to use for conversation
827 ///
828 /// # Example
829 ///
830 /// ```no_run
831 /// # fn main() -> anyhow::Result<()> {
832 /// use rust_bert::pipelines::common::{ModelType, TokenizerOption};
833 /// use rust_bert::pipelines::conversation::ConversationModel;
834 /// let tokenizer = TokenizerOption::from_file(
835 /// ModelType::GPT2,
836 /// "path/to/vocab.json",
837 /// Some("path/to/merges.txt"),
838 /// false,
839 /// None,
840 /// None,
841 /// )?;
842 /// let conversation_model = ConversationModel::new_with_tokenizer(Default::default(), tokenizer)?;
843 /// # Ok(())
844 /// # }
845 /// ```
846 pub fn new_with_tokenizer(
847 conversation_config: ConversationConfig,
848 tokenizer: TokenizerOption,
849 ) -> Result<ConversationModel, RustBertError> {
850 let max_allowed_length = conversation_config
851 .max_length
852 .map(|max_length| max_length - conversation_config.min_length_for_response);
853 let device = conversation_config.device;
854 let model = ConversationOption::new_with_tokenizer(conversation_config, tokenizer)?;
855 let eos_token_id = model.get_eos_id()?;
856 Ok(ConversationModel {
857 model,
858 eos_token_id,
859 max_allowed_context_length: max_allowed_length,
860 device,
861 })
862 }
863
864 /// Perform a multi-turn conversation based on user input
865 ///
866 /// # Arguments
867 ///
868 /// * `conversation_manager` - `&mut ConversationManager` Conversation manager keeping track of active conversations
869 ///
870 /// # Returns
871 /// * `HashMap<&Uuid, &str>` Responses from the model for each active conversation, referenced by Uuid
872 ///
873 /// # Example
874 ///
875 /// ```no_run
876 /// # fn main() -> anyhow::Result<()> {
877 /// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
878 /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
879 /// let model = ConversationModel::new(Default::default())?;
880 ///
881 /// let mut conversation_manager = ConversationManager::new();
882 /// conversation_manager.create("Hello, how are you?");
883 ///
884 /// let output = model.generate_responses(&mut conversation_manager);
885 /// # Ok(())
886 /// # }
887 /// ```
888 pub fn generate_responses<'a>(
889 &self,
890 conversation_manager: &'a mut ConversationManager,
891 ) -> Result<HashMap<&'a Uuid, &'a str>, RustBertError> {
892 let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
893 let updated_conversations = if !active_uuid.is_empty() {
894 let texts = active_conversations
895 .iter()
896 .map(|c| c.new_user_input.as_ref().unwrap().as_str())
897 .collect::<Vec<&str>>();
898
899 let history = active_conversations
900 .iter()
901 .map(|c| c.history.iter().flatten().copied().collect())
902 .collect::<Vec<Vec<i64>>>();
903
904 let prompt_ids = self.encode_prompts(texts.as_ref());
905 let (input_tensor, attention_mask) =
906 self.concat_input_history(prompt_ids.as_ref(), history);
907 let input_length = *input_tensor.size().last().unwrap() as usize;
908 let mut generated = self
909 .model
910 .generate_from_ids_and_past(input_tensor, Some(attention_mask))?;
911 let removed_padding_quantities = self.clean_padding_indices(&mut generated);
912
913 let mut output = HashMap::with_capacity(active_uuid.len());
914
915 for (
916 ((conversation, (generated_sequence, conversation_promp_ids)), uuid),
917 removed_padding,
918 ) in active_conversations
919 .into_iter()
920 .zip(generated.into_iter().zip(prompt_ids.into_iter()))
921 .zip(active_uuid.into_iter())
922 .zip(removed_padding_quantities.into_iter())
923 {
924 let generated_response = &generated_sequence[input_length - removed_padding.0..];
925 conversation
926 .generated_responses
927 .push(
928 self.model
929 .get_tokenizer()
930 .decode(generated_response, true, true),
931 );
932 conversation.history.push(conversation_promp_ids);
933 conversation.history.push(generated_response.to_vec());
934 conversation.mark_processed();
935 output.insert(uuid, conversation.get_last_response().unwrap());
936 }
937 output
938 } else {
939 HashMap::new()
940 };
941 Ok(updated_conversations)
942 }
943
944 fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) -> Vec<(usize, usize)> {
945 // In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
946 let pad_token = self
947 .model
948 .get_tokenizer()
949 .get_pad_id()
950 .unwrap_or(self.eos_token_id);
951 let mut removed_tokens = Vec::with_capacity(model_output.len());
952 for sequence_history in model_output {
953 let index_end = sequence_history
954 .iter()
955 .rev()
956 .position(|&r| r != pad_token)
957 .unwrap();
958 let index_start = sequence_history
959 .iter()
960 .position(|&r| r != pad_token)
961 .unwrap();
962 if index_end > 0 {
963 sequence_history.drain(sequence_history.len() - index_end + 1..);
964 }
965 sequence_history.drain(..index_start);
966 removed_tokens.push((index_start, index_end));
967 }
968 removed_tokens
969 }
970
971 fn concat_input_history(
972 &self,
973 inputs: &[Vec<i64>],
974 history: Vec<Vec<i64>>,
975 ) -> (Tensor, Tensor) {
976 // Concatenates the history token indices with new user input
977 let pad_token = self
978 .model
979 .get_tokenizer()
980 .get_pad_id()
981 .unwrap_or(self.eos_token_id);
982
983 assert_eq!(
984 inputs.len(),
985 history.len(),
986 "Length of inputs should equal length of history"
987 );
988
989 let mut concatenated_inputs = Vec::with_capacity(inputs.len());
990 for (input, history) in inputs.iter().zip(history.iter()) {
991 let mut concatenated_element = Vec::with_capacity(input.len() + history.len());
992 concatenated_element.extend_from_slice(history);
993 concatenated_element.extend_from_slice(input);
994 concatenated_inputs.push(concatenated_element);
995 }
996
997 let truncated_concatenated_inputs = concatenated_inputs
998 .iter()
999 .map(|input| match self.max_allowed_context_length {
1000 Some(max_allowed_context_length)
1001 if input.len() > max_allowed_context_length as usize =>
1002 {
1003 let start = self.get_truncated_input_index(
1004 input,
1005 max_allowed_context_length as usize,
1006 pad_token,
1007 );
1008 &input[start..]
1009 }
1010 _ => input.as_slice(),
1011 })
1012 .collect::<Vec<&[i64]>>();
1013
1014 let max_len = truncated_concatenated_inputs
1015 .iter()
1016 .map(|input| input.len())
1017 .max()
1018 .unwrap();
1019
1020 let attention_mask = Tensor::ones(
1021 [inputs.len() as i64, max_len as i64],
1022 (Kind::Int8, self.device),
1023 );
1024
1025 let concatenated_inputs = truncated_concatenated_inputs
1026 .into_iter()
1027 .enumerate()
1028 .map(|(input_idx, input)| {
1029 let _ = attention_mask
1030 .get(input_idx as i64)
1031 .slice(0, 0, (max_len - input.len()) as i64, 1)
1032 .fill_(0);
1033 let mut padded_input = vec![pad_token; max_len - input.len()];
1034 padded_input.extend(input);
1035 padded_input
1036 })
1037 .map(|tokens| Tensor::from_slice(&tokens).to(self.device))
1038 .collect::<Vec<Tensor>>();
1039
1040 (Tensor::stack(&concatenated_inputs, 0), attention_mask)
1041 }
1042
1043 fn get_truncated_input_index(
1044 &self,
1045 history: &[i64],
1046 max_length: usize,
1047 pad_token: i64,
1048 ) -> usize {
1049 let start_length = history.len();
1050 let eos_indices: Vec<usize> = history
1051 .iter()
1052 .enumerate()
1053 .filter(|(i, &e)| {
1054 (e == pad_token)
1055 & (*i != start_length - 1)
1056 & ((start_length as isize - max_length as isize - *i as isize) < 0)
1057 })
1058 .map(|(i, _)| i + 1)
1059 .collect();
1060
1061 // Return the position of the first EOS index that fits the max length requirement.
1062 // If it does not exist, no solution exists and truncate text at a non-EOS position
1063 *eos_indices.first().unwrap_or(&(start_length - max_length))
1064 }
1065
1066 /// Encodes prompts into Vectors of indices to be processed by the model. This method may be used to
1067 /// initialize the history of a conversation with a prior state.
1068 ///
1069 /// # Example:
1070 ///
1071 /// ```no_run
1072 /// # fn main() -> anyhow::Result<()> {
1073 /// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
1074 /// use rust_bert::pipelines::generation_utils::LanguageGenerator;
1075 /// let model = ConversationModel::new(Default::default())?;
1076 /// let history = [
1077 /// "Going to the movies tonight - any suggestions?",
1078 /// "The Big Lebowski",
1079 /// "Is it an action movie?",
1080 /// ];
1081 /// let encoded_history = model.encode_prompts(&history);
1082 /// # Ok(())
1083 /// # }
1084 /// ```
1085 pub fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
1086 // Encode the user prompt into token ids
1087 let tokens = self.model.get_tokenizer().tokenize_list(texts);
1088
1089 tokens
1090 .into_iter()
1091 .map(|prompt_tokens| {
1092 self.model
1093 .get_tokenizer()
1094 .convert_tokens_to_ids(&prompt_tokens)
1095 })
1096 .map(|mut tokens| {
1097 if let Some(max_allowed_context_length) = self.max_allowed_context_length {
1098 tokens.truncate(max_allowed_context_length as usize - 1);
1099 }
1100 tokens.push(self.eos_token_id);
1101 tokens
1102 })
1103 .collect::<Vec<Vec<i64>>>()
1104 }
1105}
1106
1107#[cfg(test)]
1108mod test {
1109 use super::*;
1110
1111 #[test]
1112 #[ignore] // no need to run, compilation is enough to verify it is Send
1113 fn test() {
1114 let config = ConversationConfig::default();
1115 let _: Box<dyn Send> = Box::new(ConversationModel::new(config));
1116 }
1117}