From e89b8d37ae35e136021dca3c2e4de9eead3c1a96 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 18 Jun 2020 15:41:43 -0700 Subject: [PATCH] [Serve] Fix PyTorch Tutorial (#9019) --- python/ray/serve/examples/doc/tutorial_pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/serve/examples/doc/tutorial_pytorch.py b/python/ray/serve/examples/doc/tutorial_pytorch.py index 106eaf0ed..0dde8abda 100644 --- a/python/ray/serve/examples/doc/tutorial_pytorch.py +++ b/python/ray/serve/examples/doc/tutorial_pytorch.py @@ -16,12 +16,14 @@ from torchvision.models import resnet18 # __doc_define_servable_begin__ class ImageModel: def __init__(self): - self.model = resnet18(pretrained=True) + self.model = resnet18(pretrained=True).eval() self.preprocessor = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __call__(self, flask_request):