import argparse
import warnings
from pathlib import Path
import gpytorch
import numpy as np
import torch
from numpy.typing import NDArray
import logging
from gpytorchwrapper.src.config.config_classes import create_config
from gpytorchwrapper.src.models.model_load import load_model
warnings.filterwarnings("ignore") # Ignore warnings from the torch.jit.trace function
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
[docs]
def parse_args():
"""
Parse command-line arguments for PyTorch to TorchScript model conversion.
Parses command-line arguments required for converting a trained PyTorch
Gaussian Process model to TorchScript format for deployment.
Returns
-------
argparse.Namespace
Parsed command-line arguments with the following attributes:
- input : pathlib.Path
Path to the input PyTorch model (.pth file)
- output : str
Name of the output TorchScript model file (default: 'model.ts')
- directory : pathlib.Path
Directory path where TorchScript model will be saved (created if needed)
Notes
-----
The function automatically creates the output directory if it doesn't exist.
Input validation ensures the input file path is converted to a Path object.
Examples
--------
>>> args = parse_args()
>>> print(args.input)
PosixPath('/path/to/model.pth')
>>> print(args.output)
'model.ts'
"""
parser = argparse.ArgumentParser(
prog="pt2ts", description="Convert a PyTorch model to a TorchScript model"
)
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
help="Path to the model .pth file.",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=False,
default="model.ts",
help="Name of the output TorchScript model",
)
parser.add_argument(
"-d",
"--directory",
type=str,
required=False,
default="./",
help="Directory where the TorchScript model is saved.",
)
args = parser.parse_args()
args.input, args.directory = map(Path, [args.input, args.directory])
args.directory.mkdir(parents=True, exist_ok=True)
return args
[docs]
class MeanVarModelWrapper(torch.nn.Module):
"""
Wrapper class for GPyTorch models to extract mean and variance.
Wraps a GPyTorch Gaussian Process model to provide a simplified interface
that returns both mean and variance predictions, making it suitable for
TorchScript tracing.
Parameters
----------
gp : gpytorch.models.GP
The GPyTorch Gaussian Process model to wrap
Attributes
----------
gp : gpytorch.models.GP
The wrapped Gaussian Process model
Notes
-----
This wrapper is necessary because GPyTorch models return distribution
objects that are not directly compatible with TorchScript tracing.
The wrapper extracts the mean and variance components which are
tensor objects suitable for tracing.
Examples
--------
>>> model = SomeGPyTorchModel()
>>> wrapped_model = MeanVarModelWrapper(model)
>>> mean, var = wrapped_model(test_x)
"""
def __init__(self, gp):
super().__init__()
self.gp = gp
[docs]
def forward(self, x):
output_dist = self.gp(x)
return output_dist.mean, output_dist.variance
[docs]
def trace_model(model, len_training_data, transformer, num_inputs):
"""
Trace a GPyTorch model for TorchScript conversion.
Creates a TorchScript-compatible traced version of the GPyTorch model
using torch.jit.trace with appropriate GPyTorch settings for optimal
performance and compatibility.
Parameters
----------
model : gpytorch.models.GP
The trained GPyTorch Gaussian Process model to trace
len_training_data : int
Number of training data points, used for kernel size optimization
transformer : sklearn.preprocessing transformer or None
Input data transformer, or None if no transformation is applied
num_inputs : int
Number of input features/dimensions
Returns
-------
torch.jit.ScriptModule
Traced TorchScript model that can be saved and deployed
Notes
-----
The tracing process uses several GPyTorch-specific performance enhancing settings:
- fast_pred_var(): Enables LOVE method for efficient predictive variance
- fast_pred_samples(): Enables LOVE method for predictive samples
- trace_mode(): Disables GPyTorch features incompatible with tracing
- max_eager_kernel_size(): Disables lazy evaluation for better tracing
Test data is generated randomly and transformed if a transformer is provided.
The model is set to evaluation mode before tracing.
Examples
--------
>>> traced_model = trace_model(model, 1000, transformer, 5)
>>> traced_model.save('model.ts')
See Also
--------
create_test_data : Generate random test data for tracing
MeanVarModelWrapper : Wrapper class for GPyTorch models
"""
test_x = create_test_data(num_inputs)
if transformer is not None:
test_x = transformer.transform(test_x)
test_x = torch.tensor(test_x, dtype=torch.float64, requires_grad=True)
with (
gpytorch.settings.fast_pred_var(), # LOVE method for predictive variance
gpytorch.settings.fast_pred_samples(), # LOVE method for predictive samples
gpytorch.settings.trace_mode(), # Required for tracing, turns off some exclusive GPyTorch features
gpytorch.settings.max_eager_kernel_size(
len_training_data + len(test_x)
), # Disables lazy evaluation
):
model.eval()
model(test_x) # Do precomputation
traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)
return traced_model
[docs]
def create_test_data(num_inputs: int) -> NDArray:
"""
Creates test data for checking the validity of the traced model.
"""
return np.random.rand(500, num_inputs)
[docs]
def test_traced_model(model, traced_model, input_transformer, num_inputs):
"""
Verify integrity of traced model against original model.
Compares outputs of the original GPyTorch model with the traced
TorchScript model to ensure conversion accuracy.
Uses randomly generated test data for comparison.
Parameters
----------
model : gpytorch.models.GP
Original GPyTorch Gaussian Process model
traced_model : torch.jit.ScriptModule
Traced TorchScript version of the model
input_transformer : sklearn.preprocessing transformer or None
Input data transformer, or None if no transformation is applied
num_inputs : int
Number of input features/dimensions
Returns
-------
None
Raises
------
AssertionError
If traced model outputs don't match original model outputs
within specified tolerance (1e-14)
Notes
-----
The test compares both mean and variance predictions from both models.
Uses torch.allclose() with absolute tolerance of 1e-14 for numerical
precision validation. Both models are evaluated in no_grad() mode
with fast_pred_var() setting for consistency.
Examples
--------
>>> test_traced_model(original_model, traced_model, transformer, 5)
# Passes silently if models match, raises AssertionError if not
See Also
--------
create_test_data : Generate test data for comparison
"""
test_x = create_test_data(num_inputs)
if input_transformer is not None:
test_x = input_transformer.transform(test_x)
test_x = torch.tensor(test_x)
with torch.no_grad(), gpytorch.settings.fast_pred_var():
traced_mean, traced_var = traced_model(test_x)
pred = model(test_x)
assert torch.allclose(traced_mean, pred.mean, atol=1e-14)
assert torch.allclose(traced_var, pred.variance, atol=1e-14)
[docs]
def main():
"""
Main conversion pipeline from PyTorch model to TorchScript.
Executes the complete conversion workflow including model loading,
tracing, integrity testing, and saving the TorchScript model.
Handles both transformed and non-transformed input data scenarios.
Returns
-------
None
Function performs conversion and saves TorchScript model to disk
Raises
------
FileNotFoundError
If input model file cannot be found
torch.jit.TracingError
If model tracing fails due to control flow issues
AssertionError
If traced model integrity test fails
Notes
-----
The conversion process includes:
1. Load PyTorch model and configuration from .pth file
2. Extract training data and input transformers
3. Trace the model using torch.jit.trace with test data
4. Verify traced model produces identical outputs to original
5. Save traced model in TorchScript format
The function uses GPyTorch-specific settings for optimal tracing:
- fast_pred_var() for efficient predictive variance
- trace_mode() to disable incompatible GPyTorch features
- max_eager_kernel_size() to disable lazy evaluation
Examples
--------
Command line usage:
>>> # python pt2ts.py -i model.pth -o converted_model.ts -d output/
>>> main()
See Also
--------
trace_model : Perform model tracing
test_traced_model : Verify traced model integrity
"""
args = parse_args()
model_dump = torch.load(args.input)
config = create_config(model_dump["config"])
train_x, train_y = (
model_dump["training_data"]["train_x"],
model_dump["training_data"]["train_y"],
)
num_inputs = config.data_conf.num_inputs
if config.transform_conf.transform_input.transform_data:
input_transformer = model_dump["training_data"]["input_transformer"]
else:
input_transformer = None
logger.info("Loading model definition.")
model, likelihood = load_model(config, model_dump, train_x, train_y)
logger.info("Start tracing model.")
traced_model = trace_model(model, len(train_x), input_transformer, num_inputs)
logger.info("Finished tracing model.")
logger.info("Test integrity traced model.")
test_traced_model(model, traced_model, input_transformer, num_inputs)
logger.info("Model integrity is good.")
logger.info(f"Saving traced model to {args.directory / args.output}.")
traced_model.save(f"{args.directory / args.output}")
if __name__ == "__main__":
main()