pt2ts
Functions
|
Creates test data for checking the validity of the traced model. |
|
Main conversion pipeline from PyTorch model to TorchScript. |
Parse command-line arguments for PyTorch to TorchScript model conversion. |
|
|
Verify integrity of traced model against original model. |
|
Trace a GPyTorch model for TorchScript conversion. |
Classes
Wrapper class for GPyTorch models to extract mean and variance. |
- class pt2ts.MeanVarModelWrapper(gp)[source]
Bases:
Module
Wrapper class for GPyTorch models to extract mean and variance.
Wraps a GPyTorch Gaussian Process model to provide a simplified interface that returns both mean and variance predictions, making it suitable for TorchScript tracing.
- Parameters:
gp (gpytorch.models.GP) – The GPyTorch Gaussian Process model to wrap
- gp
The wrapped Gaussian Process model
- Type:
gpytorch.models.GP
Notes
This wrapper is necessary because GPyTorch models return distribution objects that are not directly compatible with TorchScript tracing. The wrapper extracts the mean and variance components which are tensor objects suitable for tracing.
Examples
>>> model = SomeGPyTorchModel() >>> wrapped_model = MeanVarModelWrapper(model) >>> mean, var = wrapped_model(test_x)
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- pt2ts.create_test_data(num_inputs: int) ndarray[Any, dtype[_ScalarType_co]] [source]
Creates test data for checking the validity of the traced model.
- pt2ts.main()[source]
Main conversion pipeline from PyTorch model to TorchScript.
Executes the complete conversion workflow including model loading, tracing, integrity testing, and saving the TorchScript model. Handles both transformed and non-transformed input data scenarios.
- Returns:
Function performs conversion and saves TorchScript model to disk
- Return type:
None
- Raises:
FileNotFoundError – If input model file cannot be found
torch.jit.TracingError – If model tracing fails due to control flow issues
AssertionError – If traced model integrity test fails
Notes
The conversion process includes: 1. Load PyTorch model and configuration from .pth file 2. Extract training data and input transformers 3. Trace the model using torch.jit.trace with test data 4. Verify traced model produces identical outputs to original 5. Save traced model in TorchScript format
The function uses GPyTorch-specific settings for optimal tracing: - fast_pred_var() for efficient predictive variance - trace_mode() to disable incompatible GPyTorch features - max_eager_kernel_size() to disable lazy evaluation
Examples
Command line usage: >>> # python pt2ts.py -i model.pth -o converted_model.ts -d output/ >>> main()
See also
trace_model
Perform model tracing
test_traced_model
Verify traced model integrity
- pt2ts.parse_args()[source]
Parse command-line arguments for PyTorch to TorchScript model conversion.
Parses command-line arguments required for converting a trained PyTorch Gaussian Process model to TorchScript format for deployment.
- Returns:
Parsed command-line arguments with the following attributes:
- inputpathlib.Path
Path to the input PyTorch model (.pth file)
- outputstr
Name of the output TorchScript model file (default: ‘model.ts’)
- directorypathlib.Path
Directory path where TorchScript model will be saved (created if needed)
- Return type:
argparse.Namespace
Notes
The function automatically creates the output directory if it doesn’t exist. Input validation ensures the input file path is converted to a Path object.
Examples
>>> args = parse_args() >>> print(args.input) PosixPath('/path/to/model.pth') >>> print(args.output) 'model.ts'
- pt2ts.test_traced_model(model, traced_model, input_transformer, num_inputs)[source]
Verify integrity of traced model against original model.
Compares outputs of the original GPyTorch model with the traced TorchScript model to ensure conversion accuracy. Uses randomly generated test data for comparison.
- Parameters:
model (gpytorch.models.GP) – Original GPyTorch Gaussian Process model
traced_model (torch.jit.ScriptModule) – Traced TorchScript version of the model
input_transformer (sklearn.preprocessing transformer or None) – Input data transformer, or None if no transformation is applied
num_inputs (int) – Number of input features/dimensions
- Return type:
None
- Raises:
AssertionError – If traced model outputs don’t match original model outputs within specified tolerance (1e-14)
Notes
The test compares both mean and variance predictions from both models. Uses torch.allclose() with absolute tolerance of 1e-14 for numerical precision validation. Both models are evaluated in no_grad() mode with fast_pred_var() setting for consistency.
Examples
>>> test_traced_model(original_model, traced_model, transformer, 5) # Passes silently if models match, raises AssertionError if not
See also
create_test_data
Generate test data for comparison
- pt2ts.trace_model(model, len_training_data, transformer, num_inputs)[source]
Trace a GPyTorch model for TorchScript conversion.
Creates a TorchScript-compatible traced version of the GPyTorch model using torch.jit.trace with appropriate GPyTorch settings for optimal performance and compatibility.
- Parameters:
model (gpytorch.models.GP) – The trained GPyTorch Gaussian Process model to trace
len_training_data (int) – Number of training data points, used for kernel size optimization
transformer (sklearn.preprocessing transformer or None) – Input data transformer, or None if no transformation is applied
num_inputs (int) – Number of input features/dimensions
- Returns:
Traced TorchScript model that can be saved and deployed
- Return type:
torch.jit.ScriptModule
Notes
The tracing process uses several GPyTorch-specific performance enhancing settings:
fast_pred_var(): Enables LOVE method for efficient predictive variance
fast_pred_samples(): Enables LOVE method for predictive samples
trace_mode(): Disables GPyTorch features incompatible with tracing
max_eager_kernel_size(): Disables lazy evaluation for better tracing
Test data is generated randomly and transformed if a transformer is provided. The model is set to evaluation mode before tracing.
Examples
>>> traced_model = trace_model(model, 1000, transformer, 5) >>> traced_model.save('model.ts')
See also
create_test_data
Generate random test data for tracing
MeanVarModelWrapper
Wrapper class for GPyTorch models