[][src]Struct rust_bert::pipelines::question_answering::QuestionAnsweringModel

pub struct QuestionAnsweringModel { /* fields omitted */ }

Methods

impl QuestionAnsweringModel[src]

pub fn new(
    vocab_path: &Path,
    config_path: &Path,
    weights_path: &Path,
    device: Device
) -> Fallible<QuestionAnsweringModel>
[src]

Build a new QuestionAnsweringModel

Arguments

  • vocab_path - Path to the model vocabulary, expected to have a structure following the Transformers library convention
  • config_path - Path to the model configuration, expected to have a structure following the Transformers library convention
  • weights_path - Path to the model weight files. These need to be converted form the .bin to .ot format using the utility script provided.
  • device - Device to run the model on, e.g. Device::Cpu or Device::Cuda(0)

Example

use tch::Device;
use std::path::{Path, PathBuf};
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;

let mut home: PathBuf = dirs::home_dir().unwrap();
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
let device = Device::Cpu;
let qa_model =  QuestionAnsweringModel::new(vocab_path,
                                            config_path,
                                            weights_path,
                                            device)?;

pub fn predict(
    &self,
    qa_inputs: &[QaInput],
    top_k: i64,
    batch_size: usize
) -> Vec<Vec<Answer>>
[src]

Perform extractive question answering given a list of QaInputs

Arguments

  • qa_inputs - &[QaInput] Array of Question Answering inputs (context and question pairs)
  • top_k - return the top-k answers for each QaInput. Set to 1 to return only the best answer.
  • batch_size - maximum batch size for the model forward pass.

Returns

  • Vec<Vec<Answer>> Vector (same length as qa_inputs) of vectors (each of length top_k) containing the extracted answers.

Example

use tch::Device;
use std::path::{Path, PathBuf};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};

let mut home: PathBuf = dirs::home_dir().unwrap();
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
let device = Device::Cpu;
let qa_model =  QuestionAnsweringModel::new(vocab_path,
                                            config_path,
                                            weights_path,
                                            device)?;

let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");

let qa_input_1 = QaInput { question: question_1, context: context_1 };
let qa_input_2 = QaInput { question: question_2, context: context_2 };
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);

Auto Trait Implementations

Blanket Implementations

impl<T> Any for T where
    T: 'static + ?Sized
[src]

impl<T> Borrow<T> for T where
    T: ?Sized
[src]

impl<T> BorrowMut<T> for T where
    T: ?Sized
[src]

impl<T> From<T> for T[src]

impl<T, U> Into<U> for T where
    U: From<T>, 
[src]

impl<T, U> TryFrom<U> for T where
    U: Into<T>, 
[src]

type Error = Infallible

The type returned in the event of a conversion error.

impl<T, U> TryInto<U> for T where
    U: TryFrom<T>, 
[src]

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.