TorchServe - Serving PyTorch Models

TorchServe - Serving PyTorch Models

TorchServe - Serving PyTorch Models 

What is TorchServe?

TorchServe is a flexible and easy to use tool for serving PyTorch models. It provides a set of necessary features, such as a server, a model archiver tool, an API endpoint specification, logging, metrics, batch inference, model snapshots and others. It also offers a list of advanced features, for instance, support for custom inference services, unit tests and an easy way to collect benchmark data through JMeter.

TorchServe was released on 10th June 2020 and with its latest version (0.2.0) is still in the experimental stage but still it works perfectly fine for common use cases.

How to Instal TorchServe?

  1. Installing TorchServe is straightforward. First open Terminal/Command Line.

  2. Run the following command

pip3 install torchserve torch-model-archiver

  1. Installation should be complete.

  2. To verify whether the installation is correct and/or to view information about the installation run the following command

pip3 show torchserve

TorchServe Architecture

Important Terminology:

  • Frontend: The request/response handling component of TorchServe. This portion of the serving component handles both request/response coming from clients and manages the life cycles of the models.

  • Model Workers: These workers are responsible for running the actual inference on the models. These are actual running instances of the models.

  • Model: Models could be a script_module (JIT saved models) or eager_mode_models. These models can provide custom pre- and post-processing of data along with any other model artifacts such as state_dicts. Models can be loaded from cloud storage or from local hosts.

  • Plugins: These are custom endpoints or authz/authn or batching algorithms that can be dropped into TorchServe at startup time.

  • Model Store: This is a directory in which all the loadable models exist.

Serving the Model

We will deploy an MNIST handwritten digits classifier. First, we will design a PyTorch model to solve the MNIST handwritten digits classification problem. Save the below code in a file named ''.

import torch import torch.nn.functional as F from torch import nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output

Next, we will write a script to train the above model. Save the following code in a file ''.

from __future__ import print_function import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR from model import Net def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target =, optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target =, output = model(data) # sum up batch loss test_loss += F.nll_loss(output, target, reduction='sum').item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=14, metavar='N', help='number of epochs to train (default: 14)') parser.add_argument('--lr', type=float, default=1.0, metavar='LR', help='learning rate (default: 1.0)') parser.add_argument('--gamma', type=float, default=0.7, metavar='M', help='Learning rate step gamma (default: 0.7)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) model = Net().to(device) optimizer = optim.Adadelta(model.parameters(), scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) test(model, device, test_loader) scheduler.step() if args.save_model:, "") if __name__ == '__main__': main()

Train the new model and save it using

python3 --save-model

Now, move the saved model in a new directory called 'artefacts' using the following commands.

mkdir artefacts

mv artefacts/

Next, we will write a custom handler to run the inference on our model. Save the following code in a new file called ''.

import io import os import logging import torch import numpy as np from PIL import Image from torch.autograd import Variable from torchvision import transforms logger = logging.getLogger(__name__) class MNISTDigitClassifier(object): """ MNISTDigitClassifier handler class. This handler takes a greyscale image and returns the digit in that image. """ def __init__(self): self.model = None self.mapping = None self.device = None self.initialized = False def initialize(self, ctx): """First try to load torchscript else load eager mode state_dict based model""" properties = ctx.system_properties self.device = torch.device( "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") model_dir = properties.get("model_dir") # Read model serialize/pt file model_pt_path = os.path.join(model_dir, "") # Read model definition file model_def_path = os.path.join(model_dir, "") if not os.path.isfile(model_def_path): raise RuntimeError("Missing the model definition file") from model import Net state_dict = torch.load(model_pt_path, map_location=self.device) self.model = Net() self.model.load_state_dict(state_dict) self.model.eval() logger.debug( 'Model file {0} loaded successfully'.format(model_pt_path)) self.initialized = True def preprocess(self, data): """ Scales, crops, and normalizes a PIL image for a MNIST model, returns an Numpy array """ image = data[0].get("data") if image is None: image = data[0].get("body") mnist_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) image = image = mnist_transform(image) return image def inference(self, img, topk=5): ''' Predict the class (or classes) of an image using a trained deep learning model. ''' # Convert 2D image to 1D vector img = np.expand_dims(img, 0) img = torch.from_numpy(img) self.model.eval() inputs = Variable(img).to(self.device) outputs = self.model.forward(inputs) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return [predicted_idx] def postprocess(self, inference_output): return inference_output _service = MNISTDigitClassifier() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None data = _service.preprocess(data) data = _service.inference(data) data = _service.postprocess(data) return data

This code runs the inference on grayscale images as MNIST database has all its images in grayscale.

Now, we will create a torch model archive using the torch-model-archiver utility. Run the following command to archive the model we created and saved as artefacts/ into a .mar file

torch-model-archiver --model-name mnist --version 1.0 --model-file --serialized-file artefacts/ --handler

Next we create a new directory named 'model_store' and move the archived model in it.

mkdir model_store mv mnist.mar model_store/

At last, we can start the server and query it for results. To use it, just put a test image named '0.png' in a directory called 'test_data' and run the below command in Terminal/Command Line

torchserve --start --model-store model_store --models mnist=mnist.mar curl -T test_data/0.png

We can stop the server using the below command

torchserve --stop

More Articles of Aniket Sharma:

Name Views Likes
Pyperclip: Installation and Working 990 2
Number Guessing Game using Python 683 2
Pyperclip: Not Implemented Error 1026 2
Hangman Game using Python 16785 2
Using Databases with CherryPy application 1672 2
nose: Working 507 2
pytest: Working 511 2
Open Source and Hacktoberfest 867 2
Managing Logs of CherryPy applications 1001 2
Top 20 Data Science Tools 684 2
Ajax application using CherryPy 799 2
REST application using CherryPy 664 2
On Screen Keyboard using Python 5508 2
Elastic Net Regression 815 2
US Presidential Election 2020 Prediction using Python 794 2
Sound Source Separation 1164 2
URLs with Parameters in CherryPy 1633 2
Testing CherryPy application 635 2
Handling HTML Forms with CherryPy 1448 2
Applications of Natural Language Processing in Businesses 508 2
NetworkX: Multigraphs 648 2
Tracking User Activity with CherryPy 1397 2
CherryPy: Handling Cookies 820 2
Introduction to NetworkX 633 2
TorchServe - Serving PyTorch Models 1302 2
Fake News Detection Model using Python 734 2
Keeping Home Routers secure while working remotely 483 2
Email Slicer using Python 2996 2
NetworkX: Creating a Graph 1108 2
Best Mathematics Courses for Machine Learning 551 2
Hello World in CherryPy 680 2
Building dependencies as Meson subprojects 978 2
Vehicle Detection System 1081 2
NetworkX: Examining and Removing Graph Elements 608 2
Handling URLs with CherryPy 536 2
PEP 8 - Guide to Beautiful Python Code 757 2
NetworkX: Drawing Graphs 624 2
Mad Libs Game using Python 643 2
Hosting Cherry applications 612 2
Top 5 Free Online IDEs of 2020 866 2
pytest: Introduction 534 2
Preventing Pwned and Reused Passwords 582 2
Contact Book using Python 2095 2
Introduction to CherryPy 547 2
nose: Introduction 505 2
Text-based Adventure Game using Python 3000 2
NetworkX: Adding Attributes 2279 2
NetworkX: Directed Graphs 1021 2
Dice Simulator using Python 560 2
Decorating CherryPy applications using CSS 833 2