1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::fs::OpenOptions;
5use std::io::{Read, Write};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, SystemTime};
9
10use reqwest::RequestBuilder;
11use utf8path::Path;
12
13mod chat;
14mod chats;
15mod cli;
16mod fmt;
17mod types;
18mod wrap;
19
20pub use chat::{Chat, ChatOptions};
21pub use chats::{Chats, ChatsOptions};
22pub use fmt::Formattable;
23pub use types::{
24 ChatMessage, ChatRequest, ChatResponse, EmbedRequest, EmbedResponse, GenerateRequest,
25 GenerateResponse,
26};
27pub use wrap::WordWrap;
28
29pub const OLLAMA_HOST: &str = "http://localhost:11434";
33
34#[derive(Debug)]
38pub enum Error {
39 Internal,
41 Signal,
43 EditorNotSet,
45 EditorFailed(Option<i32>),
47 ChatNotSet,
49 InvalidArgument(String),
51 Ollama(String),
53 Io(std::io::Error),
55 Utf8Error(std::str::Utf8Error),
57 Json(serde_json::Error),
59 Reqwest(reqwest::Error),
61}
62
63impl std::error::Error for Error {}
64
65impl std::fmt::Display for Error {
66 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
67 match self {
68 Self::Internal => write!(f, "Internal error"),
69 Self::Signal => write!(f, "Signal received"),
70 Self::EditorNotSet => write!(f, "EDITOR not set"),
71 Self::EditorFailed(Some(code)) => write!(f, "Editor failed with exit code {}", code),
72 Self::EditorFailed(None) => write!(f, "Editor failed without exit code"),
73 Self::ChatNotSet => write!(f, "YAMMER_CHAT not set"),
74 Self::InvalidArgument(message) => write!(f, "invalid argument: {}", message),
75 Self::Ollama(s) => write!(f, "Ollama error: {}", s),
76 Self::Io(e) => write!(f, "I/O error: {}", e),
77 Self::Utf8Error(e) => write!(f, "UTF-8 error: {}", e),
78 Self::Json(e) => write!(f, "JSON error: {}", e),
79 Self::Reqwest(e) => write!(f, "Reqwest error: {}", e),
80 }
81 }
82}
83
84impl From<std::io::Error> for Error {
85 fn from(err: std::io::Error) -> Self {
86 Self::Io(err)
87 }
88}
89
90impl From<std::str::Utf8Error> for Error {
91 fn from(err: std::str::Utf8Error) -> Self {
92 Self::Utf8Error(err)
93 }
94}
95
96impl From<serde_json::Error> for Error {
97 fn from(e: serde_json::Error) -> Self {
98 Self::Json(e)
99 }
100}
101
102impl From<reqwest::Error> for Error {
103 fn from(e: reqwest::Error) -> Self {
104 Self::Reqwest(e)
105 }
106}
107
108#[derive(
114 Clone, Debug, Default, arrrg_derive::CommandLine, serde::Deserialize, serde::Serialize,
115)]
116pub struct Parameters {
117 #[arrrg(optional, "Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)")]
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub mirostat: Option<i32>,
121
122 #[arrrg(
127 optional,
128 "Influences how quickly the algorithm responds to feedback from the generated text."
129 )]
130 #[serde(skip_serializing_if = "Option::is_none")]
131 pub mirostat_eta: Option<f64>,
132
133 #[arrrg(
137 optional,
138 "Controls the balance between coherence and diversity of the output."
139 )]
140 #[serde(skip_serializing_if = "Option::is_none")]
141 pub mirostat_tau: Option<f64>,
142
143 #[arrrg(optional, "The number of tokens worth of context to allocate.")]
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub num_ctx: Option<u32>,
147
148 #[arrrg(
152 optional,
153 "Sets how far back for the model to look back to prevent repetition."
154 )]
155 #[serde(skip_serializing_if = "Option::is_none")]
156 pub repeat_last_n: Option<i32>,
157
158 #[arrrg(optional, "Sets how strongly to penalize repetitions.")]
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub repeat_penalty: Option<f64>,
165
166 #[arrrg(optional, "The temperature of the model.")]
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub temperature: Option<f64>,
172
173 #[arrrg(optional, "Sets the random number seed to use for generation.")]
178 #[serde(skip_serializing_if = "Option::is_none")]
179 pub seed: Option<i32>,
180
181 #[arrrg(
186 optional,
187 "Tail free sampling is used to reduce the impact of less probable tokens from the output."
188 )]
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub tfs_z: Option<f64>,
191
192 #[arrrg(optional, "Maximum number of tokens to predict when generating text.")]
196 #[serde(skip_serializing_if = "Option::is_none")]
197 pub num_predict: Option<i32>,
198
199 #[arrrg(optional, "Reduces the probability of generating nonsense.")]
204 #[serde(skip_serializing_if = "Option::is_none")]
205 pub top_k: Option<i32>,
206
207 #[arrrg(optional, "Works together with top-k.")]
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub top_p: Option<f64>,
214
215 #[arrrg(
222 optional,
223 "Alternative to the top_p, and aims to ensure a balance of quality and variety."
224 )]
225 #[serde(skip_serializing_if = "Option::is_none")]
226 pub min_p: Option<f64>,
227}
228
229impl Parameters {
230 pub fn apply(&mut self, from: Self) {
232 if let Some(mirostat) = from.mirostat {
233 self.mirostat = Some(mirostat);
234 }
235 if let Some(mirostat_eta) = from.mirostat_eta {
236 self.mirostat_eta = Some(mirostat_eta);
237 }
238 if let Some(mirostat_tau) = from.mirostat_tau {
239 self.mirostat_tau = Some(mirostat_tau);
240 }
241 if let Some(num_ctx) = from.num_ctx {
242 self.num_ctx = Some(num_ctx);
243 }
244 if let Some(repeat_last_n) = from.repeat_last_n {
245 self.repeat_last_n = Some(repeat_last_n);
246 }
247 if let Some(repeat_penalty) = from.repeat_penalty {
248 self.repeat_penalty = Some(repeat_penalty);
249 }
250 if let Some(temperature) = from.temperature {
251 self.temperature = Some(temperature);
252 }
253 if let Some(seed) = from.seed {
254 self.seed = Some(seed);
255 }
256 if let Some(tfs_z) = from.tfs_z {
257 self.tfs_z = Some(tfs_z);
258 }
259 if let Some(num_predict) = from.num_predict {
260 self.num_predict = Some(num_predict);
261 }
262 if let Some(top_k) = from.top_k {
263 self.top_k = Some(top_k);
264 }
265 if let Some(top_p) = from.top_p {
266 self.top_p = Some(top_p);
267 }
268 if let Some(min_p) = from.min_p {
269 self.min_p = Some(min_p);
270 }
271 }
272}
273
274impl From<Parameters> for serde_json::Value {
275 fn from(p: Parameters) -> serde_json::Value {
276 let mut json = serde_json::json!({});
277 if let Some(mirostat) = p.mirostat {
278 json["mirostat"] = serde_json::json!(mirostat);
279 }
280 if let Some(mirostat_eta) = p.mirostat_eta {
281 json["mirostat_eta"] = serde_json::json!(mirostat_eta);
282 }
283 if let Some(mirostat_tau) = p.mirostat_tau {
284 json["mirostat_tau"] = serde_json::json!(mirostat_tau);
285 }
286 if let Some(num_ctx) = p.num_ctx {
287 json["num_ctx"] = serde_json::json!(num_ctx);
288 }
289 if let Some(repeat_last_n) = p.repeat_last_n {
290 json["repeat_last_n"] = serde_json::json!(repeat_last_n);
291 }
292 if let Some(repeat_penalty) = p.repeat_penalty {
293 json["repeat_penalty"] = serde_json::json!(repeat_penalty);
294 }
295 if let Some(temperature) = p.temperature {
296 json["temperature"] = serde_json::json!(temperature);
297 }
298 if let Some(seed) = p.seed {
299 json["seed"] = serde_json::json!(seed);
300 }
301 if let Some(tfs_z) = p.tfs_z {
302 json["tfs_z"] = serde_json::json!(tfs_z);
303 }
304 if let Some(num_predict) = p.num_predict {
305 json["num_predict"] = serde_json::json!(num_predict);
306 }
307 if let Some(top_k) = p.top_k {
308 json["top_k"] = serde_json::json!(top_k);
309 }
310 if let Some(top_p) = p.top_p {
311 json["top_p"] = serde_json::json!(top_p);
312 }
313 if let Some(min_p) = p.min_p {
314 json["min_p"] = serde_json::json!(min_p);
315 }
316 json
317 }
318}
319
320impl Eq for Parameters {}
321
322impl PartialEq for Parameters {
323 fn eq(&self, other: &Self) -> bool {
324 self.mirostat == other.mirostat
325 && self.mirostat_eta == other.mirostat_eta
326 && self.mirostat_tau == other.mirostat_tau
327 && self.num_ctx == other.num_ctx
328 && self.repeat_last_n == other.repeat_last_n
329 && self.repeat_penalty == other.repeat_penalty
330 && self.temperature == other.temperature
331 && self.seed == other.seed
332 && self.tfs_z == other.tfs_z
333 && self.num_predict == other.num_predict
334 && self.top_k == other.top_k
335 && self.top_p == other.top_p
336 && self.min_p == other.min_p
337 }
338}
339
340#[derive(Clone, Debug, Eq, PartialEq, arrrg_derive::CommandLine)]
344pub struct ShellmOptions {
345 #[arrrg(optional, "The host to connect to.")]
347 pub ollama_host: Option<String>,
348 #[arrrg(optional, "The model to use from the ollama library.")]
350 pub model: String,
351 #[arrrg(optional, "The suffix to append to the response.")]
353 pub suffix: Option<String>,
354 #[arrrg(optional, "The system to use in the template.")]
356 pub system: Option<String>,
357 #[arrrg(optional, "The template to use for the prompt.")]
359 pub template: Option<String>,
360 #[arrrg(
362 flag,
363 "Format the response in JSON. You must also ask the model to do so."
364 )]
365 pub json: bool,
366 #[arrrg(
368 optional,
369 "Schema to adhere to when formatting the response in JSON. Has no effect without --json."
370 )]
371 pub schema: Option<serde_json::Value>,
372 #[arrrg(optional, "Whether to pass bypass formatting of the prompt.")]
374 pub raw: Option<bool>,
375 #[arrrg(optional, "Duration to keep the model in memory for after the call.")]
377 pub keep_alive: Option<String>,
378 #[arrrg(nested)]
380 pub param: Parameters,
381 #[arrrg(
383 optional,
384 "Wrap at this line length, or 0 to disable yammer-induced wrapping."
385 )]
386 pub wrap: Option<usize>,
387}
388
389impl Default for ShellmOptions {
390 fn default() -> Self {
391 ShellmOptions {
392 ollama_host: None,
393 model: "gemma2".to_string(),
395 suffix: None,
396 system: None,
397 template: None,
398 json: false,
399 schema: None,
400 raw: None,
401 keep_alive: None,
402 param: Parameters::default(),
403 wrap: None,
404 }
405 }
406}
407
408pub async fn shellm(
412 options: ShellmOptions,
413 promptfiles: &[impl AsRef<str>],
414) -> Result<(), Box<dyn std::error::Error>> {
415 let mut stdin: Option<String> = None;
416 for promptfile in promptfiles {
417 let promptfile = promptfile.as_ref();
418 let prompt = if promptfile == "-" {
419 if let Some(stdin) = stdin.as_ref() {
420 stdin.clone()
421 } else {
422 let mut s = String::new();
423 std::io::stdin().read_to_string(&mut s)?;
424 stdin = Some(s.clone());
425 s
426 }
427 } else {
428 match std::fs::read_to_string(promptfile) {
429 Ok(s) => s,
430 Err(e) => {
431 eprintln!("shellm: {}: {}", promptfile, e);
432 continue;
433 }
434 }
435 };
436 let gen = GenerateRequest {
437 model: options.model.clone(),
438 prompt,
439 suffix: options.suffix.clone(),
440 images: None,
441 format: if options.json {
442 if let Some(schema) = options.schema.clone() {
443 Some(schema)
444 } else {
445 Some(serde_json::Value::String("json".to_string()))
446 }
447 } else {
448 None
449 },
450 system: options.system.clone(),
451 template: options.template.clone(),
452 stream: Some(true),
453 raw: options.raw,
454 keep_alive: None,
455 options: Some(options.param.clone().into()),
456 };
457 let req = gen.make_request(&ollama_host(options.ollama_host.clone()));
458 let mut ww = WordWrap::new(options.wrap.unwrap_or(100));
459 let res = stream(req, |v| {
460 if let Some(serde_json::Value::String(message)) = v.get("response") {
461 ww.push(message.clone(), &mut std::io::stdout())?;
462 }
463 Ok(())
464 })
465 .await;
466 if let Err(Error::Signal) = res {
467 break;
468 } else if let Err(err) = res {
469 return Err(err.into());
470 }
471 writeln!(std::io::stdout())?;
472 }
473 Ok(())
474}
475
476#[derive(Clone, Debug, Default, Eq, PartialEq, arrrg_derive::CommandLine)]
480pub struct OneshotOptions {
481 #[arrrg(optional, "The host to connect to.")]
483 pub ollama_host: Option<String>,
484 #[arrrg(optional, "The suffix to append to the response.")]
486 pub suffix: Option<String>,
487 #[arrrg(optional, "The system to use in the template.")]
489 pub system: Option<String>,
490 #[arrrg(optional, "The template to use for the prompt.")]
492 pub template: Option<String>,
493 #[arrrg(
495 flag,
496 "Format the response in JSON. You must also ask the model to do so."
497 )]
498 pub json: bool,
499 #[arrrg(
501 optional,
502 "Schema to adhere to when formatting the response in JSON. Has no effect without --json."
503 )]
504 pub schema: Option<serde_json::Value>,
505 #[arrrg(optional, "Whether to pass bypass formatting of the prompt.")]
507 pub raw: Option<bool>,
508 #[arrrg(optional, "Duration to keep the model in memory for after the call.")]
510 pub keep_alive: Option<String>,
511 #[arrrg(nested)]
513 pub param: Parameters,
514 #[arrrg(
516 optional,
517 "Wrap at this line length, or 0 to disable yammer-induced wrapping."
518 )]
519 pub wrap: Option<usize>,
520}
521
522pub fn editor(default: &str) -> Result<impl AsRef<String>, Error> {
526 let path = format!(
527 ".yammer.{}.{}",
528 std::process::id(),
529 SystemTime::now()
530 .duration_since(SystemTime::UNIX_EPOCH)
531 .unwrap_or(Duration::ZERO)
532 .as_micros()
533 );
534 let editor = std::env::var("EDITOR").map_err(|_| Error::EditorNotSet)?;
535 struct UnlinkOnDrop(String);
536 impl Drop for UnlinkOnDrop {
537 fn drop(&mut self) {
538 let _ = std::fs::remove_file(&self.0);
539 }
540 }
541 impl AsRef<String> for UnlinkOnDrop {
542 fn as_ref(&self) -> &String {
543 &self.0
544 }
545 }
546 let mut file = OpenOptions::new()
547 .create_new(true)
548 .write(true)
549 .open(&path)?;
550 let unlink = UnlinkOnDrop(path.clone());
551 file.write_all(default.as_bytes())?;
552 file.flush()?;
553 file.sync_all()?;
554 drop(file);
555 let status = std::process::Command::new(editor).arg(&path).status()?;
556 if Some(0) != status.code() {
557 return Err(Error::EditorFailed(status.code()));
558 }
559 Ok(unlink)
560}
561
562pub async fn oneshot(
566 options: OneshotOptions,
567 models: &[impl AsRef<str>],
568) -> Result<(), Box<dyn std::error::Error>> {
569 let path = editor("Replace this text with your prompt.")?;
570 for model in models {
571 let options = ShellmOptions {
572 ollama_host: options.ollama_host.clone(),
573 model: model.as_ref().to_string(),
574 suffix: options.suffix.clone(),
575 system: options.system.clone(),
576 template: options.template.clone(),
577 json: options.json,
578 schema: options.schema.clone(),
579 raw: options.raw,
580 keep_alive: options.keep_alive.clone(),
581 param: options.param.clone(),
582 wrap: options.wrap,
583 };
584 shellm(options, &[path.as_ref()]).await?;
585 }
586 Ok(())
587}
588
589#[derive(Clone, Debug, Eq, PartialEq, arrrg_derive::CommandLine)]
593pub struct PromptOptions {
594 #[arrrg(optional, "The host to connect to.")]
596 pub ollama_host: Option<String>,
597 #[arrrg(optional, "The model to use from the ollama library.")]
599 pub model: String,
600 #[arrrg(optional, "The suffix to append to the response.")]
602 pub suffix: Option<String>,
603 #[arrrg(optional, "The system to use in the template.")]
605 pub system: Option<String>,
606 #[arrrg(optional, "The template to use for the prompt.")]
608 pub template: Option<String>,
609 #[arrrg(
611 flag,
612 "Format the response in JSON. You must also ask the model to do so."
613 )]
614 pub json: bool,
615 #[arrrg(
617 optional,
618 "Schema to adhere to when formatting the response in JSON. Has no effect without --json."
619 )]
620 pub schema: Option<serde_json::Value>,
621 #[arrrg(optional, "Whether to pass bypass formatting of the prompt.")]
623 pub raw: Option<bool>,
624 #[arrrg(optional, "Duration to keep the model in memory for after the call.")]
626 pub keep_alive: Option<String>,
627 #[arrrg(nested)]
629 pub param: Parameters,
630 #[arrrg(
632 optional,
633 "Wrap at this line length, or 0 to disable yammer-induced wrapping."
634 )]
635 pub wrap: Option<usize>,
636}
637
638impl Default for PromptOptions {
639 fn default() -> Self {
640 PromptOptions {
641 ollama_host: None,
642 model: "gemma2".to_string(),
643 suffix: None,
644 system: None,
645 template: None,
646 json: false,
647 raw: None,
648 schema: None,
649 keep_alive: None,
650 param: Parameters::default(),
651 wrap: None,
652 }
653 }
654}
655
656pub async fn prompt(
660 options: PromptOptions,
661 prompts: &[impl AsRef<str>],
662) -> Result<(), Box<dyn std::error::Error>> {
663 for prompt in prompts {
664 let gen = GenerateRequest {
665 model: options.model.clone(),
666 prompt: prompt.as_ref().to_string(),
667 suffix: options.suffix.clone(),
668 images: None,
669 format: if options.json {
670 if let Some(schema) = options.schema.clone() {
671 Some(schema)
672 } else {
673 Some(serde_json::Value::String("json".to_string()))
674 }
675 } else {
676 None
677 },
678 system: options.system.clone(),
679 template: options.template.clone(),
680 stream: Some(true),
681 raw: options.raw,
682 keep_alive: None,
683 options: Some(options.param.clone().into()),
684 };
685 let req = gen.make_request(&ollama_host(options.ollama_host.clone()));
686 let mut ww = WordWrap::new(options.wrap.unwrap_or(100));
687 let res = stream(req, |v| {
688 if let Some(serde_json::Value::String(message)) = v.get("response") {
689 ww.push(message.clone(), &mut std::io::stdout())?;
690 }
691 Ok(())
692 })
693 .await;
694 if let Err(Error::Signal) = res {
695 break;
696 } else if let Err(err) = res {
697 return Err(err.into());
698 }
699 writeln!(std::io::stdout())?;
700 }
701 Ok(())
702}
703
704pub async fn chat_shell(changelog: Option<Path<'_>>, options: ChatOptions) -> Result<(), Error> {
708 let chat = Chat::new(changelog, options)?;
709 chat.shell().await
710}
711
712pub async fn chats_shell(options: ChatsOptions) -> Result<(), Error> {
716 let chats = Chats::new(options)?;
717 chats.shell().await
718}
719
720pub async fn stream(
724 req: RequestBuilder,
725 for_each: impl FnMut(serde_json::Value) -> Result<(), Error>,
726) -> Result<(), Error> {
727 let sns = stream_no_signal(req, for_each);
728 let sig = async {
729 loop {
730 tokio::time::sleep(Duration::from_millis(50)).await;
731 if minimal_signals::pending()
732 .iter()
733 .filter(|s| *s != minimal_signals::SIGCHLD)
734 .count()
735 > 0
736 {
737 break;
738 }
739 }
740 };
741 tokio::select! {
742 res = sns => res,
743 _ = sig => Err(Error::Signal),
744 }
745}
746
747async fn stream_no_signal(
748 req: RequestBuilder,
749 mut for_each: impl FnMut(serde_json::Value) -> Result<(), Error>,
750) -> Result<(), Error> {
751 let mut resp = req.send().await?;
752 if resp.status() != 200 {
753 return if let Some(chunk) = resp.chunk().await? {
754 #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
755 struct ErrorResponse {
756 pub error: String,
757 }
758 let err = serde_json::from_slice::<ErrorResponse>(&chunk)?;
759 Err(Error::Ollama(err.error))
760 } else {
761 Err(Error::Internal)
762 };
763 }
764 let mut leftovers: Vec<u8> = vec![];
765 while let Some(chunk) = resp.chunk().await? {
766 leftovers.extend(&chunk);
767 if let Ok(value) = serde_json::from_slice(&leftovers) {
768 for_each(value)?;
769 leftovers.clear();
770 }
771 }
772 if !leftovers.is_empty() {
773 let Ok(value) = serde_json::from_slice(&leftovers) else {
774 return Err(Error::Ollama(format!(
775 "Host returned invalid JSON chunk {leftovers:?}"
776 )));
777 };
778 for_each(value)?;
779 }
780 Ok(())
781}
782
783pub fn ollama_host(host: Option<String>) -> String {
788 host.unwrap_or_else(|| std::env::var("OLLAMA_HOST").unwrap_or_else(|_| OLLAMA_HOST.to_string()))
789}
790
791pub fn chat_root() -> Result<Path<'static>, Error> {
795 let root = std::env::var("YAMMER_CHAT").map_err(|_| Error::ChatNotSet)?;
796 Ok(Path::from(root))
797}
798
799pub fn chat_path(chat_id: &str) -> Result<Path<'static>, Error> {
803 let root = chat_root()?;
804 Ok(root.join("chats").join(format!("{}.ndjson", chat_id)))
805}
806
807const SPINNER: &[&str] = &["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
810
811#[derive(Debug)]
813pub struct Spinner {
814 done: Arc<AtomicBool>,
815 inhibited: Arc<Mutex<bool>>,
816 background: Option<std::thread::JoinHandle<()>>,
817}
818
819impl Spinner {
820 #[allow(clippy::new_without_default)]
822 pub fn new() -> Self {
823 let done = Arc::new(AtomicBool::new(false));
824 let done_p = Arc::clone(&done);
825 let inhibited = Arc::new(Mutex::new(true));
826 let inhibited_p = Arc::clone(&inhibited);
827 let background = std::thread::spawn(move || {
828 let mut i = 0;
829 while !done_p.load(Ordering::Relaxed) {
830 std::thread::sleep(std::time::Duration::from_millis(50));
831 let inhibited_p = inhibited_p.lock().unwrap();
832 if *inhibited_p {
833 continue;
834 }
835 let mut stderr = std::io::stderr().lock();
836 let _ = stderr.write(b"\x1b[2K\r");
837 let _ = stderr.write(SPINNER[i % SPINNER.len()].as_bytes());
838 let _ = stderr.write(" ".as_bytes());
839 let _ = stderr.flush();
840 i += 1;
841 }
842 });
843 Self {
844 done,
845 inhibited,
846 background: Some(background),
847 }
848 }
849
850 pub fn start(&self) {
852 *self.inhibited.lock().unwrap() = false;
853 }
854
855 pub fn inhibit(&self) {
857 let mut inhibited = self.inhibited.lock().unwrap();
858 if !*inhibited {
859 *inhibited = true;
860 let mut stderr = std::io::stderr().lock();
861 let _ = stderr.write(b"\x1b[2K\r");
862 }
863 }
864}
865
866impl Drop for Spinner {
867 fn drop(&mut self) {
868 self.done.store(true, Ordering::Relaxed);
869 self.inhibit();
870 if let Some(background) = self.background.take() {
871 background.join().unwrap();
872 }
873 }
874}
875
876pub trait JsonSchema {
880 fn json_schema() -> serde_json::Value;
882}
883
884impl JsonSchema for bool {
885 fn json_schema() -> serde_json::Value {
886 serde_json::json! {{ "type": "boolean" }}
887 }
888}
889
890impl JsonSchema for i8 {
891 fn json_schema() -> serde_json::Value {
892 serde_json::json! {{ "type": "integer" }}
893 }
894}
895
896impl JsonSchema for i16 {
897 fn json_schema() -> serde_json::Value {
898 serde_json::json! {{ "type": "integer" }}
899 }
900}
901
902impl JsonSchema for i32 {
903 fn json_schema() -> serde_json::Value {
904 serde_json::json! {{ "type": "integer" }}
905 }
906}
907
908impl JsonSchema for i64 {
909 fn json_schema() -> serde_json::Value {
910 serde_json::json! {{ "type": "integer" }}
911 }
912}
913
914impl JsonSchema for u8 {
915 fn json_schema() -> serde_json::Value {
916 serde_json::json! {{ "type": "integer" }}
917 }
918}
919
920impl JsonSchema for u16 {
921 fn json_schema() -> serde_json::Value {
922 serde_json::json! {{ "type": "integer" }}
923 }
924}
925
926impl JsonSchema for u32 {
927 fn json_schema() -> serde_json::Value {
928 serde_json::json! {{ "type": "integer" }}
929 }
930}
931
932impl JsonSchema for u64 {
933 fn json_schema() -> serde_json::Value {
934 serde_json::json! {{ "type": "integer" }}
935 }
936}
937
938impl JsonSchema for f32 {
939 fn json_schema() -> serde_json::Value {
940 serde_json::json! {{ "type": "number" }}
941 }
942}
943
944impl JsonSchema for f64 {
945 fn json_schema() -> serde_json::Value {
946 serde_json::json! {{ "type": "number" }}
947 }
948}
949
950impl JsonSchema for String {
951 fn json_schema() -> serde_json::Value {
952 serde_json::json! {{ "type": "string" }}
953 }
954}
955
956impl<T: JsonSchema> JsonSchema for Option<T> {
957 fn json_schema() -> serde_json::Value {
958 let mut res = <T as JsonSchema>::json_schema();
959 res["nullable"] = true.into();
960 res
961 }
962}
963
964impl<T: JsonSchema> JsonSchema for Vec<T> {
965 fn json_schema() -> serde_json::Value {
966 serde_json::json! {{ "type": "array", "items": <T as JsonSchema>::json_schema() }}
967 }
968}
969
970impl JsonSchema for serde_json::Value {
971 fn json_schema() -> serde_json::Value {
972 serde_json::json! {{}}
973 }
974}
975
976impl<Tz: chrono::TimeZone> JsonSchema for chrono::DateTime<Tz> {
977 fn json_schema() -> serde_json::Value {
978 String::json_schema()
979 }
980}
981
982pub struct ToolBuilder {
986 name: String,
987 description: String,
988 fields: Vec<(String, serde_json::Value)>,
989}
990
991impl ToolBuilder {
992 pub fn new(name: &str, description: &str) -> Self {
995 let name = name.to_string();
996 let description = description.to_string();
997 let fields = vec![];
998 Self {
999 name,
1000 description,
1001 fields,
1002 }
1003 }
1004
1005 pub fn arg<T: JsonSchema>(mut self, name: &str) -> Self {
1007 self.fields.push((name.to_string(), T::json_schema()));
1008 self
1009 }
1010
1011 pub fn build(self) -> serde_json::Value {
1013 let mut properties = serde_json::json! {{}};
1014 let mut required = vec![];
1015 for (name, schema) in self.fields.into_iter() {
1016 properties[name.clone()] = schema;
1017 required.push(name);
1018 }
1019 let required: serde_json::Value = required.into();
1020 let parameters = serde_json::json! {{
1021 "type": "object",
1022 "properties": properties,
1023 "required": required,
1024 }};
1025 serde_json::json! {{
1026 "type": "function",
1027 "function": {
1028 "name": self.name,
1029 "description": self.description,
1030 "parameters": parameters,
1031 }
1032 }}
1033 }
1034}
1035
1036#[cfg(test)]
1039mod tests {
1040 use super::*;
1041
1042 #[test]
1043 fn tool_builder() {
1044 let tb = ToolBuilder::new(
1045 "build_widgets",
1046 "Create N different widgets of the specified color",
1047 )
1048 .arg::<String>("color")
1049 .arg::<f64>("count")
1050 .build();
1051 assert_eq!(
1052 r#"{
1053 "function": {
1054 "description": "Create N different widgets of the specified color",
1055 "name": "build_widgets",
1056 "parameters": {
1057 "properties": {
1058 "color": {
1059 "type": "string"
1060 },
1061 "count": {
1062 "type": "number"
1063 }
1064 },
1065 "required": [
1066 "color",
1067 "count"
1068 ],
1069 "type": "object"
1070 }
1071 },
1072 "type": "function"
1073}"#,
1074 serde_json::to_string_pretty(&tb).unwrap()
1075 );
1076 }
1077}