1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
use std::path::PathBuf;

use serde::Deserialize;
use tch::{Device, Kind};

use crate::pipelines::common::ModelType;
use crate::pipelines::sentence_embeddings::{
    SentenceEmbeddingsConfig, SentenceEmbeddingsModel, SentenceEmbeddingsModulesConfig,
};
use crate::{Config, RustBertError};

#[cfg(feature = "remote")]
use crate::{
    pipelines::sentence_embeddings::resources::SentenceEmbeddingsModelType,
    resources::RemoteResource,
};

/// # SentenceEmbeddings Model Builder
///
/// Allows the user to build a model from standard Sentence-Transformer files
/// (configuration and weights).
pub struct SentenceEmbeddingsBuilder<T> {
    device: Device,
    kind: Option<Kind>,
    inner: T,
}

impl<T> SentenceEmbeddingsBuilder<T> {
    pub fn with_device(mut self, device: Device) -> Self {
        self.device = device;
        self
    }

    pub fn with_kind(mut self, kind: Kind) -> Self {
        self.kind = Some(kind);
        self
    }
}

pub struct Local {
    model_dir: PathBuf,
}

#[derive(Debug, Deserialize)]
struct ModelConfig {
    model_type: ModelType,
}

impl Config for ModelConfig {}

impl SentenceEmbeddingsBuilder<Local> {
    pub fn local<P: Into<PathBuf>>(model_dir: P) -> Self {
        Self {
            device: Device::cuda_if_available(),
            kind: None,
            inner: Local {
                model_dir: model_dir.into(),
            },
        }
    }

    pub fn create_model(self) -> Result<SentenceEmbeddingsModel, RustBertError> {
        let model_dir = self.inner.model_dir;

        let modules_config = model_dir.join("modules.json");
        let modules = SentenceEmbeddingsModulesConfig::from_file(&modules_config).validate()?;

        let transformer_config = model_dir.join("config.json");
        let transformer_type = ModelConfig::from_file(&transformer_config).model_type;
        let transformer_weights = model_dir.join("rust_model.ot");

        let pooling_config = model_dir
            .join(&modules.pooling_module().path)
            .join("config.json");

        let (dense_config, dense_weights) = modules
            .dense_module()
            .map(|m| {
                (
                    Some(model_dir.join(&m.path).join("config.json")),
                    Some(model_dir.join(&m.path).join("rust_model.ot")),
                )
            })
            .unwrap_or((None, None));

        let tokenizer_config = model_dir.join("tokenizer_config.json");
        let sentence_bert_config = model_dir.join("sentence_bert_config.json");
        let (tokenizer_vocab, tokenizer_merges) = match transformer_type {
            ModelType::Bert | ModelType::DistilBert => (model_dir.join("vocab.txt"), None),
            ModelType::Roberta => (
                model_dir.join("vocab.json"),
                Some(model_dir.join("merges.txt")),
            ),
            ModelType::Albert => (model_dir.join("spiece.model"), None),
            ModelType::T5 => (model_dir.join("spiece.model"), None),
            _ => {
                return Err(RustBertError::InvalidConfigurationError(format!(
                    "Unsupported transformer model {transformer_type:?} for Sentence Embeddings",
                )));
            }
        };

        let config = SentenceEmbeddingsConfig {
            modules_config_resource: modules_config.into(),
            transformer_type,
            transformer_config_resource: transformer_config.into(),
            transformer_weights_resource: transformer_weights.into(),
            pooling_config_resource: pooling_config.into(),
            dense_config_resource: dense_config.map(|r| r.into()),
            dense_weights_resource: dense_weights.map(|r| r.into()),
            sentence_bert_config_resource: sentence_bert_config.into(),
            tokenizer_config_resource: tokenizer_config.into(),
            tokenizer_vocab_resource: tokenizer_vocab.into(),
            tokenizer_merges_resource: tokenizer_merges.map(|r| r.into()),
            device: self.device,
            kind: self.kind,
        };

        SentenceEmbeddingsModel::new(config)
    }
}

#[cfg(feature = "remote")]
pub struct Remote {
    config: SentenceEmbeddingsConfig,
}

#[cfg(feature = "remote")]
impl SentenceEmbeddingsBuilder<Remote> {
    pub fn remote(model_type: SentenceEmbeddingsModelType) -> Self {
        Self {
            device: Device::cuda_if_available(),
            kind: None,
            inner: Remote {
                config: SentenceEmbeddingsConfig::from(model_type),
            },
        }
    }

    pub fn modules_config(mut self, resource: RemoteResource) -> Self {
        self.inner.config.modules_config_resource = Box::new(resource);
        self
    }

    pub fn transformer_config(mut self, resource: RemoteResource) -> Self {
        self.inner.config.transformer_config_resource = Box::new(resource);
        self
    }

    pub fn transformer_weights(mut self, resource: RemoteResource) -> Self {
        self.inner.config.transformer_weights_resource = Box::new(resource);
        self
    }

    pub fn pooling_config(mut self, resource: RemoteResource) -> Self {
        self.inner.config.pooling_config_resource = Box::new(resource);
        self
    }

    pub fn dense_config(mut self, resource: RemoteResource) -> Self {
        self.inner.config.dense_config_resource = Some(Box::new(resource));
        self
    }

    pub fn dense_weights(mut self, resource: RemoteResource) -> Self {
        self.inner.config.dense_weights_resource = Some(Box::new(resource));
        self
    }

    pub fn sentence_bert_config(mut self, resource: RemoteResource) -> Self {
        self.inner.config.sentence_bert_config_resource = Box::new(resource);
        self
    }

    pub fn tokenizer_config(mut self, resource: RemoteResource) -> Self {
        self.inner.config.tokenizer_config_resource = Box::new(resource);
        self
    }

    pub fn tokenizer_vocab(mut self, resource: RemoteResource) -> Self {
        self.inner.config.tokenizer_vocab_resource = Box::new(resource);
        self
    }

    pub fn tokenizer_merges(mut self, resource: RemoteResource) -> Self {
        self.inner.config.tokenizer_merges_resource = Some(Box::new(resource));
        self
    }

    pub fn create_model(mut self) -> Result<SentenceEmbeddingsModel, RustBertError> {
        self.inner.config.device = self.device;
        SentenceEmbeddingsModel::new(self.inner.config)
    }
}