1use std::fmt::Display;
6
7use super::{embedding::EmbeddingBytes, SenseError};
8use base64::{engine::general_purpose::STANDARD as DECODER, Engine as _};
9use reqwest::{header::HeaderMap, Client, ClientBuilder, Url};
10use serde::{Deserialize, Serialize};
11
12#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Model {
17 #[serde(rename = "BAAI/bge-large-zh-v1.5")]
19 BgeLargeZhV1_5,
20 #[serde(rename = "BAAI/bge-large-en-v1.5")]
22 BgeLargeEnV1_5,
23 #[serde(rename = "netease-youdao/bce-embedding-base_v1")]
25 BceEmbeddingBaseV1,
26 #[serde(rename = "BAAI/bge-m3")]
28 BgeM3,
29 #[serde(rename = "Pro/BAAI/bge-m3")]
31 ProBgeM3,
32}
33
34impl Default for Model {
35 fn default() -> Self {
36 Self::BgeLargeZhV1_5
37 }
38}
39
40impl Display for Model {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(
43 f,
44 "{}",
45 serde_json::to_string(self).unwrap().trim_matches('"')
46 )
47 }
48}
49
50fn validate_api_key(key: &str) -> Result<(), SenseError> {
52 if key.len() != 51 {
53 return Err(SenseError::MalformedApiKey);
54 }
55 for c in key.chars().skip(3) {
56 if !c.is_ascii_alphanumeric() {
57 return Err(SenseError::MalformedApiKey);
58 }
59 }
60 Ok(())
61}
62
63#[derive(Serialize)]
67struct RequestBody<'a> {
68 model: &'a str,
70 input: &'a str,
72 encoding_format: &'a str,
74}
75
76#[derive(Deserialize)]
78struct Data {
79 #[serde(rename = "object")]
81 _object: String,
82 embedding: String,
84 #[serde(rename = "index")]
86 _index: i32,
87}
88
89#[derive(Deserialize)]
91#[allow(dead_code, reason = "For deserialization only")]
92#[allow(clippy::struct_field_names, reason = "Consistency with API response")]
93struct Usage {
94 prompt_tokens: u32,
96 completion_tokens: u32,
98 total_tokens: u32,
100}
101
102#[derive(Deserialize)]
104struct ResponseBody {
105 model: String,
107 data: Vec<Data>,
109 #[serde(rename = "usage")]
111 _usage: Usage,
112}
113
114#[derive(Clone)]
118pub struct ApiClient {
119 model: String,
121 endpoint: Url,
123 client: Client,
125}
126
127impl ApiClient {
128 #[allow(clippy::missing_panics_doc, reason = "URL is hardcoded")]
134 pub fn new(key: &str, model: Model) -> Result<Self, SenseError> {
135 validate_api_key(key)?;
136 let mut headers = HeaderMap::new();
137 headers.insert("Authorization", format!("Bearer {key}").parse()?);
138 let client = ClientBuilder::new().default_headers(headers).build()?;
139
140 Ok(Self {
141 model: model.to_string(),
142 endpoint: Url::parse("https://api.siliconflow.cn/v1/embeddings").unwrap(),
143 client,
144 })
145 }
146
147 pub async fn embed(&self, text: &str) -> Result<EmbeddingBytes, SenseError> {
157 let request_body = RequestBody {
158 model: &self.model,
159 input: text,
160 encoding_format: "base64",
161 };
162 let request = self.client.post(self.endpoint.clone()).json(&request_body);
163
164 let response: ResponseBody = request.send().await?.json().await?;
165 debug_assert_eq!(response.model, self.model);
166
167 let embedding = DECODER.decode(response.data[0].embedding.as_bytes())?;
168 Ok(embedding.try_into()?)
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 const KEY: &str = "sk-1234567890abcdef1234567890abcdef1234567890abcdef";
177
178 #[test]
179 fn test_api_key_ok() {
180 validate_api_key(KEY).unwrap();
181 }
182
183 #[test]
184 fn test_api_key_malformed() {
185 let malformed = &KEY[..KEY.len() - 1];
186 let err = validate_api_key(malformed).unwrap_err();
187 assert!(matches!(err, SenseError::MalformedApiKey));
188 }
189
190 #[tokio::test]
191 #[ignore = "requires API key in `SILICONFLOW_API_KEY` env var"]
192 async fn test_embed() {
193 let key = std::env::var("SILICONFLOW_API_KEY").unwrap();
195 let client = ApiClient::new(&key, Model::BgeLargeZhV1_5).unwrap();
196 let embedding = client.embed("Hello, world!").await;
197 let _ = embedding.unwrap();
198 }
199}