training_gpytorch
Functions
|
Main training pipeline. |
Parse command-line arguments for GPR training script. |
Classes
|
- class training_gpytorch.Arguments(input: str, file_type: str, config: str, output: str, directory: str, test_set: str)[source]
Bases:
object
- config: str
- directory: str
- file_type: str
- input: str
- output: str
- test_set: str
- training_gpytorch.main(args=None)[source]
Main training pipeline.
Executes the complete GPR training workflow including data loading, preprocessing, model training, evaluation, and saving. Can be run either from command line or programmatically with provided arguments.
- Parameters:
args (dict or None, optional) –
Dictionary containing training arguments. If None, arguments are parsed from command line. Expected keys match those returned by parse_args():
- inputstr or Path
path to training data file
- file_typestr
data file format (‘csv’ or ‘pickle’)
- configstr or Path
path to configuration file
- outputstr
output filename for model
- directorystr or Path
output directory
- test_setstr or Path or None
path to test data file
- Returns:
Function performs training and saves results to disk
- Return type:
None
- Raises:
FileNotFoundError – If input data file or configuration file cannot be found
ValueError – If file_type is not ‘csv’ or ‘pickle’
Notes
The function performs the following workflow:
Load and validate input data
Parse configuration settings
Split data into input/output features
Apply data transformations using scikit-learn transformers if requested
Train GPR model using GPytorch
Evaluate model performance (RMSE, correlation)
Save trained model with metadata
Training uses float64 precision by default.
Examples
Command line usage: >>> main() # Uses command line arguments
Programmatic usage: >>> args = { … ‘input’: ‘data.csv’, … ‘file_type’: ‘csv’, … ‘config’: ‘config.yaml’, … ‘output’: ‘model.pth’, … ‘directory’: ‘results/’, … ‘test_set’: None … } >>> main(args)
See also
parse_args
Parse command line arguments
DataReader.read_data
Load data from file
train_model
Train GPR model
evaluate_model
Evaluate model performance
- training_gpytorch.parse_args()[source]
Parse command-line arguments for GPR training script.
Parses command-line arguments required for training a Gaussian Process Regressor using GPytorch. Handles input data file, configuration file, output specifications, and optional test set.
- Returns:
Parsed command-line arguments with the following attributes:
- inputpathlib.Path
Path to file containing the training data
- file_typestr
Format of the data file (‘csv’ or ‘pickle’)
- configpathlib.Path
Path to configuration file containing script options
- outputstr
Name of the output file for saving model and metadata
- directorypathlib.Path
Output directory path (created if it doesn’t exist)
- test_setpathlib.Path or None
Path to test data file, or None if not provided
- Return type:
argparse.Namespace
Notes
The function automatically creates the output directory if it doesn’t exist. The test_set argument is incompatible with cross-validation mode specified in the configuration file.
Examples
>>> args = parse_args() >>> print(args.input) PosixPath('/path/to/data.csv')