[Serve] Fix PyTorch Tutorial (#9019)

This commit is contained in:
Simon Mo
2020-06-18 15:41:43 -07:00
committed by GitHub
parent 92f67cd2ae
commit e89b8d37ae
@@ -16,12 +16,14 @@ from torchvision.models import resnet18
# __doc_define_servable_begin__
class ImageModel:
def __init__(self):
self.model = resnet18(pretrained=True)
self.model = resnet18(pretrained=True).eval()
self.preprocessor = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, flask_request):