1pub mod ai21;
2pub mod amazon;
3pub mod anthropic;
4pub mod cohere;
5pub mod meta;
6pub mod stability;
7
8use crate::ai21::AI21LabsModel;
9use crate::amazon::AmazonModel;
10use crate::anthropic::AnthropicModel;
11use anyhow::Result;
12use aws_sdk_bedrockruntime::operation::invoke_model::InvokeModelOutput;
13use cohere::CohereModel;
14use meta::MetaModel;
15use serde::Deserialize;
16use stability::StabilityAIModel;
17use std::fmt::{Display, Formatter};
18
19pub enum BaseModel {
38 AI21Labs(AI21LabsModel),
39 Amazon(AmazonModel),
40 Anthropic(AnthropicModel),
41 Cohere(CohereModel),
42 Meta(MetaModel),
43 StabilityAI(StabilityAIModel),
44}
45
46#[derive(PartialEq)]
47pub enum ModelVersion {
48 V0,
49 V1,
50 V2,
51 V3,
52 V14,
53 V15,
54}
55
56impl Display for BaseModel {
57 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58 let id: String = match self {
59 BaseModel::AI21Labs(model) => format!("ai21.{model}"),
60 BaseModel::Amazon(model) => format!("amazon.{model}"),
61 BaseModel::Anthropic(model) => format!("anthropic.{model}"),
62 BaseModel::Cohere(model) => format!("cohere.{model}"),
63 BaseModel::Meta(model) => format!("meta.{model}"),
64 BaseModel::StabilityAI(model) => format!("stability.{model}"),
65 };
66 write!(f, "{id}")
67 }
68}
69
70pub trait FromModelOutput<'de, T>
71where
72 T: Deserialize<'de>,
73{
74 fn from_model_output(output: &'de InvokeModelOutput) -> Result<T> {
75 let res = std::str::from_utf8(output.body.as_ref())?;
76 Ok(serde_json::from_str(res)?)
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::ai21::AI21LabsModel;
84 use crate::amazon::AmazonModel;
85 use crate::anthropic::AnthropicModel;
86 use crate::ModelVersion::{V0, V1, V14, V15, V2, V3};
87
88 #[test]
90 fn test_base_model_to_string() {
91 let model = BaseModel::AI21Labs(AI21LabsModel::Jurassic2Mid(V1));
92 assert_eq!(model.to_string(), "ai21.j2-mid-v1");
93
94 let model = BaseModel::AI21Labs(AI21LabsModel::Jurassic2Ultra(V1));
95 assert_eq!(model.to_string(), "ai21.j2-ultra-v1");
96
97 let model = BaseModel::Amazon(AmazonModel::TitanTextLite(V1));
98 assert_eq!(model.to_string(), "amazon.titan-text-lite-v1");
99
100 let model = BaseModel::Amazon(AmazonModel::TitanEmbeddingsText(V1));
101 assert_eq!(model.to_string(), "amazon.titan-embed-text-v1");
102
103 let model = BaseModel::Amazon(AmazonModel::TitanTextExpress(V1));
104 assert_eq!(model.to_string(), "amazon.titan-text-express-v1");
105
106 let model = BaseModel::Amazon(AmazonModel::TitanTextAgile(V1));
107 assert_eq!(model.to_string(), "amazon.titan-text-agile-v1");
108
109 let model = BaseModel::Anthropic(AnthropicModel::Claude(V1));
110 assert_eq!(model.to_string(), "anthropic.claude-v1");
111
112 let model = BaseModel::Anthropic(AnthropicModel::Claude(V2));
113 assert_eq!(model.to_string(), "anthropic.claude-v2");
114
115 let model = BaseModel::Anthropic(AnthropicModel::ClaudeInstant(V1));
116 assert_eq!(model.to_string(), "anthropic.claude-instant-v1");
117
118 let model = BaseModel::Cohere(CohereModel::Command(V14));
119 assert_eq!(model.to_string(), "cohere.command-text-v14");
120
121 let model = BaseModel::Cohere(CohereModel::CommandLight(V15));
122 assert_eq!(model.to_string(), "cohere.command-light-text-v14");
123
124 let model = BaseModel::Cohere(CohereModel::EmbedEnglish(V3));
125 assert_eq!(model.to_string(), "cohere.embed-english-v3");
126
127 let model = BaseModel::Cohere(CohereModel::EmbedMultilingual(V3));
128 assert_eq!(model.to_string(), "cohere.embed-multilingual-v3");
129
130 let model = BaseModel::Meta(MetaModel::Llama2Chat13B(V1));
131 assert_eq!(model.to_string(), "meta.llama2-13b-chat-v1");
132
133 let model = BaseModel::StabilityAI(StabilityAIModel::StableDiffusionXL(V0));
134 assert_eq!(model.to_string(), "stability.stable-diffusion-xl-v0");
135 }
136}