[DataFrame] Implement where (#1989)

This commit is contained in:
Devin Petersohn
2018-05-09 14:05:52 -07:00
committed by Robert Nishihara
parent d2c193ed2c
commit 72a3a6cb02
3 changed files with 153 additions and 11 deletions
+123 -3
View File
@@ -4464,9 +4464,105 @@ class DataFrame(object):
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
errors='raise', try_cast=False, raise_on_error=None):
raise NotImplementedError(
"To contribute to Pandas on Ray, please visit "
"github.com/ray-project/ray.")
"""Replaces values not meeting condition with values in other.
Args:
cond: A condition to be met, can be callable, array-like or a
DataFrame.
other: A value or DataFrame of values to use for setting this.
inplace: Whether or not to operate inplace.
axis: The axis to apply over. Only valid when a Series is passed
as other.
level: The MultiLevel index level to apply over.
errors: Whether or not to raise errors. Does nothing in Pandas.
try_cast: Try to cast the result back to the input type.
raise_on_error: Whether to raise invalid datatypes (deprecated).
Returns:
A new DataFrame with the replaced values.
"""
inplace = validate_bool_kwarg(inplace, 'inplace')
if isinstance(other, pd.Series) and axis is None:
raise ValueError("Must specify axis=0 or 1")
if level is not None:
raise NotImplementedError("Multilevel Index not yet supported on "
"Pandas on Ray.")
axis = pd.DataFrame()._get_axis_number(axis) if axis is not None else 0
cond = cond(self) if callable(cond) else cond
if not isinstance(cond, DataFrame):
if not hasattr(cond, 'shape'):
cond = np.asanyarray(cond)
if cond.shape != self.shape:
raise ValueError("Array conditional must be same shape as "
"self")
cond = DataFrame(cond, index=self.index, columns=self.columns)
zipped_partitions = self._copartition(cond, self.index)
args = (False, axis, level, errors, try_cast, raise_on_error)
if isinstance(other, DataFrame):
other_zipped = (v for k, v in self._copartition(other,
self.index))
new_partitions = [_where_helper.remote(k, v, next(other_zipped),
self.columns, cond.columns,
other.columns, *args)
for k, v in zipped_partitions]
# Series has to be treated specially because we're operating on row
# partitions from here on.
elif isinstance(other, pd.Series):
if axis == 0:
# Pandas determines which index to use based on axis.
other = other.reindex(self.index)
other.index = pd.RangeIndex(len(other))
# Since we're working on row partitions, we have to partition
# the Series based on the partitioning of self (since both
# self and cond are co-partitioned by self.
other_builder = []
for length in self._row_metadata._lengths:
other_builder.append(other[:length])
other = other[length:]
# Resetting the index here ensures that we apply each part
# to the correct row within the partitions.
other.index = pd.RangeIndex(len(other))
other = (obj for obj in other_builder)
new_partitions = [_where_helper.remote(k, v, next(other,
pd.Series()),
self.columns,
cond.columns,
None, *args)
for k, v in zipped_partitions]
else:
other = other.reindex(self.columns)
new_partitions = [_where_helper.remote(k, v, other,
self.columns,
cond.columns,
None, *args)
for k, v in zipped_partitions]
else:
new_partitions = [_where_helper.remote(k, v, other, self.columns,
cond.columns, None, *args)
for k, v in zipped_partitions]
if inplace:
self._update_inplace(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)
else:
return DataFrame(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)
def xs(self, key, axis=0, level=None, drop_level=True):
raise NotImplementedError(
@@ -5093,3 +5189,27 @@ def _merge_columns(left_columns, right_columns, *args):
return pd.DataFrame(columns=left_columns, index=[0], dtype='uint8').merge(
pd.DataFrame(columns=right_columns, index=[0], dtype='uint8'),
*args).columns
@ray.remote
def _where_helper(left, cond, other, left_columns, cond_columns,
other_columns, *args):
left = pd.concat(ray.get(left.tolist()), axis=1)
# We have to reset the index and columns here because we are coming
# from blocks and the axes are set according to the blocks. We have
# already correctly copartitioned everything, so there's no
# correctness problems with doing this.
left.reset_index(inplace=True, drop=True)
left.columns = left_columns
cond = pd.concat(ray.get(cond.tolist()), axis=1)
cond.reset_index(inplace=True, drop=True)
cond.columns = cond_columns
if isinstance(other, np.ndarray):
other = pd.concat(ray.get(other.tolist()), axis=1)
other.reset_index(inplace=True, drop=True)
other.columns = other_columns
return left.where(cond, other, *args)
+29 -3
View File
@@ -3053,10 +3053,36 @@ def test_var(ray_df, pandas_df):
def test_where():
ray_df = create_test_dataframe()
pandas_df = pd.DataFrame(np.random.randn(100, 10),
columns=list('abcdefghij'))
ray_df = rdf.DataFrame(pandas_df)
with pytest.raises(NotImplementedError):
ray_df.where(None)
pandas_cond_df = pandas_df % 5 < 2
ray_cond_df = ray_df % 5 < 2
pandas_result = pandas_df.where(pandas_cond_df, -pandas_df)
ray_result = ray_df.where(ray_cond_df, -ray_df)
assert ray_df_equals_pandas(ray_result, pandas_result)
other = pandas_df.loc[3]
pandas_result = pandas_df.where(pandas_cond_df, other, axis=1)
ray_result = ray_df.where(ray_cond_df, other, axis=1)
assert ray_df_equals_pandas(ray_result, pandas_result)
other = pandas_df['e']
pandas_result = pandas_df.where(pandas_cond_df, other, axis=0)
ray_result = ray_df.where(ray_cond_df, other, axis=0)
assert ray_df_equals_pandas(ray_result, pandas_result)
pandas_result = pandas_df.where(pandas_df < 2, True)
ray_result = ray_df.where(ray_df < 2, True)
assert ray_df_equals_pandas(ray_result, pandas_result)
def test_xs():
+1 -5
View File
@@ -107,11 +107,7 @@ def to_pandas(df):
Returns:
A new pandas DataFrame.
"""
if df._row_partitions is not None:
pd_df = pd.concat(ray.get(df._row_partitions))
else:
pd_df = pd.concat(ray.get(df._col_partitions),
axis=1)
pd_df = pd.concat(ray.get(df._row_partitions), copy=False)
pd_df.index = df.index
pd_df.columns = df.columns
return pd_df