rust_macios/natural_language/
nl_model.rs

1use objc::{msg_send, sel, sel_impl};
2
3use crate::{
4    core_ml::MLModel,
5    foundation::{NSArray, NSDictionary, NSError, NSNumber, NSString, UInt, NSURL},
6    object,
7    objective_c_runtime::{
8        macros::interface_impl,
9        nil,
10        traits::{FromId, PNSObject},
11    },
12    utils::to_optional,
13};
14
15use super::NLModelConfiguration;
16
17object! {
18    /// A custom model trained to classify or tag natural language text.
19    unsafe pub struct NLModel;
20}
21
22#[interface_impl(NSObject)]
23impl NLModel {
24    /* Creating a model
25     */
26
27    /// Creates a new natural language model based on the given Core ML model instance.
28    #[method]
29    pub fn model_with_mlmodel(model: &MLModel) -> Result<Self, NSError>
30    where
31        Self: Sized + FromId,
32    {
33        unsafe {
34            let mut error = NSError::m_alloc();
35
36            let ptr = Self::from_id(
37                msg_send![Self::m_class(), modelWithMLModel: model.m_self() error: &mut error],
38            );
39
40            if error.m_self() != nil {
41                Err(error)
42            } else {
43                Ok(ptr)
44            }
45        }
46    }
47
48    /// Creates a new natural language model based on the given Core ML model instance.
49    #[method]
50    pub fn model_with_contents_of_url(url: &NSURL) -> Result<Self, NSError>
51    where
52        Self: Sized + FromId,
53    {
54        unsafe {
55            let mut error = NSError::m_alloc();
56
57            let ptr = Self::from_id(
58                msg_send![Self::m_class(), modelWithContentsOfURL: url.m_self() error: &mut error],
59            );
60
61            if error.m_self() != nil {
62                Err(error)
63            } else {
64                Ok(ptr)
65            }
66        }
67    }
68
69    /* Making predictions
70     */
71
72    /// Predicts a label for the given input string.
73    #[method]
74    pub fn predicted_label_for_string(&self, string: &NSString) -> Option<NSString> {
75        unsafe { to_optional(msg_send![self.m_self(), predictedLabelForString: string.m_self()]) }
76    }
77
78    /// Predicts a label for each string in the given array.
79    #[method]
80    pub fn predicted_labels_for_tokens(&self, tokens: &NSArray<NSString>) -> NSArray<NSString> {
81        unsafe {
82            NSArray::from_id(msg_send![self.m_self(), predictedLabelsForTokens: tokens.m_self()])
83        }
84    }
85
86    /// Predicts multiple possible labels for the given input string.
87    #[method]
88    pub fn predicted_label_hypotheses_for_string_maximum_count(
89        &self,
90        string: &NSString,
91        max_count: UInt,
92    ) -> NSDictionary<NSString, NSNumber> {
93        unsafe {
94            NSDictionary::from_id(
95                msg_send![self.m_self(), predictedLabelHypothesesForString: string.m_self() maximumCount: max_count],
96            )
97        }
98    }
99
100    /// Predicts multiple possible labels for each token in the given array.
101    #[method]
102    pub fn predicted_label_hypotheses_for_tokens_maximum_count(
103        &self,
104        tokens: &NSArray<NSString>,
105        max_count: UInt,
106    ) -> NSArray<NSDictionary<NSString, NSNumber>> {
107        unsafe {
108            NSArray::from_id(
109                msg_send![self.m_self(), predictedLabelHypothesesForTokens: tokens.m_self() maximumCount: max_count],
110            )
111        }
112    }
113
114    /* Inspecting a model
115     */
116
117    /// A configuration describing the natural language model.
118    #[property]
119    pub fn configuration(&self) -> NLModelConfiguration {
120        unsafe { NLModelConfiguration::from_id(msg_send![self.m_self(), configuration]) }
121    }
122}