load imagenet

This commit is contained in:
Philipp Moritz
2016-06-10 17:25:55 -07:00
parent 137909d177
commit acc51309e7
16 changed files with 209 additions and 6 deletions
+1
View File
@@ -1,3 +1,4 @@
import libraylib as lib
import serialization
from worker import scheduler_info, register_module, connect, disconnect, pull, push, remote
from libraylib import ObjRef
+1 -1
View File
@@ -176,7 +176,7 @@ def qr(a):
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objrefs[r - i, 0]
W_rcs.append(qr_helper2(y_ri, a_work.objrefs[r, c]))
W_c = ra.sum(0, *W_rcs)
W_c = ra.linalg.sum_list(*W_rcs)
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objrefs[r - i, 0]
A_rc = qr_helper1(a_work.objrefs[r, c], y_ri, t, W_c)
+3 -3
View File
@@ -68,9 +68,9 @@ def add(x1, x2):
def subtract(x1, x2):
return np.subtract(x1, x2)
@ray.remote([int, np.ndarray], [np.ndarray])
def sum(axis, *xs):
return np.sum(xs, axis=axis)
@ray.remote([np.ndarray, int], [np.ndarray])
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote([np.ndarray], [tuple])
def shape(a):
+4
View File
@@ -86,3 +86,7 @@ def matrix_rank(M):
@ray.remote([np.ndarray], [np.ndarray])
def multi_dot(*a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray])
def sum_list(*xs):
return np.sum(xs, axis=0)
+71
View File
@@ -0,0 +1,71 @@
import tarfile, io
from typing import List
import PIL.Image
import numpy as np
import boto3
import ray
s3 = boto3.client("s3")
def load_chunk(tarfile, size=None):
"""Load a number of images from a single imagenet .tar file.
This function also converts the image from grayscale to RGB if neccessary.
Args:
tarfile (tarfile.TarFile): The archive from which the files get loaded.
size (Optional[Tuple[int, int]]): Resize the image to this size if provided.
Returns:
numpy.ndarray: Contains the image data in format [batch, w, h, c]
"""
result = []
for member in tarfile.getmembers():
filename = member.path
content = tarfile.extractfile(member)
img = PIL.Image.open(content)
rgbimg = PIL.Image.new("RGB", img.size)
rgbimg.paste(img)
if size != None:
rgbimg = rgbimg.resize(size, PIL.Image.ANTIALIAS)
result.append(np.array(rgbimg).reshape(1, rgbimg.size[0], rgbimg.size[1], 3))
return np.concatenate(result)
@ray.remote([str, str, List[int]], [np.ndarray])
def load_tarfile_from_s3(bucket, s3_key, size=[]):
"""Load an imagenet .tar file.
Args:
bucket (str): Bucket holding the imagenet .tar.
s3_key (str): s3 key from which the .tar file is loaded.
size (List[int]): Resize the image to this size if size != []; len(size) == 2 required.
Returns:
np.ndarray: The image data (see load_chunk).
"""
response = s3.get_object(Bucket=bucket, Key=s3_key)
output = io.BytesIO()
chunk = response["Body"].read(1024 * 8)
while chunk:
output.write(chunk)
chunk = response["Body"].read(1024 * 8)
output.seek(0) # go to the beginning of the .tar file
tar = tarfile.open(mode= "r", fileobj=output)
return load_chunk(tar, size=size if size != [] else None)
@ray.remote([str, List[str], List[int]], [List[ray.ObjRef]])
def load_tarfiles_from_s3(bucket, s3_keys, size=[]):
"""Load a number of imagenet .tar files.
Args:
bucket (str): Bucket holding the imagenet .tars.
s3_keys (List[str]): List of s3 keys from which the .tar files are being loaded.
size (List[int]): Resize the image to this size if size != []; len(size) == 2 required.
Returns:
np.ndarray: Contains object references to the chunks of the images (see load_chunk).
"""
return [load_tarfile_from_s3(bucket, s3_key, size) for s3_key in s3_keys]