transformrs/
lib.rs

1//! Transformrs is a Rust library for interacting with various AI APIs.
2
3pub mod chat;
4pub mod models;
5pub mod text_to_image;
6pub mod text_to_speech;
7
8use base64::prelude::BASE64_STANDARD;
9use base64::Engine;
10use reqwest::header::HeaderMap;
11use reqwest::header::HeaderValue;
12use serde::Deserialize;
13use serde::Serialize;
14use std::collections::HashMap;
15use std::error::Error;
16use std::fs::File;
17use std::io::Read;
18use std::str::FromStr;
19
20pub(crate) fn request_headers(key: &Key) -> Result<HeaderMap, Box<dyn Error + Send + Sync>> {
21    let mut headers = HeaderMap::new();
22    headers.insert(
23        "Authorization",
24        HeaderValue::from_str(&format!("Bearer {}", key.key))?,
25    );
26    headers.insert("Content-Type", HeaderValue::from_str("application/json")?);
27    Ok(headers)
28}
29
30pub(crate) fn openai_base_url(provider: &Provider) -> String {
31    match provider {
32        Provider::Google => format!("{}/v1beta/openai", provider.domain()),
33        Provider::Groq => format!("{}/openai/v1", provider.domain()),
34        Provider::Hyperbolic => format!("{}/v1", provider.domain()),
35        Provider::Mistral => format!("{}/v1", provider.domain()),
36        Provider::OpenAI => format!("{}/v1", provider.domain()),
37        Provider::OpenAICompatible(domain) => domain.clone(),
38        Provider::SambaNova => format!("{}/v1", provider.domain()),
39        Provider::TogetherAI => format!("{}/v1", provider.domain()),
40        _ => format!("{}/v1/openai", provider.domain()),
41    }
42}
43
44#[allow(rustdoc::bare_urls)]
45#[derive(Clone, Debug, Serialize, PartialEq)]
46pub enum Provider {
47    Amazon,
48    Azure,
49    Cerebras,
50    DeepInfra,
51    ElevenLabs,
52    Fireworks,
53    FriendliAI,
54    Google,
55    Groq,
56    Hyperbolic,
57    Mistral,
58    Nebius,
59    Novita,
60    OpenAI,
61    /// Another OpenAI-compatible provider.
62    ///
63    /// For example, "https://api.deepinfra.com/v1/openai".
64    OpenAICompatible(String),
65    SambaNova,
66    TogetherAI,
67}
68
69impl std::fmt::Display for Provider {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "{:?}", self)
72    }
73}
74
75impl Provider {
76    pub fn domain(&self) -> String {
77        match self {
78            Provider::Amazon => "https://api.amazon.com",
79            Provider::Azure => "https://api.azure.com",
80            Provider::Cerebras => "https://api.cerebras.ai",
81            Provider::DeepInfra => "https://api.deepinfra.com",
82            Provider::ElevenLabs => "https://api.elevenlabs.io",
83            Provider::Fireworks => "https://api.fireworks.ai",
84            Provider::FriendliAI => "https://api.friendli.ai",
85            Provider::Google => "https://generativelanguage.googleapis.com",
86            Provider::Groq => "https://api.groq.com",
87            Provider::Hyperbolic => "https://api.hyperbolic.xyz",
88            Provider::Mistral => "https://api.mistral.ai",
89            Provider::Nebius => "https://api.nebi.us",
90            Provider::Novita => "https://api.novita.ai",
91            Provider::OpenAI => "https://api.openai.com",
92            Provider::OpenAICompatible(base_url) => base_url,
93            Provider::SambaNova => "https://api.sambanova.ai",
94            Provider::TogetherAI => "https://api.together.xyz",
95        }
96        .to_string()
97    }
98    pub fn key_name(&self) -> String {
99        match self {
100            Provider::OpenAICompatible(_) => "OPENAI_COMPATIBLE_KEY".to_string(),
101            _ => self.to_string().to_uppercase() + "_KEY",
102        }
103    }
104}
105
106impl FromStr for Provider {
107    type Err = Box<dyn Error + Send + Sync>;
108
109    fn from_str(s: &str) -> Result<Self, Self::Err> {
110        let s = s.to_lowercase();
111        if s.starts_with("openai-compatible(") {
112            let s = s.strip_prefix("openai-compatible(").unwrap();
113            let s = s.strip_suffix(")").unwrap();
114            let mut domain = s.to_string();
115            if !domain.starts_with("https") {
116                if domain.contains("localhost") {
117                    domain = format!("http://{}", domain);
118                } else {
119                    domain = format!("https://{}", domain);
120                }
121            }
122            return Ok(Provider::OpenAICompatible(domain));
123        }
124        match s.as_str() {
125            "amazon" => Ok(Provider::Amazon),
126            "azure" => Ok(Provider::Azure),
127            "cerebras" => Ok(Provider::Cerebras),
128            "deepinfra" => Ok(Provider::DeepInfra),
129            "elevenlabs" => Ok(Provider::ElevenLabs),
130            "fireworks" => Ok(Provider::Fireworks),
131            "friendliai" => Ok(Provider::FriendliAI),
132            "google" => Ok(Provider::Google),
133            "groq" => Ok(Provider::Groq),
134            "hyperbolic" => Ok(Provider::Hyperbolic),
135            "mistral" => Ok(Provider::Mistral),
136            "nebi" => Ok(Provider::Nebius),
137            "novita" => Ok(Provider::Novita),
138            "openai" => Ok(Provider::OpenAI),
139            "sambanova" => Ok(Provider::SambaNova),
140            "togetherai" => Ok(Provider::TogetherAI),
141            _ => Err(format!("Unsupported provider: {s}.").into()),
142        }
143    }
144}
145
146#[derive(Clone, Debug, Deserialize)]
147pub enum SubContent {
148    TextContent { text: String },
149    ImageUrlContent { image_url: String },
150}
151
152impl SubContent {
153    pub fn new(r#type: &str, text: &str) -> Self {
154        match r#type {
155            "text" => Self::TextContent {
156                text: text.to_string(),
157            },
158            "image_url" => Self::ImageUrlContent {
159                image_url: text.to_string(),
160            },
161            _ => panic!("Invalid subcontent type: {}", r#type),
162        }
163    }
164}
165
166impl Serialize for SubContent {
167    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168    where
169        S: serde::Serializer,
170    {
171        match self {
172            SubContent::TextContent { text } => serializer.serialize_str(text),
173            SubContent::ImageUrlContent { image_url } => {
174                let json = serde_json::json!({
175                    "type": "image_url",
176                    "image_url": {
177                        "url": image_url
178                    }
179                });
180                json.serialize(serializer)
181            }
182        }
183    }
184}
185
186#[derive(Clone, Debug)]
187pub enum Content {
188    Text(String),
189    Collection(Vec<SubContent>),
190}
191
192impl std::fmt::Display for Content {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            Content::Text(text) => write!(f, "{}", text),
196            Content::Collection(items) => {
197                write!(f, "{items:?}")
198            }
199        }
200    }
201}
202
203impl Serialize for Content {
204    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
205    where
206        S: serde::Serializer,
207    {
208        match self {
209            Content::Text(text) => serializer.serialize_str(text),
210            Content::Collection(items) => items.serialize(serializer),
211        }
212    }
213}
214
215impl<'de> Deserialize<'de> for Content {
216    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217    where
218        D: serde::Deserializer<'de>,
219    {
220        let value = serde_json::Value::deserialize(deserializer)?;
221        if let serde_json::Value::String(text) = value {
222            Ok(Content::Text(text))
223        } else if let serde_json::Value::Array(items) = value {
224            let subcontent = items
225                .into_iter()
226                .map(SubContent::deserialize)
227                .collect::<Result<Vec<_>, _>>()
228                .unwrap();
229            Ok(Content::Collection(subcontent))
230        } else {
231            Err(serde::de::Error::custom("Invalid content format"))
232        }
233    }
234}
235
236#[derive(Clone, Debug, Deserialize, Serialize)]
237pub struct Message {
238    pub role: String,
239    pub content: Content,
240}
241
242impl Message {
243    pub fn from_str(role: &str, text: &str) -> Self {
244        Self {
245            role: role.to_string(),
246            content: Content::Text(text.to_string()),
247        }
248    }
249    pub fn from_image_url(role: &str, image_url: &str) -> Self {
250        Self {
251            role: role.to_string(),
252            content: Content::Collection(vec![SubContent::ImageUrlContent {
253                image_url: image_url.to_string(),
254            }]),
255        }
256    }
257    pub fn from_image_bytes(role: &str, image_type: &str, image: &[u8]) -> Self {
258        let base64 = BASE64_STANDARD.encode(image);
259        let image_url = format!("data:image/{image_type};base64,{base64}");
260        Self::from_image_url(role, &image_url)
261    }
262}
263
264#[derive(Clone, Debug)]
265pub struct Key {
266    pub provider: Provider,
267    pub key: String,
268}
269
270#[derive(Clone, Debug)]
271pub struct Keys {
272    pub keys: Vec<Key>,
273}
274
275impl Keys {
276    pub fn for_provider(&self, provider: &Provider) -> Option<Key> {
277        fn finder(provider: &Provider, key: &Key) -> bool {
278            match provider {
279                Provider::OpenAICompatible(_) => {
280                    matches!(&key.provider, Provider::OpenAICompatible(_))
281                }
282                _ => key.provider == *provider,
283            }
284        }
285
286        self.keys.iter().find(|key| finder(provider, key)).cloned()
287    }
288}
289
290fn load_env_file(path: &str) -> HashMap<String, String> {
291    let mut env_content = String::new();
292    if let Ok(mut file) = File::open(path) {
293        file.read_to_string(&mut env_content)
294            .expect("Failed to read .env file");
295    }
296    env_content
297        .lines()
298        .filter_map(|line| {
299            let mut parts = line.split('=');
300            if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
301                Some((key.to_string(), value.to_string()))
302            } else {
303                None
304            }
305        })
306        .collect()
307}
308
309/// Load the keys from either the .env file or environment variables.
310pub fn load_keys(path: &str) -> Keys {
311    let env_map = load_env_file(path);
312
313    let mut keys = vec![];
314
315    let providers = [
316        Provider::Amazon,
317        Provider::Azure,
318        Provider::Cerebras,
319        Provider::DeepInfra,
320        Provider::ElevenLabs,
321        Provider::Fireworks,
322        Provider::FriendliAI,
323        Provider::Google,
324        Provider::Groq,
325        Provider::Hyperbolic,
326        Provider::Mistral,
327        Provider::Nebius,
328        Provider::Novita,
329        Provider::OpenAI,
330        Provider::OpenAICompatible("".to_string()),
331        Provider::SambaNova,
332        Provider::TogetherAI,
333    ];
334    for provider in providers {
335        if let Ok(key_value) = std::env::var(provider.key_name()) {
336            keys.push(Key {
337                provider: provider.clone(),
338                key: key_value,
339            });
340        } else if let Some(key_value) = env_map.get(&provider.key_name()) {
341            keys.push(Key {
342                provider: provider.clone(),
343                key: key_value.to_string(),
344            });
345        }
346    }
347    Keys { keys }
348}