Source code for gpytorchwrapper.src.models.model_load

from gpytorch.likelihoods import Likelihood
from gpytorch.models import GP
from torch import Tensor

from gpytorchwrapper.src.config.config_classes import Config
from gpytorchwrapper.src.config.model_factory import get_likelihood, get_model
from gpytorchwrapper.src.models.model_train import define_likelihood, define_model


[docs] def load_model( config: Config, model_dump: dict, train_x: Tensor, train_y: Tensor, ) -> tuple[GP, Likelihood]: """ Load a model from a config file and the dumped model. The model and likelihood objects are set to training mode with double precision. Parameters ---------- config : Config Object specifying the configuration used for training. model_dump : dict The unpickled dumped model file train_x : Tensor The input tensor used for training train_y : Tensor The output tensor used for training Returns ------- return : tuple of GP and Likelihood The loaded GP and likelihood objects """ likelihood_class = get_likelihood(config.training_conf.likelihood) model_class = get_model(config.training_conf.model) likelihood = define_likelihood( config.training_conf.likelihood, likelihood_class, train_x ) model = define_model( config.training_conf.model, model_class, train_x, train_y, likelihood ) model.double() likelihood.double() model.load_state_dict(model_dump["state_dict"]) return model, likelihood