Source code for training_gpytorch

import argparse
import logging
import pathlib
from dataclasses import asdict
from pathlib import Path
from sys import platform
import torch

from gpytorchwrapper.src.config.config_reader import read_yaml
from gpytorchwrapper.src.data.data_reader import DataReader
from gpytorchwrapper.src.data.data_splitter import input_output_split, split_data
from gpytorchwrapper.src.data.data_transform import transform
from gpytorchwrapper.src.models.model_train import train_model
from gpytorchwrapper.src.models.model_evaluate import evaluate_model
from gpytorchwrapper.src.models.model_save import save_model
from gpytorchwrapper.src.utils import metadata_dict, dataframe_to_tensor, Timer

from dataclasses import dataclass

__author__ = "Jenne Van Veerdeghem"
__version__ = "0.0.1"

logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)

training_timer = Timer("training")

torch.set_default_dtype(torch.float64)

# Needed for training on HPC cluster
if platform == "linux":
    pathlib.WindowsPath = pathlib.PosixPath


[docs] @dataclass class Arguments: input: str file_type: str config: str output: str directory: str test_set: str def __post_init__(self): self.input = Path(self.input) self.config = Path(self.config) self.directory = Path(self.directory) if self.test_set: self.test_set = Path(self.test_set) self.directory.mkdir(parents=True, exist_ok=True)
[docs] def parse_args(): """ Parse command-line arguments for GPR training script. Parses command-line arguments required for training a Gaussian Process Regressor using GPytorch. Handles input data file, configuration file, output specifications, and optional test set. Returns ------- argparse.Namespace Parsed command-line arguments with the following attributes: - input : pathlib.Path Path to file containing the training data - file_type : str Format of the data file ('csv' or 'pickle') - config : pathlib.Path Path to configuration file containing script options - output : str Name of the output file for saving model and metadata - directory : pathlib.Path Output directory path (created if it doesn't exist) - test_set : pathlib.Path or None Path to test data file, or None if not provided Notes ----- The function automatically creates the output directory if it doesn't exist. The test_set argument is incompatible with cross-validation mode specified in the configuration file. Examples -------- >>> args = parse_args() >>> print(args.input) PosixPath('/path/to/data.csv') """ parser = argparse.ArgumentParser( prog="GPR Training", description="Train a Gaussian Process Regressor using GPytorch.", ) parser.add_argument( "-i", "--input", type=str, required=True, help="File containing the data", ) parser.add_argument( "-f", "--file-type", type=str, required=True, help="Format of the data file. Can be either csv or pickle.", ) parser.add_argument( "-c", "--config", required=True, help="The config file containing the script options.", ) parser.add_argument( "-o", "--output", type=str, required=True, help="Name of the output file containing the model and its metadata.", ) parser.add_argument( "-d", "--directory", type=str, required=True, help="Output directory", ) parser.add_argument( "-t", "--test-set", type=str, required=False, help="File containing the test data. Not usable when cross-validation is selected in the config file.", ) args = parser.parse_args() args.input, args.config, args.directory = map( Path, [args.input, args.config, args.directory] ) if args.test_set: args.test_set = Path(args.test_set) # Allow for the creation of the output directory if it does not exist args.directory.mkdir(parents=True, exist_ok=True) return args
[docs] def main(args=None): """ Main training pipeline. Executes the complete GPR training workflow including data loading, preprocessing, model training, evaluation, and saving. Can be run either from command line or programmatically with provided arguments. Parameters ---------- args : dict or None, optional Dictionary containing training arguments. If None, arguments are parsed from command line. Expected keys match those returned by parse_args(): - input : str or Path path to training data file - file_type : str data file format ('csv' or 'pickle') - config : str or Path path to configuration file - output : str output filename for model - directory : str or Path output directory - test_set : str or Path or None path to test data file Returns ------- None Function performs training and saves results to disk Raises ------ FileNotFoundError If input data file or configuration file cannot be found ValueError If file_type is not 'csv' or 'pickle' Notes ----- The function performs the following workflow: #. Load and validate input data #. Parse configuration settings #. Split data into input/output features #. Apply data transformations using scikit-learn transformers if requested #. Train GPR model using GPytorch #. Evaluate model performance (RMSE, correlation) #. Save trained model with metadata Training uses float64 precision by default. Examples -------- Command line usage: >>> main() # Uses command line arguments Programmatic usage: >>> args = { ... 'input': 'data.csv', ... 'file_type': 'csv', ... 'config': 'config.yaml', ... 'output': 'model.pth', ... 'directory': 'results/', ... 'test_set': None ... } >>> main(args) See Also -------- parse_args : Parse command line arguments DataReader.read_data : Load data from file train_model : Train GPR model evaluate_model : Evaluate model performance """ if args is None: args = parse_args() else: args = Arguments(**args) reader = DataReader() data = reader.read_data(file=args.input, file_type=args.file_type) logger.info(f"Data loaded from {args.input}.") # Read the input files and split the specifications config = read_yaml(args.config) data_conf = config.data_conf transform_conf = config.transform_conf training_conf = config.training_conf testing_conf = config.testing_conf logger.info(f"Input file {args.config} read.") # Data processing x, y = input_output_split(data, data_conf) if not args.test_set: train_x, test_x, train_y, test_y = split_data( x, y, data_conf, transform_conf, training_conf, testing_conf, args.directory ) train_x, test_x, train_y, test_y, input_transformer, output_transformer = ( transform(train_x, train_y, test_x, test_y, transform_conf) ) else: train_x, _, train_y, _ = split_data( x, y, data_conf, transform_conf, training_conf, testing_conf, args.directory ) train_x, _, train_y, _, input_transformer, output_transformer = transform( train_x, train_y, None, None, transform_conf ) test_data = reader.read_data(file=args.test_set, file_type=args.file_type) logger.info(f"Test data loaded from {args.test_set}.") x, y = input_output_split(test_data, data_conf) test_x, _, test_y, _ = split_data( x, y, data_conf, transform_conf, training_conf, testing_conf, args.directory ) test_x, _, test_y, _, input_transformer, output_transformer = transform( test_x, test_y, None, None, transform_conf ) train_x, train_y = map(dataframe_to_tensor, [train_x, train_y]) if test_x is not None: test_x, test_y = map(dataframe_to_tensor, [test_x, test_y]) # Model training training_timer.set_init_time() model, likelihood, _ = train_model(train_x, train_y, training_conf, test_x, test_y) training_timer.set_final_time() training_timer.log_timings() # Evaluate the model on the training and test sets train_rmse, test_rmse, test_corr = evaluate_model( model, likelihood, output_transformer, train_x, train_y, test_x, test_y ) # Save metadata to dictionaries training_metadata = metadata_dict( train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y, input_transformer=input_transformer, output_transformer=output_transformer, ) metrics_metadata = metadata_dict( train_rmse=train_rmse, test_rmse=test_rmse, test_corr=test_corr ) # Save the model to a .pth file save_model( model.state_dict(), asdict(config), training_metadata, metrics_metadata, args.output, args.directory, )
if __name__ == "__main__": main()