mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:49:16 +08:00
[Serve] Fix PyTorch Tutorial (#9019)
This commit is contained in:
@@ -16,12 +16,14 @@ from torchvision.models import resnet18
|
||||
# __doc_define_servable_begin__
|
||||
class ImageModel:
|
||||
def __init__(self):
|
||||
self.model = resnet18(pretrained=True)
|
||||
self.model = resnet18(pretrained=True).eval()
|
||||
self.preprocessor = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda t: t[:3, ...]), # remove alpha channel
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def __call__(self, flask_request):
|
||||
|
||||
Reference in New Issue
Block a user