synaptic_fireworks/
lib.rs1use std::sync::Arc;
2pub use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
3use synaptic_models::ProviderBackend;
4use synaptic_openai::{OpenAiChatModel, OpenAiConfig};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum FireworksModel {
8 Llama3_1_70bInstruct,
9 Llama3_1_8bInstruct,
10 DeepSeekR1,
11 Qwen2_5_72bInstruct,
12 Custom(String),
13}
14
15impl FireworksModel {
16 pub fn as_str(&self) -> &str {
17 match self {
18 FireworksModel::Llama3_1_70bInstruct => {
19 "accounts/fireworks/models/llama-v3p1-70b-instruct"
20 }
21 FireworksModel::Llama3_1_8bInstruct => {
22 "accounts/fireworks/models/llama-v3p1-8b-instruct"
23 }
24 FireworksModel::DeepSeekR1 => "accounts/fireworks/models/deepseek-r1",
25 FireworksModel::Qwen2_5_72bInstruct => "accounts/fireworks/models/qwen2p5-72b-instruct",
26 FireworksModel::Custom(s) => s.as_str(),
27 }
28 }
29}
30
31impl std::fmt::Display for FireworksModel {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "{}", self.as_str())
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct FireworksConfig {
39 pub api_key: String,
40 pub model: String,
41 pub max_tokens: Option<u32>,
42 pub temperature: Option<f64>,
43 pub top_p: Option<f64>,
44 pub stop: Option<Vec<String>>,
45}
46
47impl FireworksConfig {
48 pub fn new(api_key: impl Into<String>, model: FireworksModel) -> Self {
49 Self {
50 api_key: api_key.into(),
51 model: model.to_string(),
52 max_tokens: None,
53 temperature: None,
54 top_p: None,
55 stop: None,
56 }
57 }
58 pub fn new_custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
59 Self {
60 api_key: api_key.into(),
61 model: model.into(),
62 max_tokens: None,
63 temperature: None,
64 top_p: None,
65 stop: None,
66 }
67 }
68 pub fn with_max_tokens(mut self, v: u32) -> Self {
69 self.max_tokens = Some(v);
70 self
71 }
72 pub fn with_temperature(mut self, v: f64) -> Self {
73 self.temperature = Some(v);
74 self
75 }
76 pub fn with_top_p(mut self, v: f64) -> Self {
77 self.top_p = Some(v);
78 self
79 }
80 pub fn with_stop(mut self, v: Vec<String>) -> Self {
81 self.stop = Some(v);
82 self
83 }
84}
85
86impl From<FireworksConfig> for OpenAiConfig {
87 fn from(c: FireworksConfig) -> Self {
88 let mut cfg = OpenAiConfig::new(c.api_key, c.model)
89 .with_base_url("https://api.fireworks.ai/inference/v1");
90 if let Some(v) = c.max_tokens {
91 cfg = cfg.with_max_tokens(v);
92 }
93 if let Some(v) = c.temperature {
94 cfg = cfg.with_temperature(v);
95 }
96 if let Some(v) = c.top_p {
97 cfg = cfg.with_top_p(v);
98 }
99 if let Some(v) = c.stop {
100 cfg = cfg.with_stop(v);
101 }
102 cfg
103 }
104}
105
106pub struct FireworksChatModel {
107 inner: OpenAiChatModel,
108}
109
110impl FireworksChatModel {
111 pub fn new(config: FireworksConfig, backend: Arc<dyn ProviderBackend>) -> Self {
112 Self {
113 inner: OpenAiChatModel::new(config.into(), backend),
114 }
115 }
116}
117
118#[async_trait::async_trait]
119impl ChatModel for FireworksChatModel {
120 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
121 self.inner.chat(request).await
122 }
123 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
124 self.inner.stream_chat(request)
125 }
126}