Skip to main content

synaptic_fireworks/
lib.rs

1use std::sync::Arc;
2pub use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
3use synaptic_models::ProviderBackend;
4use synaptic_openai::{OpenAiChatModel, OpenAiConfig};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum FireworksModel {
8    Llama3_1_70bInstruct,
9    Llama3_1_8bInstruct,
10    DeepSeekR1,
11    Qwen2_5_72bInstruct,
12    Custom(String),
13}
14
15impl FireworksModel {
16    pub fn as_str(&self) -> &str {
17        match self {
18            FireworksModel::Llama3_1_70bInstruct => {
19                "accounts/fireworks/models/llama-v3p1-70b-instruct"
20            }
21            FireworksModel::Llama3_1_8bInstruct => {
22                "accounts/fireworks/models/llama-v3p1-8b-instruct"
23            }
24            FireworksModel::DeepSeekR1 => "accounts/fireworks/models/deepseek-r1",
25            FireworksModel::Qwen2_5_72bInstruct => "accounts/fireworks/models/qwen2p5-72b-instruct",
26            FireworksModel::Custom(s) => s.as_str(),
27        }
28    }
29}
30
31impl std::fmt::Display for FireworksModel {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "{}", self.as_str())
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct FireworksConfig {
39    pub api_key: String,
40    pub model: String,
41    pub max_tokens: Option<u32>,
42    pub temperature: Option<f64>,
43    pub top_p: Option<f64>,
44    pub stop: Option<Vec<String>>,
45}
46
47impl FireworksConfig {
48    pub fn new(api_key: impl Into<String>, model: FireworksModel) -> Self {
49        Self {
50            api_key: api_key.into(),
51            model: model.to_string(),
52            max_tokens: None,
53            temperature: None,
54            top_p: None,
55            stop: None,
56        }
57    }
58    pub fn new_custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
59        Self {
60            api_key: api_key.into(),
61            model: model.into(),
62            max_tokens: None,
63            temperature: None,
64            top_p: None,
65            stop: None,
66        }
67    }
68    pub fn with_max_tokens(mut self, v: u32) -> Self {
69        self.max_tokens = Some(v);
70        self
71    }
72    pub fn with_temperature(mut self, v: f64) -> Self {
73        self.temperature = Some(v);
74        self
75    }
76    pub fn with_top_p(mut self, v: f64) -> Self {
77        self.top_p = Some(v);
78        self
79    }
80    pub fn with_stop(mut self, v: Vec<String>) -> Self {
81        self.stop = Some(v);
82        self
83    }
84}
85
86impl From<FireworksConfig> for OpenAiConfig {
87    fn from(c: FireworksConfig) -> Self {
88        let mut cfg = OpenAiConfig::new(c.api_key, c.model)
89            .with_base_url("https://api.fireworks.ai/inference/v1");
90        if let Some(v) = c.max_tokens {
91            cfg = cfg.with_max_tokens(v);
92        }
93        if let Some(v) = c.temperature {
94            cfg = cfg.with_temperature(v);
95        }
96        if let Some(v) = c.top_p {
97            cfg = cfg.with_top_p(v);
98        }
99        if let Some(v) = c.stop {
100            cfg = cfg.with_stop(v);
101        }
102        cfg
103    }
104}
105
106pub struct FireworksChatModel {
107    inner: OpenAiChatModel,
108}
109
110impl FireworksChatModel {
111    pub fn new(config: FireworksConfig, backend: Arc<dyn ProviderBackend>) -> Self {
112        Self {
113            inner: OpenAiChatModel::new(config.into(), backend),
114        }
115    }
116}
117
118#[async_trait::async_trait]
119impl ChatModel for FireworksChatModel {
120    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
121        self.inner.chat(request).await
122    }
123    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
124        self.inner.stream_chat(request)
125    }
126}