rust_ai/openai/apis/
embedding.rs1use crate::openai::{
10 endpoint::{endpoint_filter, request_endpoint, Endpoint, EndpointVariant},
11 types::{common::Error, embedding::EmbeddingResponse, model::Model},
12};
13use log::{debug, warn};
14use serde::{Deserialize, Serialize};
15use serde_with::serde_as;
16
17#[serde_as]
20#[derive(Serialize, Deserialize, Debug)]
21pub struct Embedding {
22 pub model: Model,
23
24 pub input: String,
25
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub user: Option<String>,
28}
29
30impl Default for Embedding {
31 fn default() -> Self {
32 Self {
33 model: Model::TEXT_EMBEDDING_ADA_002,
34 input: String::from(""),
35 user: None,
36 }
37 }
38}
39
40impl Embedding {
41 pub fn model(self, model: Model) -> Self {
44 Self { model, ..self }
45 }
46
47 pub fn input(self, content: &str) -> Self {
52 Self {
53 input: content.into(),
54 ..self
55 }
56 }
57
58 pub fn user(self, user: &str) -> Self {
61 Self {
62 user: Some(user.into()),
63 ..self
64 }
65 }
66
67 pub async fn embedding(self) -> Result<EmbeddingResponse, Box<dyn std::error::Error>> {
69 if !endpoint_filter(&self.model, &Endpoint::Embedding_v1) {
70 return Err("Model not compatible with this endpoint".into());
71 }
72
73 let mut embedding_response: Option<EmbeddingResponse> = None;
74
75 request_endpoint(&self, &Endpoint::Embedding_v1, EndpointVariant::None, |res| {
76 if let Ok(text) = res {
77 if let Ok(response_data) = serde_json::from_str::<EmbeddingResponse>(&text) {
78 debug!(target: "openai", "Response parsed, embedding response deserialized.");
79 embedding_response = Some(response_data);
80 } else {
81 if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
82 warn!(target: "openai",
83 "OpenAI error code {}: `{:?}`",
84 response_error.error.code.unwrap_or(0),
85 text
86 );
87 } else {
88 warn!(target: "openai", "Embedding response not deserializable.");
89 }
90 }
91 }
92 })
93 .await?;
94
95 if let Some(response_data) = embedding_response {
96 Ok(response_data)
97 } else {
98 Err("No response or error parsing response".into())
99 }
100 }
101}