Skip to main content

zai_rs/model/text_embedded/
request.rs

1use serde::{Deserialize, Serialize};
2
3/// Embedding model enum
4#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(rename_all = "kebab-case")]
6pub enum EmbeddingModel {
7    #[serde(rename = "embedding-3")]
8    Embedding3,
9    #[serde(rename = "embedding-2")]
10    Embedding2,
11}
12
13/// Input can be a single string or an array of strings
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum EmbeddingInput {
17    Single(String),
18    Batch(Vec<String>),
19}
20
21/// Output vector dimensions for embeddings
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum EmbeddingDimensions {
24    D2048,
25    D1024,
26    D512,
27    D256,
28}
29
30impl Serialize for EmbeddingDimensions {
31    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
32    where
33        S: serde::Serializer,
34    {
35        let v: u16 = match self {
36            EmbeddingDimensions::D2048 => 2048,
37            EmbeddingDimensions::D1024 => 1024,
38            EmbeddingDimensions::D512 => 512,
39            EmbeddingDimensions::D256 => 256,
40        };
41        serializer.serialize_u16(v)
42    }
43}
44
45/// Request body for embeddings
46#[derive(Debug, Clone, Serialize)]
47pub struct EmbeddingBody {
48    /// 嵌入模型:embedding-3 或 embedding-2
49    pub model: EmbeddingModel,
50
51    /// 输入文本,支持字符串或字符串数组
52    pub input: EmbeddingInput,
53
54    /// 输出维度,Embedding-3 支持 256/512/1024/2048;Embedding-2 固定
55    /// 1024(可不填)
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub dimensions: Option<EmbeddingDimensions>,
58}
59
60impl EmbeddingBody {
61    pub fn new(model: EmbeddingModel, input: EmbeddingInput) -> Self {
62        Self {
63            model,
64            input,
65            dimensions: None,
66        }
67    }
68
69    pub fn with_dimensions(mut self, dims: EmbeddingDimensions) -> Self {
70        self.dimensions = Some(dims);
71        self
72    }
73
74    /// Optional helper to enforce cross-field constraints at runtime.
75    /// Call this before sending if you want strict validation.
76    pub fn validate_model_constraints(&self) -> Result<(), validator::ValidationError> {
77        use validator::ValidationError;
78        // If input is Batch for embedding-3, enforce max 64 items (per API doc)
79        if let EmbeddingModel::Embedding3 = self.model
80            && let EmbeddingInput::Batch(ref v) = self.input
81            && v.len() > 64
82        {
83            return Err(ValidationError::new("batch_too_long"));
84        }
85        // If model = embedding-2 and dimensions is Some, it must be 1024
86        if let EmbeddingModel::Embedding2 = self.model
87            && let Some(d) = self.dimensions
88            && d != EmbeddingDimensions::D1024
89        {
90            return Err(ValidationError::new("embedding2_dims_must_be_1024"));
91        }
92        Ok(())
93    }
94}