Skip to main content

ModelSoup

Struct ModelSoup 

Source
pub struct ModelSoup { /* private fields */ }
Expand description

Model Soup - Weight-space averaging for improved generalization.

From “Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time” (Wortsman et al., 2022).

Model soups average the weights of multiple models (not predictions), which can:

  • Improve accuracy compared to individual models
  • No inference cost (single model at test time)
  • Work across different hyperparameters and random seeds
  • Particularly effective for models fine-tuned from same initialization

Two main recipes:

  • Uniform Soup: Simple average of all model weights
  • Greedy Soup: Iteratively add models that improve validation performance

§Example

use tensorlogic_train::{ModelSoup, SoupRecipe};
use std::collections::HashMap;
use scirs2_core::ndarray::Array2;

// Collect weights from multiple fine-tuned models
// let model_weights = vec![weights1, weights2, weights3];
// let soup = ModelSoup::uniform_soup(model_weights);
// let averaged_weights = soup.weights();

Implementations§

Source§

impl ModelSoup

Source

pub fn uniform_soup( model_weights: Vec<HashMap<String, Array2<f64>>>, ) -> TrainResult<Self>

Create a uniform soup by averaging all model weights equally.

§Arguments
  • model_weights - Weights from multiple fine-tuned models
§Returns

Model soup with uniformly averaged weights

§Example
use tensorlogic_train::ModelSoup;
use std::collections::HashMap;
use scirs2_core::ndarray::array;

let mut weights1 = HashMap::new();
weights1.insert("w".to_string(), array![[1.0, 2.0]]);

let mut weights2 = HashMap::new();
weights2.insert("w".to_string(), array![[3.0, 4.0]]);

let soup = ModelSoup::uniform_soup(vec![weights1, weights2]).unwrap();
// Averaged weights: [[2.0, 3.0]]
Source

pub fn greedy_soup( model_weights: Vec<HashMap<String, Array2<f64>>>, val_accuracies: Vec<f64>, ) -> TrainResult<Self>

Create a greedy soup by iteratively adding models that improve validation performance.

§Arguments
  • model_weights - Weights from multiple fine-tuned models
  • val_accuracies - Validation accuracy for each model
§Returns

Model soup with greedily selected and averaged weights

§Algorithm
  1. Start with best single model
  2. Try adding each remaining model to soup
  3. Keep additions that improve validation performance
  4. Repeat until no improvement
Source

pub fn weighted_soup( model_weights: Vec<HashMap<String, Array2<f64>>>, weights: Vec<f64>, ) -> TrainResult<Self>

Create a weighted soup with custom weights for each model.

§Arguments
  • model_weights - Weights from multiple fine-tuned models
  • weights - Weight for each model (will be normalized to sum to 1)
§Returns

Model soup with weighted averaged parameters

Source

pub fn weights(&self) -> &HashMap<String, Array2<f64>>

Get the averaged weights from the soup.

Source

pub fn num_models(&self) -> usize

Get the number of models in the soup.

Source

pub fn recipe(&self) -> SoupRecipe

Get the recipe used to create the soup.

Source

pub fn get_parameter(&self, name: &str) -> Option<&Array2<f64>>

Get a specific parameter by name.

Source

pub fn into_weights(self) -> HashMap<String, Array2<f64>>

Load weights into a model (consumes the soup).

This is a convenience method that returns the weights for loading into a model.

Trait Implementations§

Source§

impl Clone for ModelSoup

Source§

fn clone(&self) -> ModelSoup

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for ModelSoup

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V