rusty_openai/openai_api/
fine_tuning.rs

1use crate::{error_handling::OpenAIResult, openai::OpenAI};
2use serde::Serialize;
3use serde_json::Value;
4
5/// [`FineTuningApi`] struct to interact with the fine-tuning endpoints of the API.
6pub struct FineTuningApi<'a>(pub(crate) &'a OpenAI<'a>);
7
8#[derive(Serialize)]
9struct FineTuningRequest<'a> {
10    /// Model to be fine-tuned
11    model: &'a str,
12
13    /// Path to training data file
14    training_file: &'a str,
15
16    /// Optional validation data file
17    #[serde(skip_serializing_if = "Option::is_none")]
18    validation_file: Option<&'a str>,
19
20    /// Optional number of epochs for training
21    #[serde(skip_serializing_if = "Option::is_none")]
22    n_epochs: Option<u32>,
23
24    /// Optional batch size for training
25    #[serde(skip_serializing_if = "Option::is_none")]
26    batch_size: Option<u32>,
27
28    /// Optional learning rate multiplier
29    #[serde(skip_serializing_if = "Option::is_none")]
30    learning_rate_multiplier: Option<f64>,
31
32    /// Optional prompt loss weight
33    #[serde(skip_serializing_if = "Option::is_none")]
34    prompt_loss_weight: Option<f64>,
35
36    /// Optional flag to compute classification metrics
37    #[serde(skip_serializing_if = "Option::is_none")]
38    compute_classification_metrics: Option<bool>,
39
40    /// Optional number of classes for classification
41    #[serde(skip_serializing_if = "Option::is_none")]
42    classification_n_classes: Option<u32>,
43
44    /// Optional positive class for classification
45    #[serde(skip_serializing_if = "Option::is_none")]
46    classification_positive_class: Option<&'a str>,
47
48    /// Optional betas for classification metrics
49    #[serde(skip_serializing_if = "Option::is_none")]
50    classification_betas: Option<Vec<f64>>,
51}
52
53impl<'a> FineTuningApi<'a> {
54    /// Create a new fine-tuning job with the specified parameters.
55    ///
56    /// # Arguments
57    ///
58    /// * `model` - The model to be fine-tuned.
59    /// * `training_file` - The file containing training data.
60    /// * `validation_file` - Optional validation data file.
61    /// * `n_epochs` - Optional number of training epochs.
62    /// * `batch_size` - Optional batch size for training.
63    /// * `learning_rate_multiplier` - Optional learning rate multiplier.
64    /// * `prompt_loss_weight` - Optional weight for the prompt loss.
65    /// * `compute_classification_metrics` - Optional flag to compute classification metrics.
66    /// * `classification_n_classes` - Optional number of classes for classification.
67    /// * `classification_positive_class` - Optional positive class for classification.
68    /// * `classification_betas` - Optional betas for classification metrics.
69    ///
70    /// # Returns
71    ///
72    /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
73    pub async fn create_fine_tuning_job(
74        &self,
75        model: &str,                                  // Model to be fine-tuned
76        training_file: &str,                          // Path to training data file
77        validation_file: Option<&str>,                // Optional validation data file
78        n_epochs: Option<u32>,                        // Optional number of epochs for training
79        batch_size: Option<u32>,                      // Optional batch size for training
80        learning_rate_multiplier: Option<f64>,        // Optional learning rate multiplier
81        prompt_loss_weight: Option<f64>,              // Optional prompt loss weight
82        compute_classification_metrics: Option<bool>, // Optional flag to compute classification metrics
83        classification_n_classes: Option<u32>, // Optional number of classes for classification
84        classification_positive_class: Option<&str>, // Optional positive class for classification
85        classification_betas: Option<Vec<f64>>, // Optional betas for classification metrics
86    ) -> OpenAIResult<Value> {
87        // Initialize a JSON map to build the request body.
88        let body = FineTuningRequest {
89            model,
90            training_file,
91            validation_file,
92            n_epochs,
93            batch_size,
94            learning_rate_multiplier,
95            prompt_loss_weight,
96            compute_classification_metrics,
97            classification_n_classes,
98            classification_positive_class,
99            classification_betas,
100        };
101
102        // Send a POST request to the fine-tuning jobs endpoint with the request body.
103        self.0.post_json("/fine-tuning/jobs", &body).await
104    }
105
106    /// List all fine-tuning jobs.
107    ///
108    /// # Returns
109    ///
110    /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
111
112    pub async fn list_fine_tuning_jobs(&self) -> OpenAIResult<Value> {
113        // Send a GET request to the fine-tuning jobs endpoint.
114        self.0.get("/fine-tuning/jobs").await
115    }
116
117    /// Retrieve information about a specific fine-tuning job.
118    ///
119    /// # Arguments
120    ///
121    /// * `job_id` - The ID of the fine-tuning job to retrieve.
122    ///
123    /// # Returns
124    ///
125    /// A Result containing the JSON response as [`serde_json::Value`] on success, or an [`OpenAIError`][crate::error_handling::OpenAIError] on failure.
126    pub async fn retrieve_fine_tuning_job(&self, job_id: &str) -> OpenAIResult<Value> {
127        // Construct the full URL for retrieving a specific fine-tuning job.
128        let url = format!("/fine-tuning/jobs/{job_id}");
129
130        // Send a GET request to the specific fine-tuning job endpoint.
131        self.0.get(&url).await
132    }
133}