stone_mason/
lib.rs

1pub mod ai21;
2pub mod amazon;
3pub mod anthropic;
4pub mod cohere;
5pub mod meta;
6pub mod stability;
7
8use crate::ai21::AI21LabsModel;
9use crate::amazon::AmazonModel;
10use crate::anthropic::AnthropicModel;
11use anyhow::Result;
12use aws_sdk_bedrockruntime::operation::invoke_model::InvokeModelOutput;
13use cohere::CohereModel;
14use meta::MetaModel;
15use serde::Deserialize;
16use stability::StabilityAIModel;
17use std::fmt::{Display, Formatter};
18
19/// | Provider     | Model name                 | Version | Model Id                         |
20/// |--------------|----------------------------|---------|----------------------------------|
21/// | AI21 Labs    | Jurassic-2 Mid             | 1.x     | ai21.j2-mid-v1                   |
22/// | AI21 Labs    | Jurassic-2 Ultra           | 1.x     | ai21.j2-ultra-v1                 |
23/// | Amazon       | Titan Text G1 - Lite       | 1.x     | amazon.titan-text-lite-v1        |
24/// | Amazon       | Titan Embeddings G1 - Text | 1.x     | amazon.titan-embed-text-v1       |
25/// | Amazon       | Titan Text G1 - Express    | 1.x     | amazon.titan-text-express-v1     |
26/// | Amazon       | Titan Text G1 - Agile      | 1.x     | amazon.titan-text-agile-v1       |
27/// | Anthropic    | Claude                     | 1.x     | anthropic.claude-v1              |
28/// | Anthropic    | Claude                     | 2.x     | anthropic.claude-v2              |
29/// | Anthropic    | Claude Instant             | 1.x     | anthropic.claude-instant-v1      |
30/// | Cohere       | Command                    | 14.x    | cohere.command-text-v14          |
31/// | Cohere       | Command Light              | 15.x    | cohere.command-light-text-v14    |
32/// | Cohere       | Embed English              | 3.x     | cohere.embed-english-v3          |
33/// | Cohere       | Embed Multilingual         | 3.x     | cohere.embed-multilingual-v3     |
34/// | Meta         | Llama 2 Chat 13B           | 1.x     | meta.llama2-13b-chat-v1          |
35/// | Stability AI | Stable Diffusion XL        | 0.x     | stability.stable-diffusion-xl-v0 |
36
37pub enum BaseModel {
38    AI21Labs(AI21LabsModel),
39    Amazon(AmazonModel),
40    Anthropic(AnthropicModel),
41    Cohere(CohereModel),
42    Meta(MetaModel),
43    StabilityAI(StabilityAIModel),
44}
45
46#[derive(PartialEq)]
47pub enum ModelVersion {
48    V0,
49    V1,
50    V2,
51    V3,
52    V14,
53    V15,
54}
55
56impl Display for BaseModel {
57    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58        let id: String = match self {
59            BaseModel::AI21Labs(model) => format!("ai21.{model}"),
60            BaseModel::Amazon(model) => format!("amazon.{model}"),
61            BaseModel::Anthropic(model) => format!("anthropic.{model}"),
62            BaseModel::Cohere(model) => format!("cohere.{model}"),
63            BaseModel::Meta(model) => format!("meta.{model}"),
64            BaseModel::StabilityAI(model) => format!("stability.{model}"),
65        };
66        write!(f, "{id}")
67    }
68}
69
70pub trait FromModelOutput<'de, T>
71where
72    T: Deserialize<'de>,
73{
74    fn from_model_output(output: &'de InvokeModelOutput) -> Result<T> {
75        let res = std::str::from_utf8(output.body.as_ref())?;
76        Ok(serde_json::from_str(res)?)
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::ai21::AI21LabsModel;
84    use crate::amazon::AmazonModel;
85    use crate::anthropic::AnthropicModel;
86    use crate::ModelVersion::{V0, V1, V14, V15, V2, V3};
87
88    // Unit tests for ever single model version, in order
89    #[test]
90    fn test_base_model_to_string() {
91        let model = BaseModel::AI21Labs(AI21LabsModel::Jurassic2Mid(V1));
92        assert_eq!(model.to_string(), "ai21.j2-mid-v1");
93
94        let model = BaseModel::AI21Labs(AI21LabsModel::Jurassic2Ultra(V1));
95        assert_eq!(model.to_string(), "ai21.j2-ultra-v1");
96
97        let model = BaseModel::Amazon(AmazonModel::TitanTextLite(V1));
98        assert_eq!(model.to_string(), "amazon.titan-text-lite-v1");
99
100        let model = BaseModel::Amazon(AmazonModel::TitanEmbeddingsText(V1));
101        assert_eq!(model.to_string(), "amazon.titan-embed-text-v1");
102
103        let model = BaseModel::Amazon(AmazonModel::TitanTextExpress(V1));
104        assert_eq!(model.to_string(), "amazon.titan-text-express-v1");
105
106        let model = BaseModel::Amazon(AmazonModel::TitanTextAgile(V1));
107        assert_eq!(model.to_string(), "amazon.titan-text-agile-v1");
108
109        let model = BaseModel::Anthropic(AnthropicModel::Claude(V1));
110        assert_eq!(model.to_string(), "anthropic.claude-v1");
111
112        let model = BaseModel::Anthropic(AnthropicModel::Claude(V2));
113        assert_eq!(model.to_string(), "anthropic.claude-v2");
114
115        let model = BaseModel::Anthropic(AnthropicModel::ClaudeInstant(V1));
116        assert_eq!(model.to_string(), "anthropic.claude-instant-v1");
117
118        let model = BaseModel::Cohere(CohereModel::Command(V14));
119        assert_eq!(model.to_string(), "cohere.command-text-v14");
120
121        let model = BaseModel::Cohere(CohereModel::CommandLight(V15));
122        assert_eq!(model.to_string(), "cohere.command-light-text-v14");
123
124        let model = BaseModel::Cohere(CohereModel::EmbedEnglish(V3));
125        assert_eq!(model.to_string(), "cohere.embed-english-v3");
126
127        let model = BaseModel::Cohere(CohereModel::EmbedMultilingual(V3));
128        assert_eq!(model.to_string(), "cohere.embed-multilingual-v3");
129
130        let model = BaseModel::Meta(MetaModel::Llama2Chat13B(V1));
131        assert_eq!(model.to_string(), "meta.llama2-13b-chat-v1");
132
133        let model = BaseModel::StabilityAI(StabilityAIModel::StableDiffusionXL(V0));
134        assert_eq!(model.to_string(), "stability.stable-diffusion-xl-v0");
135    }
136}