import importlib
import os
import pkgutil
from types import ModuleType
import sklearn.preprocessing as transformer_module
import gpytorch.likelihoods as likelihood_module
from gpytorch.likelihoods import Likelihood
from gpytorch.models import ExactGP
from torch.optim import Optimizer
import gpytorchwrapper.src.models.gp_models as model_module
import torch.optim as optimizer_module
from .config_classes import (
TransformerConf,
OptimizerConf,
LikelihoodConf,
ModelConf,
)
import logging
import sys
logger = logging.getLogger(__name__)
[docs]
def get_likelihood(likelihood_conf: LikelihoodConf) -> Likelihood:
"""
Get the likelihood class and options
Parameters
-----------
likelihood_conf : dict
Dictionary containing the likelihood specifications
Returns
--------
selected_likelihood_class : object
The selected likelihood class
"""
selected_likelihood = likelihood_conf.likelihood_class
return getattr(likelihood_module, selected_likelihood)
[docs]
def get_plugins(path: str | None = None) -> dict[str, ModuleType]:
"""
Parameters
----------
path : str or None, optional
path to the directory containing the model plugins
Returns
-------
discovered_plugins : dict
dict with the names of the model class as a string and the model classes as values
"""
if path is None:
# Dynamically find the plugins directory relative to this script
current_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.abspath(os.path.join(current_dir, "../../plugins"))
if not os.path.isdir(path):
raise FileNotFoundError(f"Plugins directory not found at {path}")
sys.path.insert(0, path)
discovered_plugins = {
name: importlib.import_module(name)
for finder, name, ispkg in pkgutil.iter_modules()
if name.startswith("model_")
}
return discovered_plugins
[docs]
def get_model(model_conf: ModelConf) -> ExactGP:
"""
Get the model class and options
Parameters
-----------
model_conf : ModelConf
dataclass containing the model specifications
Returns
--------
selected_model_class : object
The selected model class
"""
selected_model = model_conf.model_class
plugin_modules = get_plugins()
if hasattr(model_module, selected_model):
logger.info(f"Loading model class {selected_model} from {model_module}.")
return getattr(model_module, selected_model)
elif plugin_modules != {}:
for module in plugin_modules.values():
if hasattr(module, selected_model):
logger.info(f"Loading model class {selected_model} from {module}.")
return getattr(module, selected_model)
raise NotImplementedError(
f"The specified model class, {selected_model}, is not available in gp_models.py or the plugins folder."
)
else:
raise NotImplementedError(
f"The specified model class, {selected_model}, is not available in gp_models.py or the plugins folder."
)
[docs]
def get_optimizer(optimizer_conf: OptimizerConf) -> Optimizer:
"""
Get the optimizer class and options
Parameters
-----------
optimizer_conf : OptimizerConf
dataclass containing the optimizer specifications
Returns
--------
selected_optimizer_class : object
The selected optimizer class
"""
selected_optimizer = optimizer_conf.optimizer_class
return getattr(optimizer_module, selected_optimizer)