Skip to main content

synaptic_groq/
lib.rs

1use std::sync::Arc;
2pub use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
3use synaptic_models::ProviderBackend;
4pub use synaptic_openai::OpenAiEmbeddings;
5use synaptic_openai::{OpenAiChatModel, OpenAiConfig};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum GroqModel {
9    Llama3_3_70bVersatile,
10    Llama3_1_8bInstant,
11    Llama3_1_70bVersatile,
12    Gemma2_9bIt,
13    Mixtral8x7b32768,
14    Custom(String),
15}
16impl GroqModel {
17    pub fn as_str(&self) -> &str {
18        match self {
19            GroqModel::Llama3_3_70bVersatile => "llama-3.3-70b-versatile",
20            GroqModel::Llama3_1_8bInstant => "llama-3.1-8b-instant",
21            GroqModel::Llama3_1_70bVersatile => "llama-3.1-70b-versatile",
22            GroqModel::Gemma2_9bIt => "gemma2-9b-it",
23            GroqModel::Mixtral8x7b32768 => "mixtral-8x7b-32768",
24            GroqModel::Custom(s) => s.as_str(),
25        }
26    }
27}
28impl std::fmt::Display for GroqModel {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "{}", self.as_str())
31    }
32}
33
34#[derive(Debug, Clone)]
35pub struct GroqConfig {
36    pub api_key: String,
37    pub model: String,
38    pub max_tokens: Option<u32>,
39    pub temperature: Option<f64>,
40    pub top_p: Option<f64>,
41    pub stop: Option<Vec<String>>,
42    pub seed: Option<u64>,
43}
44impl GroqConfig {
45    pub fn new(api_key: impl Into<String>, model: GroqModel) -> Self {
46        Self {
47            api_key: api_key.into(),
48            model: model.to_string(),
49            max_tokens: None,
50            temperature: None,
51            top_p: None,
52            stop: None,
53            seed: None,
54        }
55    }
56    pub fn new_custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
57        Self {
58            api_key: api_key.into(),
59            model: model.into(),
60            max_tokens: None,
61            temperature: None,
62            top_p: None,
63            stop: None,
64            seed: None,
65        }
66    }
67    pub fn with_max_tokens(mut self, v: u32) -> Self {
68        self.max_tokens = Some(v);
69        self
70    }
71    pub fn with_temperature(mut self, v: f64) -> Self {
72        self.temperature = Some(v);
73        self
74    }
75    pub fn with_top_p(mut self, v: f64) -> Self {
76        self.top_p = Some(v);
77        self
78    }
79    pub fn with_stop(mut self, v: Vec<String>) -> Self {
80        self.stop = Some(v);
81        self
82    }
83    pub fn with_seed(mut self, v: u64) -> Self {
84        self.seed = Some(v);
85        self
86    }
87}
88impl From<GroqConfig> for OpenAiConfig {
89    fn from(c: GroqConfig) -> Self {
90        let mut cfg =
91            OpenAiConfig::new(c.api_key, c.model).with_base_url("https://api.groq.com/openai/v1");
92        if let Some(v) = c.max_tokens {
93            cfg = cfg.with_max_tokens(v);
94        }
95        if let Some(v) = c.temperature {
96            cfg = cfg.with_temperature(v);
97        }
98        if let Some(v) = c.top_p {
99            cfg = cfg.with_top_p(v);
100        }
101        if let Some(v) = c.stop {
102            cfg = cfg.with_stop(v);
103        }
104        if let Some(v) = c.seed {
105            cfg = cfg.with_seed(v);
106        }
107        cfg
108    }
109}
110
111pub struct GroqChatModel {
112    inner: OpenAiChatModel,
113}
114
115impl GroqChatModel {
116    pub fn new(config: GroqConfig, backend: Arc<dyn ProviderBackend>) -> Self {
117        Self {
118            inner: OpenAiChatModel::new(config.into(), backend),
119        }
120    }
121}
122
123#[async_trait::async_trait]
124impl ChatModel for GroqChatModel {
125    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
126        self.inner.chat(request).await
127    }
128    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
129        self.inner.stream_chat(request)
130    }
131}