Skip to main content

spin_sdk/
llm.rs

1#[doc(hidden)]
2/// Module containing wit bindgen generated code.
3///
4/// This is only meant for internal consumption.
5pub mod wit {
6    #![allow(missing_docs)]
7    use crate::wit_bindgen;
8
9    wit_bindgen::generate!({
10        runtime_path: "crate::wit_bindgen::rt",
11        world: "spin-sdk-llm",
12        path: "wit",
13        generate_all,
14    });
15
16    pub use fermyon::spin::llm;
17}
18
19pub use wit::llm::{Error, InferencingParams, InferencingResult, InferencingUsage};
20
21/// Provides access to the underlying WIT interface. You should not normally need
22/// to use this module: use the re-exports in this module instead.
23#[doc(inline)]
24pub use wit::llm;
25
26/// The result of generating embeddings.
27///
28/// # Examples
29///
30/// Generate embeddings using the all-minilm-l6-v2 LLM.
31///
32/// ```no_run
33/// use spin_sdk::llm::{generate_embeddings, EmbeddingModel};
34///
35/// # fn run() -> anyhow::Result<()> {
36/// let text = &[
37///     "I've just broken a priceless turnip".to_owned(),
38/// ];
39///
40/// let embed_result = generate_embeddings(EmbeddingModel::AllMiniLmL6V2, text)?;
41///
42/// println!("prompt token count: {}", embed_result.usage.prompt_token_count);
43/// println!("embedding: {:?}", embed_result.embeddings.first());
44/// # Ok(())
45/// # }
46/// ```
47#[doc(inline)]
48pub use wit::llm::EmbeddingsResult;
49
50/// Usage related to an embeddings generation request.
51///
52/// # Examples
53///
54/// ```no_run
55/// use spin_sdk::llm::{generate_embeddings, EmbeddingModel};
56///
57/// # fn run() -> anyhow::Result<()> {
58/// # let text = &[];
59/// let embed_result = generate_embeddings(EmbeddingModel::AllMiniLmL6V2, text)?;
60/// println!("prompt token count: {}", embed_result.usage.prompt_token_count);
61/// # Ok(())
62/// # }
63/// ```
64pub use wit::llm::EmbeddingsUsage;
65
66/// The model use for inferencing
67#[allow(missing_docs)]
68#[derive(Debug, Clone, Copy)]
69pub enum InferencingModel<'a> {
70    Llama2Chat,
71    Codellarunstruct,
72    Other(&'a str),
73}
74
75impl std::fmt::Display for InferencingModel<'_> {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        let str = match self {
78            InferencingModel::Llama2Chat => "llama2-chat",
79            InferencingModel::Codellarunstruct => "codellama-instruct",
80            InferencingModel::Other(s) => s,
81        };
82        f.write_str(str)
83    }
84}
85
86impl Default for InferencingParams {
87    fn default() -> Self {
88        Self {
89            max_tokens: 100,
90            repeat_penalty: 1.1,
91            repeat_penalty_last_n_token_count: 64,
92            temperature: 0.8,
93            top_k: 40,
94            top_p: 0.9,
95        }
96    }
97}
98
99/// Perform inferencing using the provided model and prompt
100pub fn infer(model: InferencingModel, prompt: &str) -> Result<InferencingResult, Error> {
101    llm::infer(&model.to_string(), prompt, None)
102}
103
104/// Perform inferencing using the provided model, prompt, and options
105pub fn infer_with_options(
106    model: InferencingModel,
107    prompt: &str,
108    options: InferencingParams,
109) -> Result<InferencingResult, Error> {
110    llm::infer(&model.to_string(), prompt, Some(options))
111}
112
113/// Model used for generating embeddings
114#[allow(missing_docs)]
115#[derive(Debug, Clone, Copy)]
116pub enum EmbeddingModel<'a> {
117    AllMiniLmL6V2,
118    Other(&'a str),
119}
120
121impl std::fmt::Display for EmbeddingModel<'_> {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        let str = match self {
124            EmbeddingModel::AllMiniLmL6V2 => "all-minilm-l6-v2",
125            EmbeddingModel::Other(s) => s,
126        };
127        f.write_str(str)
128    }
129}
130
131/// Generate embeddings using the provided model and collection of text
132pub fn generate_embeddings(
133    model: EmbeddingModel,
134    text: &[String],
135) -> Result<llm::EmbeddingsResult, Error> {
136    llm::generate_embeddings(&model.to_string(), text)
137}