rusty_openai/openai_api/fine_tuning.rs
1use crate::{error_handling::OpenAIResult, openai::OpenAI};
2use serde::Serialize;
3use serde_json::Value;
4
5/// [`FineTuningApi`] struct to interact with the fine-tuning endpoints of the API.
6pub struct FineTuningApi<'a>(pub(crate) &'a OpenAI<'a>);
7
8#[derive(Serialize)]
9struct FineTuningRequest<'a> {
10 /// Model to be fine-tuned
11 model: &'a str,
12
13 /// Path to training data file
14 training_file: &'a str,
15
16 /// Optional validation data file
17 #[serde(skip_serializing_if = "Option::is_none")]
18 validation_file: Option<&'a str>,
19
20 /// Optional number of epochs for training
21 #[serde(skip_serializing_if = "Option::is_none")]
22 n_epochs: Option<u32>,
23
24 /// Optional batch size for training
25 #[serde(skip_serializing_if = "Option::is_none")]
26 batch_size: Option<u32>,
27
28 /// Optional learning rate multiplier
29 #[serde(skip_serializing_if = "Option::is_none")]
30 learning_rate_multiplier: Option<f64>,
31
32 /// Optional prompt loss weight
33 #[serde(skip_serializing_if = "Option::is_none")]
34 prompt_loss_weight: Option<f64>,
35
36 /// Optional flag to compute classification metrics
37 #[serde(skip_serializing_if = "Option::is_none")]
38 compute_classification_metrics: Option<bool>,
39
40 /// Optional number of classes for classification
41 #[serde(skip_serializing_if = "Option::is_none")]
42 classification_n_classes: Option<u32>,
43
44 /// Optional positive class for classification
45 #[serde(skip_serializing_if = "Option::is_none")]
46 classification_positive_class: Option<&'a str>,
47
48 /// Optional betas for classification metrics
49 #[serde(skip_serializing_if = "Option::is_none")]
50 classification_betas: Option<Vec<f64>>,
51}
52
53impl<'a> FineTuningApi<'a> {
54 /// Create a new fine-tuning job with the specified parameters.
55 ///
56 /// # Arguments
57 ///
58 /// * `model` - The model to be fine-tuned.
59 /// * `training_file` - The file containing training data.
60 /// * `validation_file` - Optional validation data file.
61 /// * `n_epochs` - Optional number of training epochs.
62 /// * `batch_size` - Optional batch size for training.
63 /// * `learning_rate_multiplier` - Optional learning rate multiplier.
64 /// * `prompt_loss_weight` - Optional weight for the prompt loss.
65 /// * `compute_classification_metrics` - Optional flag to compute classification metrics.
66 /// * `classification_n_classes` - Optional number of classes for classification.
67 /// * `classification_positive_class` - Optional positive class for classification.
68 /// * `classification_betas` - Optional betas for classification metrics.
69 ///
70 /// # Returns
71 ///
72 /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
73 pub async fn create_fine_tuning_job(
74 &self,
75 model: &str, // Model to be fine-tuned
76 training_file: &str, // Path to training data file
77 validation_file: Option<&str>, // Optional validation data file
78 n_epochs: Option<u32>, // Optional number of epochs for training
79 batch_size: Option<u32>, // Optional batch size for training
80 learning_rate_multiplier: Option<f64>, // Optional learning rate multiplier
81 prompt_loss_weight: Option<f64>, // Optional prompt loss weight
82 compute_classification_metrics: Option<bool>, // Optional flag to compute classification metrics
83 classification_n_classes: Option<u32>, // Optional number of classes for classification
84 classification_positive_class: Option<&str>, // Optional positive class for classification
85 classification_betas: Option<Vec<f64>>, // Optional betas for classification metrics
86 ) -> OpenAIResult<Value> {
87 // Initialize a JSON map to build the request body.
88 let body = FineTuningRequest {
89 model,
90 training_file,
91 validation_file,
92 n_epochs,
93 batch_size,
94 learning_rate_multiplier,
95 prompt_loss_weight,
96 compute_classification_metrics,
97 classification_n_classes,
98 classification_positive_class,
99 classification_betas,
100 };
101
102 // Send a POST request to the fine-tuning jobs endpoint with the request body.
103 self.0.post_json("/fine-tuning/jobs", &body).await
104 }
105
106 /// List all fine-tuning jobs.
107 ///
108 /// # Returns
109 ///
110 /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
111
112 pub async fn list_fine_tuning_jobs(&self) -> OpenAIResult<Value> {
113 // Send a GET request to the fine-tuning jobs endpoint.
114 self.0.get("/fine-tuning/jobs").await
115 }
116
117 /// Retrieve information about a specific fine-tuning job.
118 ///
119 /// # Arguments
120 ///
121 /// * `job_id` - The ID of the fine-tuning job to retrieve.
122 ///
123 /// # Returns
124 ///
125 /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
126 pub async fn retrieve_fine_tuning_job(&self, job_id: &str) -> OpenAIResult<Value> {
127 // Construct the full URL for retrieving a specific fine-tuning job.
128 let url = format!("/fine-tuning/jobs/{job_id}");
129
130 // Send a GET request to the specific fine-tuning job endpoint.
131 self.0.get(&url).await
132 }
133}