zai_rs/model/text_rerank/
request.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6#[derive(Default)]
7pub enum RerankModel {
8 #[default]
9 Rerank,
10}
11
12#[derive(Debug, Clone, Serialize)]
14pub struct RerankBody {
15 pub model: RerankModel,
17
18 pub query: String,
20
21 pub documents: Vec<String>,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub top_n: Option<usize>,
27
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub return_documents: Option<bool>,
31
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub return_raw_scores: Option<bool>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub request_id: Option<String>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub user_id: Option<String>,
43}
44
45impl RerankBody {
46 pub fn new(model: RerankModel, query: impl Into<String>, documents: Vec<String>) -> Self {
47 Self {
48 model,
49 query: query.into(),
50 documents,
51 top_n: None,
52 return_documents: None,
53 return_raw_scores: None,
54 request_id: None,
55 user_id: None,
56 }
57 }
58
59 pub fn with_top_n(mut self, n: usize) -> Self {
60 self.top_n = Some(n);
61 self
62 }
63 pub fn with_return_documents(mut self, v: bool) -> Self {
64 self.return_documents = Some(v);
65 self
66 }
67 pub fn with_return_raw_scores(mut self, v: bool) -> Self {
68 self.return_raw_scores = Some(v);
69 self
70 }
71 pub fn with_request_id(mut self, v: impl Into<String>) -> Self {
72 self.request_id = Some(v.into());
73 self
74 }
75 pub fn with_user_id(mut self, v: impl Into<String>) -> Self {
76 self.user_id = Some(v.into());
77 self
78 }
79
80 pub fn validate_constraints(&self) -> crate::ZaiResult<()> {
82 if self.query.chars().count() > 4096 {
83 return Err(crate::client::error::ZaiError::ApiError {
84 code: 1200,
85 message: "query length exceeds 4096 characters".to_string(),
86 });
87 }
88 if self.documents.is_empty() {
89 return Err(crate::client::error::ZaiError::ApiError {
90 code: 1200,
91 message: "documents must not be empty".to_string(),
92 });
93 }
94 if self.documents.len() > 128 {
95 return Err(crate::client::error::ZaiError::ApiError {
96 code: 1200,
97 message: "documents length exceeds 128".to_string(),
98 });
99 }
100 for (i, d) in self.documents.iter().enumerate() {
101 if d.chars().count() > 4096 {
102 return Err(crate::client::error::ZaiError::ApiError {
103 code: 1200,
104 message: format!("document at index {} exceeds 4096 characters", i),
105 });
106 }
107 }
108 if let Some(n) = self.top_n
109 && n > self.documents.len()
110 {
111 return Err(crate::client::error::ZaiError::ApiError {
112 code: 1200,
113 message: "top_n cannot exceed documents length".to_string(),
114 });
115 }
116 Ok(())
117 }
118}