mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 13:12:46 +08:00
load imagenet
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import libraylib as lib
|
||||
import serialization
|
||||
from worker import scheduler_info, register_module, connect, disconnect, pull, push, remote
|
||||
from libraylib import ObjRef
|
||||
|
||||
@@ -176,7 +176,7 @@ def qr(a):
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objrefs[r - i, 0]
|
||||
W_rcs.append(qr_helper2(y_ri, a_work.objrefs[r, c]))
|
||||
W_c = ra.sum(0, *W_rcs)
|
||||
W_c = ra.linalg.sum_list(*W_rcs)
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objrefs[r - i, 0]
|
||||
A_rc = qr_helper1(a_work.objrefs[r, c], y_ri, t, W_c)
|
||||
|
||||
@@ -68,9 +68,9 @@ def add(x1, x2):
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
@ray.remote([int, np.ndarray], [np.ndarray])
|
||||
def sum(axis, *xs):
|
||||
return np.sum(xs, axis=axis)
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
def sum(x, axis=-1):
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
|
||||
@ray.remote([np.ndarray], [tuple])
|
||||
def shape(a):
|
||||
|
||||
@@ -86,3 +86,7 @@ def matrix_rank(M):
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
import tarfile, io
|
||||
from typing import List
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import boto3
|
||||
import ray
|
||||
|
||||
s3 = boto3.client("s3")
|
||||
|
||||
def load_chunk(tarfile, size=None):
|
||||
"""Load a number of images from a single imagenet .tar file.
|
||||
|
||||
This function also converts the image from grayscale to RGB if neccessary.
|
||||
|
||||
Args:
|
||||
tarfile (tarfile.TarFile): The archive from which the files get loaded.
|
||||
size (Optional[Tuple[int, int]]): Resize the image to this size if provided.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Contains the image data in format [batch, w, h, c]
|
||||
"""
|
||||
|
||||
result = []
|
||||
for member in tarfile.getmembers():
|
||||
filename = member.path
|
||||
content = tarfile.extractfile(member)
|
||||
img = PIL.Image.open(content)
|
||||
rgbimg = PIL.Image.new("RGB", img.size)
|
||||
rgbimg.paste(img)
|
||||
if size != None:
|
||||
rgbimg = rgbimg.resize(size, PIL.Image.ANTIALIAS)
|
||||
result.append(np.array(rgbimg).reshape(1, rgbimg.size[0], rgbimg.size[1], 3))
|
||||
return np.concatenate(result)
|
||||
|
||||
@ray.remote([str, str, List[int]], [np.ndarray])
|
||||
def load_tarfile_from_s3(bucket, s3_key, size=[]):
|
||||
"""Load an imagenet .tar file.
|
||||
|
||||
Args:
|
||||
bucket (str): Bucket holding the imagenet .tar.
|
||||
s3_key (str): s3 key from which the .tar file is loaded.
|
||||
size (List[int]): Resize the image to this size if size != []; len(size) == 2 required.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image data (see load_chunk).
|
||||
"""
|
||||
|
||||
response = s3.get_object(Bucket=bucket, Key=s3_key)
|
||||
output = io.BytesIO()
|
||||
chunk = response["Body"].read(1024 * 8)
|
||||
while chunk:
|
||||
output.write(chunk)
|
||||
chunk = response["Body"].read(1024 * 8)
|
||||
output.seek(0) # go to the beginning of the .tar file
|
||||
tar = tarfile.open(mode= "r", fileobj=output)
|
||||
return load_chunk(tar, size=size if size != [] else None)
|
||||
|
||||
@ray.remote([str, List[str], List[int]], [List[ray.ObjRef]])
|
||||
def load_tarfiles_from_s3(bucket, s3_keys, size=[]):
|
||||
"""Load a number of imagenet .tar files.
|
||||
|
||||
Args:
|
||||
bucket (str): Bucket holding the imagenet .tars.
|
||||
s3_keys (List[str]): List of s3 keys from which the .tar files are being loaded.
|
||||
size (List[int]): Resize the image to this size if size != []; len(size) == 2 required.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Contains object references to the chunks of the images (see load_chunk).
|
||||
"""
|
||||
|
||||
return [load_tarfile_from_s3(bucket, s3_key, size) for s3_key in s3_keys]
|
||||
Reference in New Issue
Block a user