Skip to main content

u_sdk/deep_seek/
mod.rs

1//! DeepSeek sdk
2
3mod types;
4pub use types::*;
5
6mod error;
7pub use error::Error;
8
9mod utils;
10
11use async_stream::try_stream;
12use bon::bon;
13use bytes::{Buf, BytesMut};
14use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
15use tokio_stream::{Stream, StreamExt};
16use u_sdk_common::helper::{into_request_failed_error, parse_json_response};
17use utils::check_msg_list;
18
19const BASE_URL: &str = "https://api.deepseek.com";
20
21//region client
22pub struct Client {
23    http_client: reqwest::Client,
24}
25
26#[bon]
27impl Client {
28    #[builder(on(String, into))]
29    pub fn new(api_key: String) -> Self {
30        let mut header_map = HeaderMap::new();
31        let mut auth_val = HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap();
32        auth_val.set_sensitive(true);
33        header_map.insert(AUTHORIZATION, auth_val);
34
35        let http_client = reqwest::Client::builder()
36            .default_headers(header_map)
37            .build()
38            .unwrap();
39
40        Self { http_client }
41    }
42
43    pub fn chat_builder(&self) -> ChatBuilder<'_> {
44        Chat::builder(self)
45    }
46
47    pub async fn check_balance(&self) -> Result<CheckBalanceResponse, Error> {
48        let resp = self
49            .http_client
50            .get(format!("{}/user/balance", BASE_URL))
51            .send()
52            .await?;
53
54        let res = parse_json_response(resp).await?;
55        Ok(res)
56    }
57}
58//endregion
59
60//region chat
61impl Chat<'_> {
62    /// 多轮对话形式
63    ///
64    /// 发送的形式:
65    ///
66    /// ```json
67    /// // 第一条可以是prompt
68    /// {"content": "You are a helpful assistant", "role": "system" }
69    /// {"content": "Hi", "role": "user" }
70    ///
71    /// // 或者直接是user
72    /// {"content": "Hi", "role": "user" }
73    /// ```
74    pub async fn chat(&self) -> Result<ChatResponse, Error> {
75        check_msg_list(self.messages)?;
76
77        // 防止 stream 为 true
78        if self.stream {
79            return Err(Error::Common(
80                "Stream mode is enabled. Use chat_by_stream instead.".to_string(),
81            ));
82        }
83
84        let client = self.client;
85        let resp = client
86            .http_client
87            .post(format!("{}/chat/completions", BASE_URL))
88            .json(self)
89            .send()
90            .await?;
91
92        let res = parse_json_response(resp).await?;
93        Ok(res)
94    }
95
96    pub async fn chat_by_stream(
97        &self,
98    ) -> Result<impl Stream<Item = Result<StreamEventData, Error>> + use<>, Error> {
99        check_msg_list(self.messages)?;
100
101        if !self.stream {
102            return Err(Error::Common(
103                "Stream mode is not enabled. Use chat instead.".to_string(),
104            ));
105        }
106
107        let resp = self
108            .client
109            .http_client
110            .post(format!("{}/chat/completions", BASE_URL))
111            .json(self)
112            .send()
113            .await?;
114
115        if !resp.status().is_success() {
116            return Err(into_request_failed_error(resp).await.into());
117        }
118
119        let mut byte_stream = resp.bytes_stream();
120
121        let event_stream = try_stream! {
122            let mut buffer = BytesMut::with_capacity(4096);
123
124            while let Some(chunk) = byte_stream.next().await {
125                // 如果底层网络错误,会通过 `?` 返回 Err(Error) 并终止流
126                let chunk = chunk?;
127                buffer.extend(chunk);
128
129                // SSE 协议中,每条事件以 "\n\n" 分隔
130                while let Some(pos) = buffer.windows(2).position(|w| w == b"\n\n") {
131                    // 转成 &str,UTF-8 错误同样会返回 Err(Error)
132                    let text = std::str::from_utf8(&buffer[..pos])
133                        .map_err(|e| Error::Common(format!("Invalid UTF-8 sequence: {}", e)))?;
134                    // 解析这一条事件,没有事件时(data: [DONE])会返回 Ok(None)
135                    if let Some(evt) = parse_event_block(text)? {
136                        yield evt;
137                    }
138                    // 清除已处理的部分
139                    buffer.advance(pos + 2);
140                }
141            }
142        };
143
144        Ok(Box::pin(event_stream))
145    }
146}
147
148// 解析一段完整的 SSE 事件文本
149fn parse_event_block(s: &str) -> Result<Option<StreamEventData>, Error> {
150    let s = s.trim();
151    // 结束标志
152    if s.starts_with("data: [DONE]") {
153        return Ok(None);
154    }
155    // 正常的数据行
156    if let Some(rest) = s.strip_prefix("data:") {
157        let json_str = rest.trim_start();
158        let data: StreamEventData = serde_json::from_str(json_str)
159            .map_err(|e| Error::Common(format!("Failed to parse stream event data: {}", e)))?;
160        Ok(Some(data))
161    } else {
162        Err(Error::Common("Unknown event format".to_string()))
163    }
164}
165//endregion