Skip to main content

zai_rs/model/text_rerank/
request.rs

1use serde::{Deserialize, Serialize};
2
3/// Rerank model enum
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6#[derive(Default)]
7pub enum RerankModel {
8    #[default]
9    Rerank,
10}
11
12/// Request body for rerank API
13#[derive(Debug, Clone, Serialize)]
14pub struct RerankBody {
15    /// 模型编码,默认为 rerank
16    pub model: RerankModel,
17
18    /// 查询文本(最大长度 4096 字符)
19    pub query: String,
20
21    /// 候选文本数组(最多 128 条,单条最大 4096 字符)
22    pub documents: Vec<String>,
23
24    /// 返回得分最高的前 n 条,默认 0 返回所有
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub top_n: Option<usize>,
27
28    /// 是否返回原始文本,默认 false
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub return_documents: Option<bool>,
31
32    /// 是否返回原始分数,默认 false
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub return_raw_scores: Option<bool>,
35
36    /// 客户端请求ID
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub request_id: Option<String>,
39
40    /// 终端用户ID
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub user_id: Option<String>,
43}
44
45impl RerankBody {
46    pub fn new(model: RerankModel, query: impl Into<String>, documents: Vec<String>) -> Self {
47        Self {
48            model,
49            query: query.into(),
50            documents,
51            top_n: None,
52            return_documents: None,
53            return_raw_scores: None,
54            request_id: None,
55            user_id: None,
56        }
57    }
58
59    pub fn with_top_n(mut self, n: usize) -> Self {
60        self.top_n = Some(n);
61        self
62    }
63    pub fn with_return_documents(mut self, v: bool) -> Self {
64        self.return_documents = Some(v);
65        self
66    }
67    pub fn with_return_raw_scores(mut self, v: bool) -> Self {
68        self.return_raw_scores = Some(v);
69        self
70    }
71    pub fn with_request_id(mut self, v: impl Into<String>) -> Self {
72        self.request_id = Some(v.into());
73        self
74    }
75    pub fn with_user_id(mut self, v: impl Into<String>) -> Self {
76        self.user_id = Some(v.into());
77        self
78    }
79
80    /// Optional runtime validation for constraints expressed in the docs
81    pub fn validate_constraints(&self) -> crate::ZaiResult<()> {
82        if self.query.chars().count() > 4096 {
83            return Err(crate::client::error::ZaiError::ApiError {
84                code: 1200,
85                message: "query length exceeds 4096 characters".to_string(),
86            });
87        }
88        if self.documents.is_empty() {
89            return Err(crate::client::error::ZaiError::ApiError {
90                code: 1200,
91                message: "documents must not be empty".to_string(),
92            });
93        }
94        if self.documents.len() > 128 {
95            return Err(crate::client::error::ZaiError::ApiError {
96                code: 1200,
97                message: "documents length exceeds 128".to_string(),
98            });
99        }
100        for (i, d) in self.documents.iter().enumerate() {
101            if d.chars().count() > 4096 {
102                return Err(crate::client::error::ZaiError::ApiError {
103                    code: 1200,
104                    message: format!("document at index {} exceeds 4096 characters", i),
105                });
106            }
107        }
108        if let Some(n) = self.top_n
109            && n > self.documents.len()
110        {
111            return Err(crate::client::error::ZaiError::ApiError {
112                code: 1200,
113                message: "top_n cannot exceed documents length".to_string(),
114            });
115        }
116        Ok(())
117    }
118}