diff --git a/python/ray/serve/examples/doc/tutorial_pytorch.py b/python/ray/serve/examples/doc/tutorial_pytorch.py index 106eaf0ed..0dde8abda 100644 --- a/python/ray/serve/examples/doc/tutorial_pytorch.py +++ b/python/ray/serve/examples/doc/tutorial_pytorch.py @@ -16,12 +16,14 @@ from torchvision.models import resnet18 # __doc_define_servable_begin__ class ImageModel: def __init__(self): - self.model = resnet18(pretrained=True) + self.model = resnet18(pretrained=True).eval() self.preprocessor = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __call__(self, flask_request):