pt2ts

Functions

create_test_data(num_inputs)

Creates test data for checking the validity of the traced model.

main()

Main conversion pipeline from PyTorch model to TorchScript.

parse_args()

Parse command-line arguments for PyTorch to TorchScript model conversion.

test_traced_model(model, traced_model, ...)

Verify integrity of traced model against original model.

trace_model(model, len_training_data, ...)

Trace a GPyTorch model for TorchScript conversion.

Classes

MeanVarModelWrapper(gp)

Wrapper class for GPyTorch models to extract mean and variance.

class pt2ts.MeanVarModelWrapper(gp)[source]

Bases: Module

Wrapper class for GPyTorch models to extract mean and variance.

Wraps a GPyTorch Gaussian Process model to provide a simplified interface that returns both mean and variance predictions, making it suitable for TorchScript tracing.

Parameters:

gp (gpytorch.models.GP) – The GPyTorch Gaussian Process model to wrap

gp

The wrapped Gaussian Process model

Type:

gpytorch.models.GP

Notes

This wrapper is necessary because GPyTorch models return distribution objects that are not directly compatible with TorchScript tracing. The wrapper extracts the mean and variance components which are tensor objects suitable for tracing.

Examples

>>> model = SomeGPyTorchModel()
>>> wrapped_model = MeanVarModelWrapper(model)
>>> mean, var = wrapped_model(test_x)
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

pt2ts.create_test_data(num_inputs: int) ndarray[Any, dtype[_ScalarType_co]][source]

Creates test data for checking the validity of the traced model.

pt2ts.main()[source]

Main conversion pipeline from PyTorch model to TorchScript.

Executes the complete conversion workflow including model loading, tracing, integrity testing, and saving the TorchScript model. Handles both transformed and non-transformed input data scenarios.

Returns:

Function performs conversion and saves TorchScript model to disk

Return type:

None

Raises:
  • FileNotFoundError – If input model file cannot be found

  • torch.jit.TracingError – If model tracing fails due to control flow issues

  • AssertionError – If traced model integrity test fails

Notes

The conversion process includes: 1. Load PyTorch model and configuration from .pth file 2. Extract training data and input transformers 3. Trace the model using torch.jit.trace with test data 4. Verify traced model produces identical outputs to original 5. Save traced model in TorchScript format

The function uses GPyTorch-specific settings for optimal tracing: - fast_pred_var() for efficient predictive variance - trace_mode() to disable incompatible GPyTorch features - max_eager_kernel_size() to disable lazy evaluation

Examples

Command line usage: >>> # python pt2ts.py -i model.pth -o converted_model.ts -d output/ >>> main()

See also

trace_model

Perform model tracing

test_traced_model

Verify traced model integrity

pt2ts.parse_args()[source]

Parse command-line arguments for PyTorch to TorchScript model conversion.

Parses command-line arguments required for converting a trained PyTorch Gaussian Process model to TorchScript format for deployment.

Returns:

Parsed command-line arguments with the following attributes:

  • inputpathlib.Path

    Path to the input PyTorch model (.pth file)

  • outputstr

    Name of the output TorchScript model file (default: ‘model.ts’)

  • directorypathlib.Path

    Directory path where TorchScript model will be saved (created if needed)

Return type:

argparse.Namespace

Notes

The function automatically creates the output directory if it doesn’t exist. Input validation ensures the input file path is converted to a Path object.

Examples

>>> args = parse_args()
>>> print(args.input)
PosixPath('/path/to/model.pth')
>>> print(args.output)
'model.ts'
pt2ts.test_traced_model(model, traced_model, input_transformer, num_inputs)[source]

Verify integrity of traced model against original model.

Compares outputs of the original GPyTorch model with the traced TorchScript model to ensure conversion accuracy. Uses randomly generated test data for comparison.

Parameters:
  • model (gpytorch.models.GP) – Original GPyTorch Gaussian Process model

  • traced_model (torch.jit.ScriptModule) – Traced TorchScript version of the model

  • input_transformer (sklearn.preprocessing transformer or None) – Input data transformer, or None if no transformation is applied

  • num_inputs (int) – Number of input features/dimensions

Return type:

None

Raises:

AssertionError – If traced model outputs don’t match original model outputs within specified tolerance (1e-14)

Notes

The test compares both mean and variance predictions from both models. Uses torch.allclose() with absolute tolerance of 1e-14 for numerical precision validation. Both models are evaluated in no_grad() mode with fast_pred_var() setting for consistency.

Examples

>>> test_traced_model(original_model, traced_model, transformer, 5)
# Passes silently if models match, raises AssertionError if not

See also

create_test_data

Generate test data for comparison

pt2ts.trace_model(model, len_training_data, transformer, num_inputs)[source]

Trace a GPyTorch model for TorchScript conversion.

Creates a TorchScript-compatible traced version of the GPyTorch model using torch.jit.trace with appropriate GPyTorch settings for optimal performance and compatibility.

Parameters:
  • model (gpytorch.models.GP) – The trained GPyTorch Gaussian Process model to trace

  • len_training_data (int) – Number of training data points, used for kernel size optimization

  • transformer (sklearn.preprocessing transformer or None) – Input data transformer, or None if no transformation is applied

  • num_inputs (int) – Number of input features/dimensions

Returns:

Traced TorchScript model that can be saved and deployed

Return type:

torch.jit.ScriptModule

Notes

The tracing process uses several GPyTorch-specific performance enhancing settings:

  • fast_pred_var(): Enables LOVE method for efficient predictive variance

  • fast_pred_samples(): Enables LOVE method for predictive samples

  • trace_mode(): Disables GPyTorch features incompatible with tracing

  • max_eager_kernel_size(): Disables lazy evaluation for better tracing

Test data is generated randomly and transformed if a transformer is provided. The model is set to evaluation mode before tracing.

Examples

>>> traced_model = trace_model(model, 1000, transformer, 5)
>>> traced_model.save('model.ts')

See also

create_test_data

Generate random test data for tracing

MeanVarModelWrapper

Wrapper class for GPyTorch models