1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
#![deny(missing_docs, missing_debug_implementations)] //! # sbr-rs //! //! `sbr` implements efficient recommender algorithms which operate on //! sequences of items: given previous items a user has interacted with, //! the model will recommend the items the user is likely to interact with //! in the future. //! //! Implemented models: //! - LSTM: a model that uses an LSTM network over the sequence of a user's interaction //! to predict their next action; //! - EWMA: a model that uses a simpler exponentially-weighted average of past actions //! to predict future interactions. //! //! Which model performs the best will depend on your dataset. The EWMA model is much //! quicker to fit, and will probably be a good starting point. //! //! ## Example //! You can fit a model on the Movielens 100K dataset in about 10 seconds: //! //! ```rust //! # extern crate sbr; //! # extern crate rand; //! # use std::time::Instant; //! # use rand::SeedableRng; //! let mut data = sbr::datasets::download_movielens_100k().unwrap(); //! //! let mut rng = rand::XorShiftRng::from_seed([42; 16]); //! //! let (train, test) = sbr::data::user_based_split(&mut data, &mut rng, 0.2); //! let train_mat = train.to_compressed(); //! let test_mat = test.to_compressed(); //! //! println!("Train: {}, test: {}", train.len(), test.len()); //! //! let mut model = sbr::models::lstm::Hyperparameters::new(data.num_items(), 32) //! .embedding_dim(32) //! .learning_rate(0.16) //! .l2_penalty(0.0004) //! .lstm_variant(sbr::models::lstm::LSTMVariant::Normal) //! .loss(sbr::models::Loss::WARP) //! .optimizer(sbr::models::Optimizer::Adagrad) //! .num_epochs(10) //! .rng(rng) //! .build(); //! //! let start = Instant::now(); //! let loss = model.fit(&train_mat).unwrap(); //! let elapsed = start.elapsed(); //! let train_mrr = sbr::evaluation::mrr_score(&model, &train_mat).unwrap(); //! let test_mrr = sbr::evaluation::mrr_score(&model, &test_mat).unwrap(); //! //! println!( //! "Train MRR {} at loss {} and test MRR {} (in {:?})", //! train_mrr, loss, test_mrr, elapsed //! ); //! ``` #[macro_use] extern crate serde_derive; #[macro_use] extern crate itertools; #[cfg(feature = "default")] extern crate csv; #[macro_use] extern crate failure; extern crate ndarray; extern crate rand; extern crate rayon; extern crate serde; extern crate siphasher; #[cfg(feature = "default")] extern crate reqwest; extern crate wyrm; pub mod data; #[cfg(feature = "default")] pub mod datasets; pub mod evaluation; pub mod models; /// Alias for user indices. pub type UserId = usize; /// Alias for item indices. pub type ItemId = usize; /// Alias for timestamps. pub type Timestamp = usize; /// Prediction error types. #[derive(Debug, Fail)] pub enum PredictionError { /// Failed prediction due to numerical issues. #[fail(display = "Invalid prediction value: non-finite or not a number.")] InvalidPredictionValue, } /// Fitting error types. #[derive(Debug, Fail)] pub enum FittingError { /// No interactions were given. #[fail(display = "No interactions were supplied.")] NoInteractions, } /// Trait describing models that can compute predictions given /// a user's sequences of past interactions. pub trait OnlineRankingModel { /// The representation the model computes from past interactions. type UserRepresentation: std::fmt::Debug; /// Compute a user representation from past interactions. fn user_representation( &self, item_ids: &[ItemId], ) -> Result<Self::UserRepresentation, PredictionError>; /// Given a user representation, rank `item_ids` according /// to how likely the user is to interact with them in the future. fn predict( &self, user: &Self::UserRepresentation, item_ids: &[ItemId], ) -> Result<Vec<f32>, PredictionError>; }