spin_sdk/
llm.rs

1pub use crate::wit::v2::llm::{Error, InferencingParams, InferencingResult, InferencingUsage};
2
3/// Provides access to the underlying WIT interface. You should not normally need
4/// to use this module: use the re-exports in this module instead.
5#[doc(inline)]
6pub use crate::wit::v2::llm;
7
8/// The result of generating embeddings.
9///
10/// # Examples
11///
12/// Generate embeddings using the all-minilm-l6-v2 LLM.
13///
14/// ```no_run
15/// use spin_sdk::llm;
16///
17/// # fn main() -> anyhow::Result<()> {
18/// let text = &[
19///     "I've just broken a priceless turnip".to_owned(),
20/// ];
21///
22/// let embed_result = llm::generate_embeddings(llm::EmbeddingModel::AllMiniLmL6V2, text)?;
23///
24/// println!("prompt token count: {}", embed_result.usage.prompt_token_count);
25/// println!("embedding: {:?}", embed_result.embeddings.first());
26/// # Ok(())
27/// # }
28/// ```
29#[doc(inline)]
30pub use crate::wit::v2::llm::EmbeddingsResult;
31
32/// Usage related to an embeddings generation request.
33///
34/// # Examples
35///
36/// ```no_run
37/// use spin_sdk::llm;
38///
39/// # fn main() -> anyhow::Result<()> {
40/// # let text = &[];
41/// let embed_result = llm::generate_embeddings(llm::EmbeddingModel::AllMiniLmL6V2, text)?;
42/// println!("prompt token count: {}", embed_result.usage.prompt_token_count);
43/// # Ok(())
44/// # }
45/// ```
46pub use crate::wit::v2::llm::EmbeddingsUsage;
47
48/// The model use for inferencing
49#[allow(missing_docs)]
50#[derive(Debug, Clone, Copy)]
51pub enum InferencingModel<'a> {
52    Llama2Chat,
53    CodellamaInstruct,
54    Other(&'a str),
55}
56
57impl std::fmt::Display for InferencingModel<'_> {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        let str = match self {
60            InferencingModel::Llama2Chat => "llama2-chat",
61            InferencingModel::CodellamaInstruct => "codellama-instruct",
62            InferencingModel::Other(s) => s,
63        };
64        f.write_str(str)
65    }
66}
67
68impl Default for InferencingParams {
69    fn default() -> Self {
70        Self {
71            max_tokens: 100,
72            repeat_penalty: 1.1,
73            repeat_penalty_last_n_token_count: 64,
74            temperature: 0.8,
75            top_k: 40,
76            top_p: 0.9,
77        }
78    }
79}
80
81/// Perform inferencing using the provided model and prompt
82pub fn infer(model: InferencingModel, prompt: &str) -> Result<InferencingResult, Error> {
83    llm::infer(&model.to_string(), prompt, None)
84}
85
86/// Perform inferencing using the provided model, prompt, and options
87pub fn infer_with_options(
88    model: InferencingModel,
89    prompt: &str,
90    options: InferencingParams,
91) -> Result<InferencingResult, Error> {
92    llm::infer(&model.to_string(), prompt, Some(options))
93}
94
95/// Model used for generating embeddings
96#[allow(missing_docs)]
97#[derive(Debug, Clone, Copy)]
98pub enum EmbeddingModel<'a> {
99    AllMiniLmL6V2,
100    Other(&'a str),
101}
102
103impl std::fmt::Display for EmbeddingModel<'_> {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        let str = match self {
106            EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
107            EmbeddingModel::Other(s) => s,
108        };
109        f.write_str(str)
110    }
111}
112
113/// Generate embeddings using the provided model and collection of text
114pub fn generate_embeddings(
115    model: EmbeddingModel,
116    text: &[String],
117) -> Result<llm::EmbeddingsResult, Error> {
118    llm::generate_embeddings(&model.to_string(), text)
119}