rust_ai/openai/apis/
embedding.rs

1//!
2//! Get a vector representation of a given input that can be easily consumed by
3//! machine learning models and algorithms.
4//!
5//! Source: OpenAI documentation
6
7////////////////////////////////////////////////////////////////////////////////
8
9use crate::openai::{
10    endpoint::{endpoint_filter, request_endpoint, Endpoint, EndpointVariant},
11    types::{common::Error, embedding::EmbeddingResponse, model::Model},
12};
13use log::{debug, warn};
14use serde::{Deserialize, Serialize};
15use serde_with::serde_as;
16
17/// Given a prompt and an instruction, the model will return an edited version
18/// of the prompt.
19#[serde_as]
20#[derive(Serialize, Deserialize, Debug)]
21pub struct Embedding {
22    pub model: Model,
23
24    pub input: String,
25
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub user: Option<String>,
28}
29
30impl Default for Embedding {
31    fn default() -> Self {
32        Self {
33            model: Model::TEXT_EMBEDDING_ADA_002,
34            input: String::from(""),
35            user: None,
36        }
37    }
38}
39
40impl Embedding {
41    /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of
42    /// your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
43    pub fn model(self, model: Model) -> Self {
44        Self { model, ..self }
45    }
46
47    /// Input text to get embeddings for, encoded as a string or array of tokens.
48    /// To get embeddings for multiple inputs in a single request, pass an array
49    /// of strings or array of token arrays. Each input must not exceed 8192
50    /// tokens in length.
51    pub fn input(self, content: &str) -> Self {
52        Self {
53            input: content.into(),
54            ..self
55        }
56    }
57
58    /// A unique identifier representing your end-user, which can help OpenAI to
59    /// monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
60    pub fn user(self, user: &str) -> Self {
61        Self {
62            user: Some(user.into()),
63            ..self
64        }
65    }
66
67    /// Send embedding request to OpenAI.
68    pub async fn embedding(self) -> Result<EmbeddingResponse, Box<dyn std::error::Error>> {
69        if !endpoint_filter(&self.model, &Endpoint::Embedding_v1) {
70            return Err("Model not compatible with this endpoint".into());
71        }
72
73        let mut embedding_response: Option<EmbeddingResponse> = None;
74
75        request_endpoint(&self, &Endpoint::Embedding_v1, EndpointVariant::None, |res| {
76          if let Ok(text) = res {
77              if let Ok(response_data) = serde_json::from_str::<EmbeddingResponse>(&text) {
78                  debug!(target: "openai", "Response parsed, embedding response deserialized.");
79                  embedding_response = Some(response_data);
80              } else {
81                  if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
82                      warn!(target: "openai",
83                          "OpenAI error code {}: `{:?}`",
84                          response_error.error.code.unwrap_or(0),
85                          text
86                      );
87                  } else {
88                      warn!(target: "openai", "Embedding response not deserializable.");
89                  }
90              }
91          }
92      })
93      .await?;
94
95        if let Some(response_data) = embedding_response {
96            Ok(response_data)
97        } else {
98            Err("No response or error parsing response".into())
99        }
100    }
101}