semantic_search/
api.rs

1//! # Silicon Flow module
2//!
3//! This module contains logic for the Silicon Flow API.
4
5use std::fmt::Display;
6
7use super::{embedding::EmbeddingBytes, SenseError};
8use base64::{engine::general_purpose::STANDARD as DECODER, Engine as _};
9use reqwest::{header::HeaderMap, Client, ClientBuilder, Url};
10use serde::{Deserialize, Serialize};
11
12// == API key validation and model definitions ==
13
14/// Available models.
15#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Model {
17    /// BAAI/bge-large-zh-v1.5
18    #[serde(rename = "BAAI/bge-large-zh-v1.5")]
19    BgeLargeZhV1_5,
20    /// BAAI/bge-large-en-v1.5
21    #[serde(rename = "BAAI/bge-large-en-v1.5")]
22    BgeLargeEnV1_5,
23    /// netease-youdao/bce-embedding-base_v1
24    #[serde(rename = "netease-youdao/bce-embedding-base_v1")]
25    BceEmbeddingBaseV1,
26    /// BAAI/bge-m3
27    #[serde(rename = "BAAI/bge-m3")]
28    BgeM3,
29    /// Pro/BAAI/bge-m3
30    #[serde(rename = "Pro/BAAI/bge-m3")]
31    ProBgeM3,
32}
33
34impl Default for Model {
35    fn default() -> Self {
36        Self::BgeLargeZhV1_5
37    }
38}
39
40impl Display for Model {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(
43            f,
44            "{}",
45            serde_json::to_string(self).unwrap().trim_matches('"')
46        )
47    }
48}
49
50/// Validate that the API key is well-formed.
51fn validate_api_key(key: &str) -> Result<(), SenseError> {
52    if key.len() != 51 {
53        return Err(SenseError::MalformedApiKey);
54    }
55    for c in key.chars().skip(3) {
56        if !c.is_ascii_alphanumeric() {
57            return Err(SenseError::MalformedApiKey);
58        }
59    }
60    Ok(())
61}
62
63// == Request and response definitions ==
64
65/// The request body for the Silicon Flow API.
66#[derive(Serialize)]
67struct RequestBody<'a> {
68    /// The model to use.
69    model: &'a str,
70    /// The input text.
71    input: &'a str,
72    /// The encoding format, either "float" or "base64".
73    encoding_format: &'a str,
74}
75
76/// ResponseBody.data: The list of embeddings generated by the model.
77#[derive(Deserialize)]
78struct Data {
79    /// Fixed string "embedding".
80    #[serde(rename = "object")]
81    _object: String,
82    /// Base64-encoded embedding.
83    embedding: String,
84    /// Unused.
85    #[serde(rename = "index")]
86    _index: i32,
87}
88
89/// ResponseBody.usage: The usage information for the request.
90#[derive(Deserialize)]
91#[allow(dead_code, reason = "For deserialization only")]
92#[allow(clippy::struct_field_names, reason = "Consistency with API response")]
93struct Usage {
94    /// The number of tokens used by the prompt.
95    prompt_tokens: u32,
96    /// The number of tokens used by the completion.
97    completion_tokens: u32,
98    /// The total number of tokens used by the request.
99    total_tokens: u32,
100}
101
102/// The response body for the Silicon Flow API.
103#[derive(Deserialize)]
104struct ResponseBody {
105    /// The name of the model used to generate the embedding.
106    model: String,
107    /// The list of embeddings generated by the model.
108    data: Vec<Data>,
109    /// The usage information for the request.
110    #[serde(rename = "usage")]
111    _usage: Usage,
112}
113
114// == API client ==
115
116/// A client for the Silicon Flow API.
117#[derive(Clone)]
118pub struct ApiClient {
119    /// The model to use.
120    model: String,
121    /// API endpoint.
122    endpoint: Url,
123    /// HTTP client.
124    client: Client,
125}
126
127impl ApiClient {
128    /// Create a new API client.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the API key is malformed or the HTTP client cannot be created.
133    #[allow(clippy::missing_panics_doc, reason = "URL is hardcoded")]
134    pub fn new(key: &str, model: Model) -> Result<Self, SenseError> {
135        validate_api_key(key)?;
136        let mut headers = HeaderMap::new();
137        headers.insert("Authorization", format!("Bearer {key}").parse()?);
138        let client = ClientBuilder::new().default_headers(headers).build()?;
139
140        Ok(Self {
141            model: model.to_string(),
142            endpoint: Url::parse("https://api.siliconflow.cn/v1/embeddings").unwrap(),
143            client,
144        })
145    }
146
147    /// Embed a text.
148    ///
149    /// # Errors
150    ///
151    /// Returns:
152    ///
153    /// - [`SenseError::RequestFailed`] if the request fails
154    /// - [`SenseError::Base64DecodingFailed`] if base64 decoding fails
155    /// - [`SenseError::DimensionMismatch`] if the embedding is not 1024-dimensional.
156    pub async fn embed(&self, text: &str) -> Result<EmbeddingBytes, SenseError> {
157        let request_body = RequestBody {
158            model: &self.model,
159            input: text,
160            encoding_format: "base64",
161        };
162        let request = self.client.post(self.endpoint.clone()).json(&request_body);
163
164        let response: ResponseBody = request.send().await?.json().await?;
165        debug_assert_eq!(response.model, self.model);
166
167        let embedding = DECODER.decode(response.data[0].embedding.as_bytes())?;
168        Ok(embedding.try_into()?)
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    const KEY: &str = "sk-1234567890abcdef1234567890abcdef1234567890abcdef";
177
178    #[test]
179    fn test_api_key_ok() {
180        validate_api_key(KEY).unwrap();
181    }
182
183    #[test]
184    fn test_api_key_malformed() {
185        let malformed = &KEY[..KEY.len() - 1];
186        let err = validate_api_key(malformed).unwrap_err();
187        assert!(matches!(err, SenseError::MalformedApiKey));
188    }
189
190    #[tokio::test]
191    #[ignore = "requires API key in `SILICONFLOW_API_KEY` env var"]
192    async fn test_embed() {
193        // Read the API key from the environment
194        let key = std::env::var("SILICONFLOW_API_KEY").unwrap();
195        let client = ApiClient::new(&key, Model::BgeLargeZhV1_5).unwrap();
196        let embedding = client.embed("Hello, world!").await;
197        let _ = embedding.unwrap();
198    }
199}