training_gpytorch

Functions

main([args])

Main training pipeline.

parse_args()

Parse command-line arguments for GPR training script.

Classes

Arguments(input, file_type, config, output, ...)

class training_gpytorch.Arguments(input: str, file_type: str, config: str, output: str, directory: str, test_set: str)[source]

Bases: object

config: str
directory: str
file_type: str
input: str
output: str
test_set: str
training_gpytorch.main(args=None)[source]

Main training pipeline.

Executes the complete GPR training workflow including data loading, preprocessing, model training, evaluation, and saving. Can be run either from command line or programmatically with provided arguments.

Parameters:

args (dict or None, optional) –

Dictionary containing training arguments. If None, arguments are parsed from command line. Expected keys match those returned by parse_args():

  • inputstr or Path

    path to training data file

  • file_typestr

    data file format (‘csv’ or ‘pickle’)

  • configstr or Path

    path to configuration file

  • outputstr

    output filename for model

  • directorystr or Path

    output directory

  • test_setstr or Path or None

    path to test data file

Returns:

Function performs training and saves results to disk

Return type:

None

Raises:
  • FileNotFoundError – If input data file or configuration file cannot be found

  • ValueError – If file_type is not ‘csv’ or ‘pickle’

Notes

The function performs the following workflow:

  1. Load and validate input data

  2. Parse configuration settings

  3. Split data into input/output features

  4. Apply data transformations using scikit-learn transformers if requested

  5. Train GPR model using GPytorch

  6. Evaluate model performance (RMSE, correlation)

  7. Save trained model with metadata

Training uses float64 precision by default.

Examples

Command line usage: >>> main() # Uses command line arguments

Programmatic usage: >>> args = { … ‘input’: ‘data.csv’, … ‘file_type’: ‘csv’, … ‘config’: ‘config.yaml’, … ‘output’: ‘model.pth’, … ‘directory’: ‘results/’, … ‘test_set’: None … } >>> main(args)

See also

parse_args

Parse command line arguments

DataReader.read_data

Load data from file

train_model

Train GPR model

evaluate_model

Evaluate model performance

training_gpytorch.parse_args()[source]

Parse command-line arguments for GPR training script.

Parses command-line arguments required for training a Gaussian Process Regressor using GPytorch. Handles input data file, configuration file, output specifications, and optional test set.

Returns:

Parsed command-line arguments with the following attributes:

  • inputpathlib.Path

    Path to file containing the training data

  • file_typestr

    Format of the data file (‘csv’ or ‘pickle’)

  • configpathlib.Path

    Path to configuration file containing script options

  • outputstr

    Name of the output file for saving model and metadata

  • directorypathlib.Path

    Output directory path (created if it doesn’t exist)

  • test_setpathlib.Path or None

    Path to test data file, or None if not provided

Return type:

argparse.Namespace

Notes

The function automatically creates the output directory if it doesn’t exist. The test_set argument is incompatible with cross-validation mode specified in the configuration file.

Examples

>>> args = parse_args()
>>> print(args.input)
PosixPath('/path/to/data.csv')