zai_rs/model/text_embedded/
request.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5#[serde(rename_all = "kebab-case")]
6pub enum EmbeddingModel {
7 #[serde(rename = "embedding-3")]
8 Embedding3,
9 #[serde(rename = "embedding-2")]
10 Embedding2,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(untagged)]
16pub enum EmbeddingInput {
17 Single(String),
18 Batch(Vec<String>),
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum EmbeddingDimensions {
24 D2048,
25 D1024,
26 D512,
27 D256,
28}
29
30impl Serialize for EmbeddingDimensions {
31 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
32 where
33 S: serde::Serializer,
34 {
35 let v: u16 = match self {
36 EmbeddingDimensions::D2048 => 2048,
37 EmbeddingDimensions::D1024 => 1024,
38 EmbeddingDimensions::D512 => 512,
39 EmbeddingDimensions::D256 => 256,
40 };
41 serializer.serialize_u16(v)
42 }
43}
44
45#[derive(Debug, Clone, Serialize)]
47pub struct EmbeddingBody {
48 pub model: EmbeddingModel,
50
51 pub input: EmbeddingInput,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
57 pub dimensions: Option<EmbeddingDimensions>,
58}
59
60impl EmbeddingBody {
61 pub fn new(model: EmbeddingModel, input: EmbeddingInput) -> Self {
62 Self {
63 model,
64 input,
65 dimensions: None,
66 }
67 }
68
69 pub fn with_dimensions(mut self, dims: EmbeddingDimensions) -> Self {
70 self.dimensions = Some(dims);
71 self
72 }
73
74 pub fn validate_model_constraints(&self) -> Result<(), validator::ValidationError> {
77 use validator::ValidationError;
78 if let EmbeddingModel::Embedding3 = self.model
80 && let EmbeddingInput::Batch(ref v) = self.input
81 && v.len() > 64
82 {
83 return Err(ValidationError::new("batch_too_long"));
84 }
85 if let EmbeddingModel::Embedding2 = self.model
87 && let Some(d) = self.dimensions
88 && d != EmbeddingDimensions::D1024
89 {
90 return Err(ValidationError::new("embedding2_dims_must_be_1024"));
91 }
92 Ok(())
93 }
94}