1use std::sync::Arc;
2pub use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
3use synaptic_models::ProviderBackend;
4pub use synaptic_openai::OpenAiEmbeddings;
5use synaptic_openai::{OpenAiChatModel, OpenAiConfig};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum GroqModel {
9 Llama3_3_70bVersatile,
10 Llama3_1_8bInstant,
11 Llama3_1_70bVersatile,
12 Gemma2_9bIt,
13 Mixtral8x7b32768,
14 Custom(String),
15}
16impl GroqModel {
17 pub fn as_str(&self) -> &str {
18 match self {
19 GroqModel::Llama3_3_70bVersatile => "llama-3.3-70b-versatile",
20 GroqModel::Llama3_1_8bInstant => "llama-3.1-8b-instant",
21 GroqModel::Llama3_1_70bVersatile => "llama-3.1-70b-versatile",
22 GroqModel::Gemma2_9bIt => "gemma2-9b-it",
23 GroqModel::Mixtral8x7b32768 => "mixtral-8x7b-32768",
24 GroqModel::Custom(s) => s.as_str(),
25 }
26 }
27}
28impl std::fmt::Display for GroqModel {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "{}", self.as_str())
31 }
32}
33
34#[derive(Debug, Clone)]
35pub struct GroqConfig {
36 pub api_key: String,
37 pub model: String,
38 pub max_tokens: Option<u32>,
39 pub temperature: Option<f64>,
40 pub top_p: Option<f64>,
41 pub stop: Option<Vec<String>>,
42 pub seed: Option<u64>,
43}
44impl GroqConfig {
45 pub fn new(api_key: impl Into<String>, model: GroqModel) -> Self {
46 Self {
47 api_key: api_key.into(),
48 model: model.to_string(),
49 max_tokens: None,
50 temperature: None,
51 top_p: None,
52 stop: None,
53 seed: None,
54 }
55 }
56 pub fn new_custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
57 Self {
58 api_key: api_key.into(),
59 model: model.into(),
60 max_tokens: None,
61 temperature: None,
62 top_p: None,
63 stop: None,
64 seed: None,
65 }
66 }
67 pub fn with_max_tokens(mut self, v: u32) -> Self {
68 self.max_tokens = Some(v);
69 self
70 }
71 pub fn with_temperature(mut self, v: f64) -> Self {
72 self.temperature = Some(v);
73 self
74 }
75 pub fn with_top_p(mut self, v: f64) -> Self {
76 self.top_p = Some(v);
77 self
78 }
79 pub fn with_stop(mut self, v: Vec<String>) -> Self {
80 self.stop = Some(v);
81 self
82 }
83 pub fn with_seed(mut self, v: u64) -> Self {
84 self.seed = Some(v);
85 self
86 }
87}
88impl From<GroqConfig> for OpenAiConfig {
89 fn from(c: GroqConfig) -> Self {
90 let mut cfg =
91 OpenAiConfig::new(c.api_key, c.model).with_base_url("https://api.groq.com/openai/v1");
92 if let Some(v) = c.max_tokens {
93 cfg = cfg.with_max_tokens(v);
94 }
95 if let Some(v) = c.temperature {
96 cfg = cfg.with_temperature(v);
97 }
98 if let Some(v) = c.top_p {
99 cfg = cfg.with_top_p(v);
100 }
101 if let Some(v) = c.stop {
102 cfg = cfg.with_stop(v);
103 }
104 if let Some(v) = c.seed {
105 cfg = cfg.with_seed(v);
106 }
107 cfg
108 }
109}
110
111pub struct GroqChatModel {
112 inner: OpenAiChatModel,
113}
114
115impl GroqChatModel {
116 pub fn new(config: GroqConfig, backend: Arc<dyn ProviderBackend>) -> Self {
117 Self {
118 inner: OpenAiChatModel::new(config.into(), backend),
119 }
120 }
121}
122
123#[async_trait::async_trait]
124impl ChatModel for GroqChatModel {
125 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
126 self.inner.chat(request).await
127 }
128 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
129 self.inner.stream_chat(request)
130 }
131}