stone_mason/
amazon.rs

1use crate::ModelVersion;
2use crate::ModelVersion::V1;
3use derive_builder::Builder;
4use serde::Serialize;
5use std::fmt::{Display, Error, Formatter};
6
7pub enum AmazonModel {
8    TitanTextLite(ModelVersion),
9    TitanEmbeddingsText(ModelVersion),
10    TitanTextExpress(ModelVersion),
11    TitanTextAgile(ModelVersion),
12}
13
14impl Display for AmazonModel {
15    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
16        match self {
17            AmazonModel::TitanTextLite(v) if *v == V1 => write!(f, "titan-text-lite-v1"),
18            AmazonModel::TitanEmbeddingsText(v) if *v == V1 => write!(f, "titan-embed-text-v1"),
19            AmazonModel::TitanTextExpress(v) if *v == V1 => write!(f, "titan-text-express-v1"),
20            AmazonModel::TitanTextAgile(v) if *v == V1 => write!(f, "titan-text-agile-v1"),
21            _ => Err(Error),
22        }
23    }
24}
25
26#[derive(Serialize, Builder, Clone, Debug)]
27pub struct AmazonParams {
28    #[serde(rename(serialize = "inputText"))]
29    input_text: String,
30
31    #[serde(rename(serialize = "textGenerationConfig"))]
32    text_generation_config: TextGenerationConfig,
33}
34
35#[derive(Serialize, Builder, Clone, Debug)]
36#[builder(setter(strip_option))]
37pub struct TextGenerationConfig {
38    #[builder(default = "None")]
39    #[serde(skip_serializing_if = "Option::is_none")]
40    temperature: Option<f32>,
41
42    #[builder(default = "None")]
43    #[serde(skip_serializing_if = "Option::is_none")]
44    #[serde(rename(serialize = "topP"))]
45    top_p: Option<f32>,
46
47    #[builder(default = "None")]
48    #[serde(skip_serializing_if = "Option::is_none")]
49    #[serde(rename(serialize = "maxTokenCount"))]
50    max_token_count: Option<u32>,
51
52    #[serde(rename(serialize = "stopSequences"))]
53    stop_sequences: Vec<String>,
54}