pub struct PrototypicalNetworks<S = Untrained> { /* private fields */ }Expand description
Prototypical Networks for Few-Shot Learning
Prototypical Networks learn a metric space where classification can be performed by computing distances to prototype representations of each class. The prototypes are the mean of the support examples for each class in an embedding space.
The method is particularly effective for few-shot learning scenarios where only a few labeled examples are available per class.
§Parameters
embedding_dim- Dimensionality of the embedding spacehidden_layers- Hidden layer dimensions for the embedding networkdistance_metric- Distance metric to use (‘euclidean’, ‘cosine’, ‘manhattan’)learning_rate- Learning rate for embedding network trainingn_episodes- Number of training episodesn_way- Number of classes per episoden_shot- Number of support examples per classn_query- Number of query examples per class
§Examples
ⓘ
use sklears_semi_supervised::PrototypicalNetworks;
use sklears_core::traits::{Predict, Fit};
let X = array![
[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],
[1.1, 2.1], [2.1, 3.1], [3.1, 4.1], [4.1, 5.1]
];
let y = array![0, 1, 0, 1, 0, 1, 0, 1];
let proto_net = PrototypicalNetworks::new()
.embedding_dim(32)
.n_way(2)
.n_shot(1)
.n_query(3);
let fitted = proto_net.fit(&X.view(), &y.view()).unwrap();
let predictions = fitted.predict(&X.view()).unwrap();Implementations§
Source§impl PrototypicalNetworks<Untrained>
impl PrototypicalNetworks<Untrained>
Sourcepub fn embedding_dim(self, embedding_dim: usize) -> Self
pub fn embedding_dim(self, embedding_dim: usize) -> Self
Set the embedding dimensionality
Set the hidden layer dimensions
Sourcepub fn distance_metric(self, metric: String) -> Self
pub fn distance_metric(self, metric: String) -> Self
Set the distance metric
Sourcepub fn learning_rate(self, learning_rate: f64) -> Self
pub fn learning_rate(self, learning_rate: f64) -> Self
Set the learning rate
Sourcepub fn n_episodes(self, n_episodes: usize) -> Self
pub fn n_episodes(self, n_episodes: usize) -> Self
Set the number of training episodes
Sourcepub fn n_shot(self, n_shot: usize) -> Self
pub fn n_shot(self, n_shot: usize) -> Self
Set the number of support examples per class (N-shot)
Sourcepub fn temperature(self, temperature: f64) -> Self
pub fn temperature(self, temperature: f64) -> Self
Set the temperature parameter for softmax
Trait Implementations§
Source§impl<S: Clone> Clone for PrototypicalNetworks<S>
impl<S: Clone> Clone for PrototypicalNetworks<S>
Source§fn clone(&self) -> PrototypicalNetworks<S>
fn clone(&self) -> PrototypicalNetworks<S>
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreSource§impl<S: Debug> Debug for PrototypicalNetworks<S>
impl<S: Debug> Debug for PrototypicalNetworks<S>
Source§impl Default for PrototypicalNetworks<Untrained>
impl Default for PrototypicalNetworks<Untrained>
Source§impl Estimator for PrototypicalNetworks<Untrained>
impl Estimator for PrototypicalNetworks<Untrained>
Source§type Error = SklearsError
type Error = SklearsError
Error type for the estimator
Source§fn validate_config(&self) -> Result<(), SklearsError>
fn validate_config(&self) -> Result<(), SklearsError>
Validate estimator configuration with detailed error context
Source§fn check_compatibility(
&self,
n_samples: usize,
n_features: usize,
) -> Result<(), SklearsError>
fn check_compatibility( &self, n_samples: usize, n_features: usize, ) -> Result<(), SklearsError>
Check if estimator is compatible with given data dimensions
Source§fn metadata(&self) -> EstimatorMetadata
fn metadata(&self) -> EstimatorMetadata
Get estimator metadata
Source§impl Fit<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, ArrayBase<ViewRepr<&i32>, Dim<[usize; 1]>>> for PrototypicalNetworks<Untrained>
impl Fit<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, ArrayBase<ViewRepr<&i32>, Dim<[usize; 1]>>> for PrototypicalNetworks<Untrained>
Source§type Fitted = PrototypicalNetworks<PrototypicalNetworksTrained>
type Fitted = PrototypicalNetworks<PrototypicalNetworksTrained>
The fitted model type
Source§fn fit(
self,
X: &ArrayView2<'_, Float>,
y: &ArrayView1<'_, i32>,
) -> SklResult<Self::Fitted>
fn fit( self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>, ) -> SklResult<Self::Fitted>
Fit the model to the provided data with validation
Source§fn fit_with_validation(
self,
x: &X,
y: &Y,
_x_val: Option<&X>,
_y_val: Option<&Y>,
) -> Result<(Self::Fitted, FitMetrics), SklearsError>where
Self: Sized,
fn fit_with_validation(
self,
x: &X,
y: &Y,
_x_val: Option<&X>,
_y_val: Option<&Y>,
) -> Result<(Self::Fitted, FitMetrics), SklearsError>where
Self: Sized,
Fit with custom validation and early stopping
Source§impl Predict<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<i32>, Dim<[usize; 1]>>> for PrototypicalNetworks<PrototypicalNetworksTrained>
impl Predict<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<i32>, Dim<[usize; 1]>>> for PrototypicalNetworks<PrototypicalNetworksTrained>
Source§fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>>
fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>>
Make predictions on the provided data
Source§fn predict_with_uncertainty(
&self,
x: &X,
) -> Result<(Output, UncertaintyMeasure), SklearsError>
fn predict_with_uncertainty( &self, x: &X, ) -> Result<(Output, UncertaintyMeasure), SklearsError>
Make predictions with confidence intervals
Source§impl PredictProba<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>> for PrototypicalNetworks<PrototypicalNetworksTrained>
impl PredictProba<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>> for PrototypicalNetworks<PrototypicalNetworksTrained>
Source§fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>>
fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>>
Predict class probabilities
Auto Trait Implementations§
impl<S> Freeze for PrototypicalNetworks<S>where
S: Freeze,
impl<S> RefUnwindSafe for PrototypicalNetworks<S>where
S: RefUnwindSafe,
impl<S> Send for PrototypicalNetworks<S>where
S: Send,
impl<S> Sync for PrototypicalNetworks<S>where
S: Sync,
impl<S> Unpin for PrototypicalNetworks<S>where
S: Unpin,
impl<S> UnwindSafe for PrototypicalNetworks<S>where
S: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> StableApi for Twhere
T: Estimator,
impl<T> StableApi for Twhere
T: Estimator,
Source§const STABLE_SINCE: &'static str = "0.1.0"
const STABLE_SINCE: &'static str = "0.1.0"
API version this type was stabilized in
Source§const HAS_EXPERIMENTAL_FEATURES: bool = false
const HAS_EXPERIMENTAL_FEATURES: bool = false
Whether this API has any experimental features