rust_macios/natural_language/
nl_embedding.rs

1use block::{ConcreteBlock, IntoConcreteBlock};
2use libc::{c_double, c_float};
3use objc::{msg_send, sel, sel_impl};
4
5use crate::{
6    foundation::{NSArray, NSDictionary, NSError, NSIndexSet, NSNumber, NSString, UInt, NSURL},
7    object,
8    objective_c_runtime::{
9        macros::interface_impl,
10        nil,
11        traits::{FromId, PNSObject},
12    },
13    utils::{to_bool, to_optional},
14};
15
16use super::NLLanguage;
17
18/// The distance between two strings in a text embedding.
19pub type NLDistance = c_double;
20
21/// The means of calculating a distance between two locations in a text embedding.
22#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
23#[repr(i64)]
24pub enum NLDistanceType {
25    /// A method of calculating distance by using cosine similarity.
26    Cosine,
27}
28
29object! {
30    /// A map of strings to vectors, which locates neighboring, similar strings.
31    unsafe pub struct NLEmbedding;
32}
33
34#[interface_impl(NSObject)]
35impl NLEmbedding {
36    /* Creating a word embedding
37     */
38
39    /// Retrieves a word embedding for a given language.
40    #[method]
41    pub fn word_embedding_for_language(language: NLLanguage) -> Option<NLEmbedding> {
42        unsafe {
43            to_optional(msg_send![
44                Self::m_class(),
45                wordEmbeddingForLanguage: language
46            ])
47        }
48    }
49
50    /// Retrieves a word embedding for a given language and revision.
51    #[method]
52    pub fn word_embedding_for_language_revision(
53        language: NLLanguage,
54        revision: UInt,
55    ) -> Option<NLEmbedding> {
56        unsafe {
57            to_optional(
58                msg_send![Self::m_class(), wordEmbeddingForLanguage:language revision: revision],
59            )
60        }
61    }
62
63    /// Creates a word embedding from a model file.
64    #[method]
65    pub fn embedding_with_contents_of_url(url: &NSURL) -> Result<Self, NSError>
66    where
67        Self: Sized + FromId,
68    {
69        let mut error = NSError::m_alloc();
70
71        let ptr = unsafe {
72            Self::from_id(
73                msg_send![Self::m_class(), embeddingWithContentsOfURL:url.m_self() error: &mut error],
74            )
75        };
76
77        if error.m_self() != nil {
78            Err(error)
79        } else {
80            Ok(ptr)
81        }
82    }
83
84    /* Creating a sentence embedding
85     */
86
87    /// Retrieves a sentence embedding for a given language.
88    #[method]
89    pub fn sentence_embedding_for_language(language: NLLanguage) -> Option<NLEmbedding> {
90        unsafe {
91            to_optional(msg_send![
92                Self::m_class(),
93                sentenceEmbeddingForLanguage: language
94            ])
95        }
96    }
97
98    /// Retrieves a sentence embedding for a given language and revision.
99    #[method]
100    pub fn sentence_embedding_for_language_revision(
101        language: NLLanguage,
102        revision: UInt,
103    ) -> Option<NLEmbedding> {
104        unsafe {
105            to_optional(msg_send![
106                Self::m_class(),
107                sentenceEmbeddingForLanguage: language revision: revision
108            ])
109        }
110    }
111
112    /* Finding strings and their distances in an embedding
113     */
114
115    /// Retrieves a limited number of strings near a string in the vocabulary.
116    #[method]
117    pub fn neighbors_for_string_maximum_count_distance_type(
118        &self,
119        string: &NSString,
120        max_count: UInt,
121        distance_type: NLDistanceType,
122    ) -> NSArray<NSString> {
123        unsafe {
124            NSArray::from_id(
125                msg_send![self.m_self(), neighborsForString:string.m_self() maximumCount: max_count distanceType: distance_type],
126            )
127        }
128    }
129
130    /// Retrieves a limited number of strings, within a radius of a string, in the vocabulary.
131    #[method]
132    pub fn neighbors_for_string_maximum_count_maximum_distance_distance_type(
133        &self,
134        string: &NSString,
135        max_count: UInt,
136        max_distance: NLDistance,
137        distance_type: NLDistanceType,
138    ) -> NSArray<NSString> {
139        unsafe {
140            NSArray::from_id(
141                msg_send![self.m_self(), neighborsForString: string.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type],
142            )
143        }
144    }
145
146    /// Retrieves a limited number of strings near a location in the vocabulary space.
147    #[method]
148    pub fn neighbors_for_vector_maximum_count_distance_type(
149        &self,
150        vector: &NSArray<NSNumber>,
151        max_count: UInt,
152        distance_type: NLDistanceType,
153    ) -> NSArray<NSString> {
154        unsafe {
155            NSArray::from_id(
156                msg_send![self.m_self(), neighborsForVector:vector.m_self() maximumCount: max_count distanceType: distance_type],
157            )
158        }
159    }
160
161    /// Retrieves a limited number of strings within a radius of a location in the vocabulary space.
162    #[method]
163    pub fn neighbors_for_vector_maximum_count_maximum_distance_distance_type(
164        &self,
165        vector: &NSArray<NSNumber>,
166        max_count: UInt,
167        max_distance: NLDistance,
168        distance_type: NLDistanceType,
169    ) -> NSArray<NSString> {
170        unsafe {
171            NSArray::from_id(
172                msg_send![self.m_self(), neighborsForVector: vector.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type],
173            )
174        }
175    }
176
177    /// Passes the nearest strings of a string in the vocabulary to a block.
178    #[method]
179    pub fn enumerate_neighbors_for_string_maximum_count_distance_type_using_block<F>(
180        &self,
181        string: &NSString,
182        max_count: UInt,
183        distance_type: NLDistanceType,
184        block: F,
185    ) where
186        F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
187    {
188        let block = ConcreteBlock::new(block);
189        let block = block.copy();
190
191        unsafe {
192            msg_send![self.m_self(), enumerateNeighborsForString: string.m_self() maximumCount: max_count distanceType: distance_type usingBlock: block]
193        }
194    }
195
196    /// Passes the nearest strings, within a radius of a string in the vocabulary, to a block.
197    #[method]
198    pub fn enumerate_neighbors_for_string_maximum_count_maximum_distance_distance_type_using_block<
199        F,
200    >(
201        &self,
202        string: &NSString,
203        max_count: UInt,
204        max_distance: NLDistance,
205        distance_type: NLDistanceType,
206        block: F,
207    ) where
208        F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
209    {
210        let block = ConcreteBlock::new(block);
211        let block = block.copy();
212
213        unsafe {
214            msg_send![self.m_self(), enumerateNeighborsForString: string.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type usingBlock: block]
215        }
216    }
217
218    /// Passes the nearest strings of a location in the vocabulary space to a block.
219    #[method]
220    pub fn enumerate_neighbors_for_vector_maximum_count_distance_type_using_block<F>(
221        &self,
222        vector: &NSArray<NSNumber>,
223        max_count: UInt,
224        distance_type: NLDistanceType,
225        block: F,
226    ) where
227        F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
228    {
229        let block = ConcreteBlock::new(block);
230        let block = block.copy();
231
232        unsafe {
233            msg_send![self.m_self(), enumerateNeighborsForVector: vector.m_self() maximumCount: max_count distanceType: distance_type usingBlock: block]
234        }
235    }
236
237    /// Passes the nearest strings, within a radius of a location in the vocabulary space, to a block.
238    #[method]
239    pub fn enumerate_neighbors_for_vector_maximum_count_maximum_distance_distance_type_using_block<
240        F,
241    >(
242        &self,
243        vector: &NSArray<NSNumber>,
244        max_count: UInt,
245        max_distance: NLDistance,
246        distance_type: NLDistanceType,
247        block: F,
248    ) where
249        F: IntoConcreteBlock<(NSString, NLDistance, *mut bool), Ret = ()> + 'static,
250    {
251        let block = ConcreteBlock::new(block);
252        let block = block.copy();
253
254        unsafe {
255            msg_send![self.m_self(), enumerateNeighborsForVector: vector.m_self() maximumCount: max_count maximumDistance: max_distance distanceType: distance_type usingBlock: block]
256        }
257    }
258
259    /// Calculates the distance between two strings in the vocabulary space.
260    #[method]
261    pub fn distance_between_string_and_string_distance_type(
262        &self,
263        first: &NSString,
264        second: &NSString,
265        distance_type: NLDistanceType,
266    ) -> NLDistance {
267        unsafe {
268            msg_send![self.m_self(), distanceBetweenString: first.m_self() andString: second.m_self() distanceType: distance_type]
269        }
270    }
271
272    /* Inspecting the vocabulary of an embedding
273     */
274
275    /// The number of dimensions in the vocabulary’s vector space.
276    #[property]
277    pub fn dimension(&self) -> UInt {
278        unsafe { msg_send![self.m_self(), dimension] }
279    }
280
281    /// The number of words in the vocabulary.
282    #[property]
283    pub fn vocabulary_size(&self) -> UInt {
284        unsafe { msg_send![self.m_self(), vocabularySize] }
285    }
286
287    /// The language of the text in the word embedding.
288    #[property]
289    pub fn language(&self) -> Option<NLLanguage> {
290        unsafe { to_optional(msg_send![self.m_self(), language]) }
291    }
292
293    /// Requests a Boolean value that indicates whether the term is in the vocabulary.
294    #[method]
295    pub fn contains_string(&self, string: &NSString) -> bool {
296        unsafe { to_bool(msg_send![self.m_self(), containsString: string.m_self()]) }
297    }
298
299    /// Requests the vector for the given term.
300    #[method]
301    pub fn vector_for_string(&self, string: &NSString) -> NSArray<NSNumber> {
302        unsafe { NSArray::from_id(msg_send![self.m_self(), vectorForString: string.m_self()]) }
303    }
304
305    /// Copies a vector into the given a pointer to a float array.
306    #[method]
307    pub fn get_vector_for_string(&self, vector: &mut [c_float], string: &NSString) -> bool {
308        unsafe { to_bool(msg_send![self.m_self(), getVector: vector forString: string.m_self()]) }
309    }
310
311    /// The revision of the word embedding.
312    #[property]
313    pub fn revision(&self) -> UInt {
314        unsafe { msg_send![self.m_self(), revision] }
315    }
316
317    /* Saving an embedding
318     */
319
320    /// Exports the word embedding contained within a Core ML model file at the given URL.
321    #[method]
322    pub fn write_embedding_for_dictionary_language_revision_to_url(
323        dictionary: &NSDictionary<NSString, NSArray<NSNumber>>,
324        language: NLLanguage,
325        revision: UInt,
326        url: &NSURL,
327    ) -> Result<bool, NSError> {
328        let mut error = NSError::m_alloc();
329
330        let ptr = unsafe {
331            to_bool(
332                msg_send![Self::m_class(), writeEmbeddingForDictionary: dictionary.m_self() language: language revision: revision toURL: url.m_self() error: &mut error ],
333            )
334        };
335
336        if error.m_self() != nil {
337            Err(error)
338        } else {
339            Ok(ptr)
340        }
341    }
342
343    /* Checking for Natural Language support
344     */
345
346    /// Retrieves the current version of a word embedding for the given language.
347    #[method]
348    pub fn current_revision_for_language(language: NLLanguage) -> UInt {
349        unsafe { msg_send![Self::m_class(), currentRevisionForLanguage: language] }
350    }
351
352    /// Retrieves all version numbers of a word embedding for the given language.
353    #[method]
354    pub fn supported_revisions_for_language(language: NLLanguage) -> NSIndexSet {
355        unsafe {
356            NSIndexSet::from_id(msg_send![
357                Self::m_class(),
358                supportedRevisionsForLanguage: language
359            ])
360        }
361    }
362
363    /// Retrieves the current version of a sentence embedding for the given language.
364    #[method]
365    pub fn current_sentence_embedding_revision_for_language(language: NLLanguage) -> UInt {
366        unsafe {
367            msg_send![
368                Self::m_class(),
369                currentSentenceEmbeddingRevisionForLanguage: language
370            ]
371        }
372    }
373
374    /// Retrieves all version numbers of a sentence embedding for the given language.
375    #[method]
376    pub fn supported_sentence_embedding_revisions_for_language(language: NLLanguage) -> NSIndexSet {
377        unsafe {
378            NSIndexSet::from_id(msg_send![
379                Self::m_class(),
380                supportedSentenceEmbeddingRevisionsForLanguage: language
381            ])
382        }
383    }
384}