yammer/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::fs::OpenOptions;
5use std::io::{Read, Write};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, SystemTime};
9
10use reqwest::RequestBuilder;
11use utf8path::Path;
12
13mod chat;
14mod chats;
15mod cli;
16mod fmt;
17mod types;
18mod wrap;
19
20pub use chat::{Chat, ChatOptions};
21pub use chats::{Chats, ChatsOptions};
22pub use fmt::Formattable;
23pub use types::{
24    ChatMessage, ChatRequest, ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest,
25    GenerateResponse,
26};
27pub use wrap::WordWrap;
28
29///////////////////////////////////////////// constants ////////////////////////////////////////////
30
31/// The default host to connect to.
32pub const OLLAMA_HOST: &str = "http://localhost:11434";
33
34/////////////////////////////////////////////// Error //////////////////////////////////////////////
35
36/// An error that can occur when interacting with the ollama API.
37#[derive(Debug)]
38pub enum Error {
39    /// An Internal error occurred.
40    Internal,
41    /// A signal interrupted the call.
42    Signal,
43    /// The EDITOR environment variable is not set.
44    EditorNotSet,
45    /// EDITOR failed.
46    EditorFailed(Option<i32>),
47    /// The YAMMER_CHAT environment variable is not set.
48    ChatNotSet,
49    /// An invalid argument was passed.
50    InvalidArgument(String),
51    /// An error occurred in the ollama service.
52    Ollama(String),
53    /// An I/O error occurred.
54    Io(std::io::Error),
55    /// A UTF-8 error occurred.
56    Utf8Error(std::str::Utf8Error),
57    /// A JSON error occurred.
58    Json(serde_json::Error),
59    /// A Reqwest error occurred.
60    Reqwest(reqwest::Error),
61}
62
63impl std::error::Error for Error {}
64
65impl std::fmt::Display for Error {
66    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
67        match self {
68            Self::Internal => write!(f, "Internal error"),
69            Self::Signal => write!(f, "Signal received"),
70            Self::EditorNotSet => write!(f, "EDITOR not set"),
71            Self::EditorFailed(Some(code)) => write!(f, "Editor failed with exit code {}", code),
72            Self::EditorFailed(None) => write!(f, "Editor failed without exit code"),
73            Self::ChatNotSet => write!(f, "YAMMER_CHAT not set"),
74            Self::InvalidArgument(message) => write!(f, "invalid argument: {}", message),
75            Self::Ollama(s) => write!(f, "Ollama error: {}", s),
76            Self::Io(e) => write!(f, "I/O error: {}", e),
77            Self::Utf8Error(e) => write!(f, "UTF-8 error: {}", e),
78            Self::Json(e) => write!(f, "JSON error: {}", e),
79            Self::Reqwest(e) => write!(f, "Reqwest error: {}", e),
80        }
81    }
82}
83
84impl From<std::io::Error> for Error {
85    fn from(err: std::io::Error) -> Self {
86        Self::Io(err)
87    }
88}
89
90impl From<std::str::Utf8Error> for Error {
91    fn from(err: std::str::Utf8Error) -> Self {
92        Self::Utf8Error(err)
93    }
94}
95
96impl From<serde_json::Error> for Error {
97    fn from(e: serde_json::Error) -> Self {
98        Self::Json(e)
99    }
100}
101
102impl From<reqwest::Error> for Error {
103    fn from(e: reqwest::Error) -> Self {
104        Self::Reqwest(e)
105    }
106}
107
108//////////////////////////////////////////// Parameters ////////////////////////////////////////////
109
110/// Parameters for the model.
111///
112/// These correspond to the same name as PARAMETER options in Ollama.
113#[derive(
114    Clone, Debug, Default, arrrg_derive::CommandLine, serde::Deserialize, serde::Serialize,
115)]
116pub struct Parameters {
117    /// Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
118    #[arrrg(optional, "Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)")]
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub mirostat: Option<i32>,
121
122    /// Influences how quickly the algorithm responds to feedback from the generated text.
123    ///
124    /// A lower learning rate will result in slower adjustments, while a higher learning rate will
125    /// make the algorithm more responsive. (Default: 0.1)
126    #[arrrg(
127        optional,
128        "Influences how quickly the algorithm responds to feedback from the generated text."
129    )]
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub mirostat_eta: Option<f64>,
132
133    /// Controls the balance between coherence and diversity of the output.
134    ///
135    /// A lower value will result in more focused and coherent text. (Default: 5.0)
136    #[arrrg(
137        optional,
138        "Controls the balance between coherence and diversity of the output."
139    )]
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub mirostat_tau: Option<f64>,
142
143    /// The number of tokens worth of context to allocate.
144    #[arrrg(optional, "The number of tokens worth of context to allocate.")]
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub num_ctx: Option<u32>,
147
148    /// Sets how far back for the model to look back to prevent repetition.
149    ///
150    /// (Default: 64, 0 = disabled, -1 = num_ctx)
151    #[arrrg(
152        optional,
153        "Sets how far back for the model to look back to prevent repetition."
154    )]
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub repeat_last_n: Option<i32>,
157
158    /// Sets how strongly to penalize repetitions.
159    ///
160    /// A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value
161    /// (e.g., 0.9) will be more lenient. (Default: 1.1)
162    #[arrrg(optional, "Sets how strongly to penalize repetitions.")]
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub repeat_penalty: Option<f64>,
165
166    /// The temperature of the model.
167    ///
168    /// Increasing the temperature will make the model answer more creatively. (Default: 0.8)
169    #[arrrg(optional, "The temperature of the model.")]
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub temperature: Option<f64>,
172
173    /// Sets the random number seed to use for generation.
174    ///
175    /// Setting this to a specific number will make the model generate the same text for the same
176    /// prompt.  (Default: 0)
177    #[arrrg(optional, "Sets the random number seed to use for generation.")]
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub seed: Option<i32>,
180
181    /// Tail free sampling is used to reduce the impact of less probable tokens from the output.
182    ///
183    /// A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this
184    /// setting. (default: 1)
185    #[arrrg(
186        optional,
187        "Tail free sampling is used to reduce the impact of less probable tokens from the output."
188    )]
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub tfs_z: Option<f64>,
191
192    /// Maximum number of tokens to predict when generating text.
193    ///
194    /// (Default: 128, -1 = infinite generation, -2 = fill context)
195    #[arrrg(optional, "Maximum number of tokens to predict when generating text.")]
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub num_predict: Option<i32>,
198
199    /// Reduces the probability of generating nonsense.
200    ///
201    /// A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10)
202    /// will be more conservative. (Default: 40)
203    #[arrrg(optional, "Reduces the probability of generating nonsense.")]
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub top_k: Option<i32>,
206
207    /// Works together with top-k.
208    ///
209    /// A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5)
210    /// will generate more focused and conservative text. (Default: 0.9)
211    #[arrrg(optional, "Works together with top-k.")]
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub top_p: Option<f64>,
214
215    /// Alternative to the top_p, and aims to ensure a balance of quality and variety.
216    ///
217    /// The parameter p represents the minimum probability for a token to be considered, relative
218    /// to the probability of the most likely token. For example, with p=0.05 and the most likely
219    /// token having a probability of 0.9, logits with a value less than 0.045 are filtered out.
220    /// (Default: 0.0)
221    #[arrrg(
222        optional,
223        "Alternative to the top_p, and aims to ensure a balance of quality and variety."
224    )]
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub min_p: Option<f64>,
227}
228
229impl Parameters {
230    /// Overlay the parameters from another set of parameters.
231    pub fn apply(&mut self, from: Self) {
232        if let Some(mirostat) = from.mirostat {
233            self.mirostat = Some(mirostat);
234        }
235        if let Some(mirostat_eta) = from.mirostat_eta {
236            self.mirostat_eta = Some(mirostat_eta);
237        }
238        if let Some(mirostat_tau) = from.mirostat_tau {
239            self.mirostat_tau = Some(mirostat_tau);
240        }
241        if let Some(num_ctx) = from.num_ctx {
242            self.num_ctx = Some(num_ctx);
243        }
244        if let Some(repeat_last_n) = from.repeat_last_n {
245            self.repeat_last_n = Some(repeat_last_n);
246        }
247        if let Some(repeat_penalty) = from.repeat_penalty {
248            self.repeat_penalty = Some(repeat_penalty);
249        }
250        if let Some(temperature) = from.temperature {
251            self.temperature = Some(temperature);
252        }
253        if let Some(seed) = from.seed {
254            self.seed = Some(seed);
255        }
256        if let Some(tfs_z) = from.tfs_z {
257            self.tfs_z = Some(tfs_z);
258        }
259        if let Some(num_predict) = from.num_predict {
260            self.num_predict = Some(num_predict);
261        }
262        if let Some(top_k) = from.top_k {
263            self.top_k = Some(top_k);
264        }
265        if let Some(top_p) = from.top_p {
266            self.top_p = Some(top_p);
267        }
268        if let Some(min_p) = from.min_p {
269            self.min_p = Some(min_p);
270        }
271    }
272}
273
274impl From<Parameters> for serde_json::Value {
275    fn from(p: Parameters) -> serde_json::Value {
276        let mut json = serde_json::json!({});
277        if let Some(mirostat) = p.mirostat {
278            json["mirostat"] = serde_json::json!(mirostat);
279        }
280        if let Some(mirostat_eta) = p.mirostat_eta {
281            json["mirostat_eta"] = serde_json::json!(mirostat_eta);
282        }
283        if let Some(mirostat_tau) = p.mirostat_tau {
284            json["mirostat_tau"] = serde_json::json!(mirostat_tau);
285        }
286        if let Some(num_ctx) = p.num_ctx {
287            json["num_ctx"] = serde_json::json!(num_ctx);
288        }
289        if let Some(repeat_last_n) = p.repeat_last_n {
290            json["repeat_last_n"] = serde_json::json!(repeat_last_n);
291        }
292        if let Some(repeat_penalty) = p.repeat_penalty {
293            json["repeat_penalty"] = serde_json::json!(repeat_penalty);
294        }
295        if let Some(temperature) = p.temperature {
296            json["temperature"] = serde_json::json!(temperature);
297        }
298        if let Some(seed) = p.seed {
299            json["seed"] = serde_json::json!(seed);
300        }
301        if let Some(tfs_z) = p.tfs_z {
302            json["tfs_z"] = serde_json::json!(tfs_z);
303        }
304        if let Some(num_predict) = p.num_predict {
305            json["num_predict"] = serde_json::json!(num_predict);
306        }
307        if let Some(top_k) = p.top_k {
308            json["top_k"] = serde_json::json!(top_k);
309        }
310        if let Some(top_p) = p.top_p {
311            json["top_p"] = serde_json::json!(top_p);
312        }
313        if let Some(min_p) = p.min_p {
314            json["min_p"] = serde_json::json!(min_p);
315        }
316        json
317    }
318}
319
320impl Eq for Parameters {}
321
322impl PartialEq for Parameters {
323    fn eq(&self, other: &Self) -> bool {
324        self.mirostat == other.mirostat
325            && self.mirostat_eta == other.mirostat_eta
326            && self.mirostat_tau == other.mirostat_tau
327            && self.num_ctx == other.num_ctx
328            && self.repeat_last_n == other.repeat_last_n
329            && self.repeat_penalty == other.repeat_penalty
330            && self.temperature == other.temperature
331            && self.seed == other.seed
332            && self.tfs_z == other.tfs_z
333            && self.num_predict == other.num_predict
334            && self.top_k == other.top_k
335            && self.top_p == other.top_p
336            && self.min_p == other.min_p
337    }
338}
339
340////////////////////////////////////////////// Shellm //////////////////////////////////////////////
341
342/// Options for the `shellm` command.
343#[derive(Clone, Debug, Eq, PartialEq, arrrg_derive::CommandLine)]
344pub struct ShellmOptions {
345    /// The host to connect to.
346    #[arrrg(optional, "The host to connect to.")]
347    pub ollama_host: Option<String>,
348    /// The model to use from the ollama library.
349    #[arrrg(optional, "The model to use from the ollama library.")]
350    pub model: String,
351    /// The suffix to append to the response.
352    #[arrrg(optional, "The suffix to append to the response.")]
353    pub suffix: Option<String>,
354    /// The system to use in the template.
355    #[arrrg(optional, "The system to use in the template.")]
356    pub system: Option<String>,
357    /// The template to use for the prompt.
358    #[arrrg(optional, "The template to use for the prompt.")]
359    pub template: Option<String>,
360    /// Format the response in JSON.  You must also ask the model to do so.
361    #[arrrg(
362        flag,
363        "Format the response in JSON.  You must also ask the model to do so."
364    )]
365    pub json: bool,
366    /// Schema to adhere to when formatting the response in JSON.  Has no effect without --json.
367    #[arrrg(
368        optional,
369        "Schema to adhere to when formatting the response in JSON.  Has no effect without --json."
370    )]
371    pub schema: Option<serde_json::Value>,
372    /// Whether to pass bypass formatting of the prompt.
373    #[arrrg(optional, "Whether to pass bypass formatting of the prompt.")]
374    pub raw: Option<bool>,
375    /// Duration to keep the model in memory for after the call.
376    #[arrrg(optional, "Duration to keep the model in memory for after the call.")]
377    pub keep_alive: Option<String>,
378    /// Additional options to pass to the model.
379    #[arrrg(nested)]
380    pub param: Parameters,
381    /// Wrap at this line length, or 0 to disable yammer-induced wrapping.
382    #[arrrg(
383        optional,
384        "Wrap at this line length, or 0 to disable yammer-induced wrapping."
385    )]
386    pub wrap: Option<usize>,
387}
388
389impl Default for ShellmOptions {
390    fn default() -> Self {
391        ShellmOptions {
392            ollama_host: None,
393            // TODO(rescrv):  Don't hard-code the default model.
394            model: "gemma2".to_string(),
395            suffix: None,
396            system: None,
397            template: None,
398            json: false,
399            schema: None,
400            raw: None,
401            keep_alive: None,
402            param: Parameters::default(),
403            wrap: None,
404        }
405    }
406}
407
408////////////////////////////////////////////// shellm //////////////////////////////////////////////
409
410/// The `shellm` command.
411pub async fn shellm(
412    options: ShellmOptions,
413    promptfiles: &[impl AsRef<str>],
414) -> Result<(), Box<dyn std::error::Error>> {
415    let mut stdin: Option<String> = None;
416    for promptfile in promptfiles {
417        let promptfile = promptfile.as_ref();
418        let prompt = if promptfile == "-" {
419            if let Some(stdin) = stdin.as_ref() {
420                stdin.clone()
421            } else {
422                let mut s = String::new();
423                std::io::stdin().read_to_string(&mut s)?;
424                stdin = Some(s.clone());
425                s
426            }
427        } else {
428            match std::fs::read_to_string(promptfile) {
429                Ok(s) => s,
430                Err(e) => {
431                    eprintln!("shellm: {}: {}", promptfile, e);
432                    continue;
433                }
434            }
435        };
436        let gen = GenerateRequest {
437            model: options.model.clone(),
438            prompt,
439            suffix: options.suffix.clone(),
440            images: None,
441            format: if options.json {
442                if let Some(schema) = options.schema.clone() {
443                    Some(schema)
444                } else {
445                    Some(serde_json::Value::String("json".to_string()))
446                }
447            } else {
448                None
449            },
450            system: options.system.clone(),
451            template: options.template.clone(),
452            stream: Some(true),
453            raw: options.raw,
454            keep_alive: None,
455            options: Some(options.param.clone().into()),
456        };
457        let req = gen.make_request(&ollama_host(options.ollama_host.clone()));
458        let mut ww = WordWrap::new(options.wrap.unwrap_or(100));
459        let res = stream(req, |v| {
460            if let Some(serde_json::Value::String(message)) = v.get("response") {
461                ww.push(message.clone(), &mut std::io::stdout())?;
462            }
463            Ok(())
464        })
465        .await;
466        if let Err(Error::Signal) = res {
467            break;
468        } else if let Err(err) = res {
469            return Err(err.into());
470        }
471        writeln!(std::io::stdout())?;
472    }
473    Ok(())
474}
475
476////////////////////////////////////////// OneShotOptions //////////////////////////////////////////
477
478/// Options for the `oneshot` command.
479#[derive(Clone, Debug, Default, Eq, PartialEq, arrrg_derive::CommandLine)]
480pub struct OneshotOptions {
481    /// The host to connect to.
482    #[arrrg(optional, "The host to connect to.")]
483    pub ollama_host: Option<String>,
484    /// The suffix to append to the response.
485    #[arrrg(optional, "The suffix to append to the response.")]
486    pub suffix: Option<String>,
487    /// The system to use in the template.
488    #[arrrg(optional, "The system to use in the template.")]
489    pub system: Option<String>,
490    /// The template to use for the prompt.
491    #[arrrg(optional, "The template to use for the prompt.")]
492    pub template: Option<String>,
493    /// Format the response in JSON.  You must also ask the model to do so.
494    #[arrrg(
495        flag,
496        "Format the response in JSON.  You must also ask the model to do so."
497    )]
498    pub json: bool,
499    /// Schema to adhere to when formatting the response in JSON.  Has no effect without --json.
500    #[arrrg(
501        optional,
502        "Schema to adhere to when formatting the response in JSON.  Has no effect without --json."
503    )]
504    pub schema: Option<serde_json::Value>,
505    /// Whether to pass bypass formatting of the prompt.
506    #[arrrg(optional, "Whether to pass bypass formatting of the prompt.")]
507    pub raw: Option<bool>,
508    /// Duration to keep the model in memory for after the call.
509    #[arrrg(optional, "Duration to keep the model in memory for after the call.")]
510    pub keep_alive: Option<String>,
511    /// Additional options to pass to the model.
512    #[arrrg(nested)]
513    pub param: Parameters,
514    /// Wrap at this line length, or 0 to disable yammer-induced wrapping.
515    #[arrrg(
516        optional,
517        "Wrap at this line length, or 0 to disable yammer-induced wrapping."
518    )]
519    pub wrap: Option<usize>,
520}
521
522////////////////////////////////////////////// editor //////////////////////////////////////////////
523
524/// Invoke an editor with a default message and return something like a string.
525pub fn editor(default: &str) -> Result<impl AsRef<String>, Error> {
526    let path = format!(
527        ".yammer.{}.{}",
528        std::process::id(),
529        SystemTime::now()
530            .duration_since(SystemTime::UNIX_EPOCH)
531            .unwrap_or(Duration::ZERO)
532            .as_micros()
533    );
534    let editor = std::env::var("EDITOR").map_err(|_| Error::EditorNotSet)?;
535    struct UnlinkOnDrop(String);
536    impl Drop for UnlinkOnDrop {
537        fn drop(&mut self) {
538            let _ = std::fs::remove_file(&self.0);
539        }
540    }
541    impl AsRef<String> for UnlinkOnDrop {
542        fn as_ref(&self) -> &String {
543            &self.0
544        }
545    }
546    let mut file = OpenOptions::new()
547        .create_new(true)
548        .write(true)
549        .open(&path)?;
550    let unlink = UnlinkOnDrop(path.clone());
551    file.write_all(default.as_bytes())?;
552    file.flush()?;
553    file.sync_all()?;
554    drop(file);
555    let status = std::process::Command::new(editor).arg(&path).status()?;
556    if Some(0) != status.code() {
557        return Err(Error::EditorFailed(status.code()));
558    }
559    Ok(unlink)
560}
561
562////////////////////////////////////////////// oneshot /////////////////////////////////////////////
563
564/// The `oneshot` command.
565pub async fn oneshot(
566    options: OneshotOptions,
567    models: &[impl AsRef<str>],
568) -> Result<(), Box<dyn std::error::Error>> {
569    let path = editor("Replace this text with your prompt.")?;
570    for model in models {
571        let options = ShellmOptions {
572            ollama_host: options.ollama_host.clone(),
573            model: model.as_ref().to_string(),
574            suffix: options.suffix.clone(),
575            system: options.system.clone(),
576            template: options.template.clone(),
577            json: options.json,
578            schema: options.schema.clone(),
579            raw: options.raw,
580            keep_alive: options.keep_alive.clone(),
581            param: options.param.clone(),
582            wrap: options.wrap,
583        };
584        shellm(options, &[path.as_ref()]).await?;
585    }
586    Ok(())
587}
588
589/////////////////////////////////////////// PromptOptions //////////////////////////////////////////
590
591/// Options for the `prompt` command.
592#[derive(Clone, Debug, Eq, PartialEq, arrrg_derive::CommandLine)]
593pub struct PromptOptions {
594    /// The host to connect to.
595    #[arrrg(optional, "The host to connect to.")]
596    pub ollama_host: Option<String>,
597    /// The model to use from the ollama library.
598    #[arrrg(optional, "The model to use from the ollama library.")]
599    pub model: String,
600    /// The suffix to append to the response.
601    #[arrrg(optional, "The suffix to append to the response.")]
602    pub suffix: Option<String>,
603    /// The system to use in the template.
604    #[arrrg(optional, "The system to use in the template.")]
605    pub system: Option<String>,
606    /// The template to use for the prompt.
607    #[arrrg(optional, "The template to use for the prompt.")]
608    pub template: Option<String>,
609    /// Format the response in JSON.  You must also ask the model to do so.
610    #[arrrg(
611        flag,
612        "Format the response in JSON.  You must also ask the model to do so."
613    )]
614    pub json: bool,
615    /// Schema to adhere to when formatting the response in JSON.  Has no effect without --json.
616    #[arrrg(
617        optional,
618        "Schema to adhere to when formatting the response in JSON.  Has no effect without --json."
619    )]
620    pub schema: Option<serde_json::Value>,
621    /// Whether to pass bypass formatting of the prompt.
622    #[arrrg(optional, "Whether to pass bypass formatting of the prompt.")]
623    pub raw: Option<bool>,
624    /// Duration to keep the model in memory for after the call.
625    #[arrrg(optional, "Duration to keep the model in memory for after the call.")]
626    pub keep_alive: Option<String>,
627    /// Additional options to pass to the model.
628    #[arrrg(nested)]
629    pub param: Parameters,
630    /// Wrap at this line length, or 0 to disable yammer-induced wrapping.
631    #[arrrg(
632        optional,
633        "Wrap at this line length, or 0 to disable yammer-induced wrapping."
634    )]
635    pub wrap: Option<usize>,
636}
637
638impl Default for PromptOptions {
639    fn default() -> Self {
640        PromptOptions {
641            ollama_host: None,
642            model: "gemma2".to_string(),
643            suffix: None,
644            system: None,
645            template: None,
646            json: false,
647            raw: None,
648            schema: None,
649            keep_alive: None,
650            param: Parameters::default(),
651            wrap: None,
652        }
653    }
654}
655
656////////////////////////////////////////////// Prompt //////////////////////////////////////////////
657
658/// The `prompt` command.
659pub async fn prompt(
660    options: PromptOptions,
661    prompts: &[impl AsRef<str>],
662) -> Result<(), Box<dyn std::error::Error>> {
663    for prompt in prompts {
664        let gen = GenerateRequest {
665            model: options.model.clone(),
666            prompt: prompt.as_ref().to_string(),
667            suffix: options.suffix.clone(),
668            images: None,
669            format: if options.json {
670                if let Some(schema) = options.schema.clone() {
671                    Some(schema)
672                } else {
673                    Some(serde_json::Value::String("json".to_string()))
674                }
675            } else {
676                None
677            },
678            system: options.system.clone(),
679            template: options.template.clone(),
680            stream: Some(true),
681            raw: options.raw,
682            keep_alive: None,
683            options: Some(options.param.clone().into()),
684        };
685        let req = gen.make_request(&ollama_host(options.ollama_host.clone()));
686        let mut ww = WordWrap::new(options.wrap.unwrap_or(100));
687        let res = stream(req, |v| {
688            if let Some(serde_json::Value::String(message)) = v.get("response") {
689                ww.push(message.clone(), &mut std::io::stdout())?;
690            }
691            Ok(())
692        })
693        .await;
694        if let Err(Error::Signal) = res {
695            break;
696        } else if let Err(err) = res {
697            return Err(err.into());
698        }
699        writeln!(std::io::stdout())?;
700    }
701    Ok(())
702}
703
704//////////////////////////////////////////// chat_shell ////////////////////////////////////////////
705
706/// Start the `chat` shell.
707pub async fn chat_shell(changelog: Option<Path<'_>>, options: ChatOptions) -> Result<(), Error> {
708    let chat = Chat::new(changelog, options)?;
709    chat.shell().await
710}
711
712//////////////////////////////////////////// chats_shell ///////////////////////////////////////////
713
714/// Start the `chats` shell.
715pub async fn chats_shell(options: ChatsOptions) -> Result<(), Error> {
716    let chats = Chats::new(options)?;
717    chats.shell().await
718}
719
720////////////////////////////////////////////// stream //////////////////////////////////////////////
721
722/// Stream the response of a request, calling `for_each` on each JSON object in the response.
723pub async fn stream(
724    req: RequestBuilder,
725    for_each: impl FnMut(serde_json::Value) -> Result<(), Error>,
726) -> Result<(), Error> {
727    let sns = stream_no_signal(req, for_each);
728    let sig = async {
729        loop {
730            tokio::time::sleep(Duration::from_millis(50)).await;
731            if minimal_signals::pending()
732                .iter()
733                .filter(|s| *s != minimal_signals::SIGCHLD)
734                .count()
735                > 0
736            {
737                break;
738            }
739        }
740    };
741    tokio::select! {
742        res = sns => res,
743        _ = sig => Err(Error::Signal),
744    }
745}
746
747async fn stream_no_signal(
748    req: RequestBuilder,
749    mut for_each: impl FnMut(serde_json::Value) -> Result<(), Error>,
750) -> Result<(), Error> {
751    let mut resp = req.send().await?;
752    if resp.status() != 200 {
753        return if let Some(chunk) = resp.chunk().await? {
754            #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
755            struct ErrorResponse {
756                pub error: String,
757            }
758            let err = serde_json::from_slice::<ErrorResponse>(&chunk)?;
759            Err(Error::Ollama(err.error))
760        } else {
761            Err(Error::Internal)
762        };
763    }
764    let mut leftovers: Vec<u8> = vec![];
765    while let Some(chunk) = resp.chunk().await? {
766        leftovers.extend(&chunk);
767        if let Ok(value) = serde_json::from_slice(&leftovers) {
768            for_each(value)?;
769            leftovers.clear();
770        }
771    }
772    if !leftovers.is_empty() {
773        let Ok(value) = serde_json::from_slice(&leftovers) else {
774            return Err(Error::Ollama(format!(
775                "Host returned invalid JSON chunk {leftovers:?}"
776            )));
777        };
778        for_each(value)?;
779    }
780    Ok(())
781}
782
783//////////////////////////////////////////// ollama_host ///////////////////////////////////////////
784
785/// Return the Ollama host, preferring the value passed in, falling back to the env var, falling
786/// back to the hard-coded default.
787pub fn ollama_host(host: Option<String>) -> String {
788    host.unwrap_or_else(|| std::env::var("OLLAMA_HOST").unwrap_or_else(|_| OLLAMA_HOST.to_string()))
789}
790
791///////////////////////////////////////////// chat_root ////////////////////////////////////////////
792
793/// The root on the filesystem for chats.
794pub fn chat_root() -> Result<Path<'static>, Error> {
795    let root = std::env::var("YAMMER_CHAT").map_err(|_| Error::ChatNotSet)?;
796    Ok(Path::from(root))
797}
798
799///////////////////////////////////////////// chat_path ////////////////////////////////////////////
800
801/// The path for one specific chat.
802pub fn chat_path(chat_id: &str) -> Result<Path<'static>, Error> {
803    let root = chat_root()?;
804    Ok(root.join("chats").join(format!("{}.ndjson", chat_id)))
805}
806
807////////////////////////////////////////////// Spinner /////////////////////////////////////////////
808
809const SPINNER: &[&str] = &["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
810
811/// A spinner widget.
812#[derive(Debug)]
813pub struct Spinner {
814    done: Arc<AtomicBool>,
815    inhibited: Arc<Mutex<bool>>,
816    background: Option<std::thread::JoinHandle<()>>,
817}
818
819impl Spinner {
820    /// Create a new spinner.
821    #[allow(clippy::new_without_default)]
822    pub fn new() -> Self {
823        let done = Arc::new(AtomicBool::new(false));
824        let done_p = Arc::clone(&done);
825        let inhibited = Arc::new(Mutex::new(true));
826        let inhibited_p = Arc::clone(&inhibited);
827        let background = std::thread::spawn(move || {
828            let mut i = 0;
829            while !done_p.load(Ordering::Relaxed) {
830                std::thread::sleep(std::time::Duration::from_millis(50));
831                let inhibited_p = inhibited_p.lock().unwrap();
832                if *inhibited_p {
833                    continue;
834                }
835                let mut stderr = std::io::stderr().lock();
836                let _ = stderr.write(b"\x1b[2K\r");
837                let _ = stderr.write(SPINNER[i % SPINNER.len()].as_bytes());
838                let _ = stderr.write(" ".as_bytes());
839                let _ = stderr.flush();
840                i += 1;
841            }
842        });
843        Self {
844            done,
845            inhibited,
846            background: Some(background),
847        }
848    }
849
850    /// Start the spinner.
851    pub fn start(&self) {
852        *self.inhibited.lock().unwrap() = false;
853    }
854
855    /// Inhibit the spinner.
856    pub fn inhibit(&self) {
857        let mut inhibited = self.inhibited.lock().unwrap();
858        if !*inhibited {
859            *inhibited = true;
860            let mut stderr = std::io::stderr().lock();
861            let _ = stderr.write(b"\x1b[2K\r");
862        }
863    }
864}
865
866impl Drop for Spinner {
867    fn drop(&mut self) {
868        self.done.store(true, Ordering::Relaxed);
869        self.inhibit();
870        if let Some(background) = self.background.take() {
871            background.join().unwrap();
872        }
873    }
874}
875
876//////////////////////////////////////////// JsonSchema ////////////////////////////////////////////
877
878/// Implement JsonSchema to derive the schema for GenerateRequest automatically.
879pub trait JsonSchema {
880    /// Return the json_schema.  Does not depend on an object.
881    fn json_schema() -> serde_json::Value;
882}
883
884impl JsonSchema for bool {
885    fn json_schema() -> serde_json::Value {
886        serde_json::json! {{ "type": "boolean" }}
887    }
888}
889
890impl JsonSchema for i8 {
891    fn json_schema() -> serde_json::Value {
892        serde_json::json! {{ "type": "integer" }}
893    }
894}
895
896impl JsonSchema for i16 {
897    fn json_schema() -> serde_json::Value {
898        serde_json::json! {{ "type": "integer" }}
899    }
900}
901
902impl JsonSchema for i32 {
903    fn json_schema() -> serde_json::Value {
904        serde_json::json! {{ "type": "integer" }}
905    }
906}
907
908impl JsonSchema for i64 {
909    fn json_schema() -> serde_json::Value {
910        serde_json::json! {{ "type": "integer" }}
911    }
912}
913
914impl JsonSchema for u8 {
915    fn json_schema() -> serde_json::Value {
916        serde_json::json! {{ "type": "integer" }}
917    }
918}
919
920impl JsonSchema for u16 {
921    fn json_schema() -> serde_json::Value {
922        serde_json::json! {{ "type": "integer" }}
923    }
924}
925
926impl JsonSchema for u32 {
927    fn json_schema() -> serde_json::Value {
928        serde_json::json! {{ "type": "integer" }}
929    }
930}
931
932impl JsonSchema for u64 {
933    fn json_schema() -> serde_json::Value {
934        serde_json::json! {{ "type": "integer" }}
935    }
936}
937
938impl JsonSchema for f32 {
939    fn json_schema() -> serde_json::Value {
940        serde_json::json! {{ "type": "number" }}
941    }
942}
943
944impl JsonSchema for f64 {
945    fn json_schema() -> serde_json::Value {
946        serde_json::json! {{ "type": "number" }}
947    }
948}
949
950impl JsonSchema for String {
951    fn json_schema() -> serde_json::Value {
952        serde_json::json! {{ "type": "string" }}
953    }
954}
955
956impl<T: JsonSchema> JsonSchema for Option<T> {
957    fn json_schema() -> serde_json::Value {
958        let mut res = <T as JsonSchema>::json_schema();
959        res["nullable"] = true.into();
960        res
961    }
962}
963
964impl<T: JsonSchema> JsonSchema for Vec<T> {
965    fn json_schema() -> serde_json::Value {
966        serde_json::json! {{ "type": "array", "items": <T as JsonSchema>::json_schema() }}
967    }
968}
969
970impl JsonSchema for serde_json::Value {
971    fn json_schema() -> serde_json::Value {
972        serde_json::json! {{}}
973    }
974}
975
976impl<Tz: chrono::TimeZone> JsonSchema for chrono::DateTime<Tz> {
977    fn json_schema() -> serde_json::Value {
978        String::json_schema()
979    }
980}
981
982//////////////////////////////////////////// ToolBuilder ///////////////////////////////////////////
983
984/// Build a tool for use in chat completions.
985pub struct ToolBuilder {
986    name: String,
987    description: String,
988    fields: Vec<(String, serde_json::Value)>,
989}
990
991impl ToolBuilder {
992    /// Create a new tool.  Name is the name of the function, and description is a plain-language
993    /// description of what it does.
994    pub fn new(name: &str, description: &str) -> Self {
995        let name = name.to_string();
996        let description = description.to_string();
997        let fields = vec![];
998        Self {
999            name,
1000            description,
1001            fields,
1002        }
1003    }
1004
1005    /// Append an argument to the tool call.  All arguments are required by convention.
1006    pub fn arg<T: JsonSchema>(mut self, name: &str) -> Self {
1007        self.fields.push((name.to_string(), T::json_schema()));
1008        self
1009    }
1010
1011    /// Consume the [ToolBuilder] and return a JSON blob suitable for passing to Ollama.
1012    pub fn build(self) -> serde_json::Value {
1013        let mut properties = serde_json::json! {{}};
1014        let mut required = vec![];
1015        for (name, schema) in self.fields.into_iter() {
1016            properties[name.clone()] = schema;
1017            required.push(name);
1018        }
1019        let required: serde_json::Value = required.into();
1020        let parameters = serde_json::json! {{
1021            "type": "object",
1022            "properties": properties,
1023            "required": required,
1024        }};
1025        serde_json::json! {{
1026            "type": "function",
1027            "function": {
1028                "name": self.name,
1029                "description": self.description,
1030                "parameters": parameters,
1031            }
1032        }}
1033    }
1034}
1035
1036/////////////////////////////////////////////// tests //////////////////////////////////////////////
1037
1038#[cfg(test)]
1039mod tests {
1040    use super::*;
1041
1042    #[test]
1043    fn tool_builder() {
1044        let tb = ToolBuilder::new(
1045            "build_widgets",
1046            "Create N different widgets of the specified color",
1047        )
1048        .arg::<String>("color")
1049        .arg::<f64>("count")
1050        .build();
1051        assert_eq!(
1052            r#"{
1053  "function": {
1054    "description": "Create N different widgets of the specified color",
1055    "name": "build_widgets",
1056    "parameters": {
1057      "properties": {
1058        "color": {
1059          "type": "string"
1060        },
1061        "count": {
1062          "type": "number"
1063        }
1064      },
1065      "required": [
1066        "color",
1067        "count"
1068      ],
1069      "type": "object"
1070    }
1071  },
1072  "type": "function"
1073}"#,
1074            serde_json::to_string_pretty(&tb).unwrap()
1075        );
1076    }
1077}