mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[DataFrame] Implement where (#1989)
This commit is contained in:
committed by
Robert Nishihara
parent
d2c193ed2c
commit
72a3a6cb02
@@ -4464,9 +4464,105 @@ class DataFrame(object):
|
||||
|
||||
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
|
||||
errors='raise', try_cast=False, raise_on_error=None):
|
||||
raise NotImplementedError(
|
||||
"To contribute to Pandas on Ray, please visit "
|
||||
"github.com/ray-project/ray.")
|
||||
"""Replaces values not meeting condition with values in other.
|
||||
|
||||
Args:
|
||||
cond: A condition to be met, can be callable, array-like or a
|
||||
DataFrame.
|
||||
other: A value or DataFrame of values to use for setting this.
|
||||
inplace: Whether or not to operate inplace.
|
||||
axis: The axis to apply over. Only valid when a Series is passed
|
||||
as other.
|
||||
level: The MultiLevel index level to apply over.
|
||||
errors: Whether or not to raise errors. Does nothing in Pandas.
|
||||
try_cast: Try to cast the result back to the input type.
|
||||
raise_on_error: Whether to raise invalid datatypes (deprecated).
|
||||
|
||||
Returns:
|
||||
A new DataFrame with the replaced values.
|
||||
"""
|
||||
|
||||
inplace = validate_bool_kwarg(inplace, 'inplace')
|
||||
|
||||
if isinstance(other, pd.Series) and axis is None:
|
||||
raise ValueError("Must specify axis=0 or 1")
|
||||
|
||||
if level is not None:
|
||||
raise NotImplementedError("Multilevel Index not yet supported on "
|
||||
"Pandas on Ray.")
|
||||
|
||||
axis = pd.DataFrame()._get_axis_number(axis) if axis is not None else 0
|
||||
|
||||
cond = cond(self) if callable(cond) else cond
|
||||
|
||||
if not isinstance(cond, DataFrame):
|
||||
if not hasattr(cond, 'shape'):
|
||||
cond = np.asanyarray(cond)
|
||||
if cond.shape != self.shape:
|
||||
raise ValueError("Array conditional must be same shape as "
|
||||
"self")
|
||||
cond = DataFrame(cond, index=self.index, columns=self.columns)
|
||||
|
||||
zipped_partitions = self._copartition(cond, self.index)
|
||||
args = (False, axis, level, errors, try_cast, raise_on_error)
|
||||
|
||||
if isinstance(other, DataFrame):
|
||||
other_zipped = (v for k, v in self._copartition(other,
|
||||
self.index))
|
||||
|
||||
new_partitions = [_where_helper.remote(k, v, next(other_zipped),
|
||||
self.columns, cond.columns,
|
||||
other.columns, *args)
|
||||
for k, v in zipped_partitions]
|
||||
|
||||
# Series has to be treated specially because we're operating on row
|
||||
# partitions from here on.
|
||||
elif isinstance(other, pd.Series):
|
||||
if axis == 0:
|
||||
# Pandas determines which index to use based on axis.
|
||||
other = other.reindex(self.index)
|
||||
other.index = pd.RangeIndex(len(other))
|
||||
|
||||
# Since we're working on row partitions, we have to partition
|
||||
# the Series based on the partitioning of self (since both
|
||||
# self and cond are co-partitioned by self.
|
||||
other_builder = []
|
||||
for length in self._row_metadata._lengths:
|
||||
other_builder.append(other[:length])
|
||||
other = other[length:]
|
||||
# Resetting the index here ensures that we apply each part
|
||||
# to the correct row within the partitions.
|
||||
other.index = pd.RangeIndex(len(other))
|
||||
|
||||
other = (obj for obj in other_builder)
|
||||
|
||||
new_partitions = [_where_helper.remote(k, v, next(other,
|
||||
pd.Series()),
|
||||
self.columns,
|
||||
cond.columns,
|
||||
None, *args)
|
||||
for k, v in zipped_partitions]
|
||||
else:
|
||||
other = other.reindex(self.columns)
|
||||
new_partitions = [_where_helper.remote(k, v, other,
|
||||
self.columns,
|
||||
cond.columns,
|
||||
None, *args)
|
||||
for k, v in zipped_partitions]
|
||||
|
||||
else:
|
||||
new_partitions = [_where_helper.remote(k, v, other, self.columns,
|
||||
cond.columns, None, *args)
|
||||
for k, v in zipped_partitions]
|
||||
|
||||
if inplace:
|
||||
self._update_inplace(row_partitions=new_partitions,
|
||||
row_metadata=self._row_metadata,
|
||||
col_metadata=self._col_metadata)
|
||||
else:
|
||||
return DataFrame(row_partitions=new_partitions,
|
||||
row_metadata=self._row_metadata,
|
||||
col_metadata=self._col_metadata)
|
||||
|
||||
def xs(self, key, axis=0, level=None, drop_level=True):
|
||||
raise NotImplementedError(
|
||||
@@ -5093,3 +5189,27 @@ def _merge_columns(left_columns, right_columns, *args):
|
||||
return pd.DataFrame(columns=left_columns, index=[0], dtype='uint8').merge(
|
||||
pd.DataFrame(columns=right_columns, index=[0], dtype='uint8'),
|
||||
*args).columns
|
||||
|
||||
|
||||
@ray.remote
|
||||
def _where_helper(left, cond, other, left_columns, cond_columns,
|
||||
other_columns, *args):
|
||||
|
||||
left = pd.concat(ray.get(left.tolist()), axis=1)
|
||||
# We have to reset the index and columns here because we are coming
|
||||
# from blocks and the axes are set according to the blocks. We have
|
||||
# already correctly copartitioned everything, so there's no
|
||||
# correctness problems with doing this.
|
||||
left.reset_index(inplace=True, drop=True)
|
||||
left.columns = left_columns
|
||||
|
||||
cond = pd.concat(ray.get(cond.tolist()), axis=1)
|
||||
cond.reset_index(inplace=True, drop=True)
|
||||
cond.columns = cond_columns
|
||||
|
||||
if isinstance(other, np.ndarray):
|
||||
other = pd.concat(ray.get(other.tolist()), axis=1)
|
||||
other.reset_index(inplace=True, drop=True)
|
||||
other.columns = other_columns
|
||||
|
||||
return left.where(cond, other, *args)
|
||||
|
||||
@@ -3053,10 +3053,36 @@ def test_var(ray_df, pandas_df):
|
||||
|
||||
|
||||
def test_where():
|
||||
ray_df = create_test_dataframe()
|
||||
pandas_df = pd.DataFrame(np.random.randn(100, 10),
|
||||
columns=list('abcdefghij'))
|
||||
ray_df = rdf.DataFrame(pandas_df)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
ray_df.where(None)
|
||||
pandas_cond_df = pandas_df % 5 < 2
|
||||
ray_cond_df = ray_df % 5 < 2
|
||||
|
||||
pandas_result = pandas_df.where(pandas_cond_df, -pandas_df)
|
||||
ray_result = ray_df.where(ray_cond_df, -ray_df)
|
||||
|
||||
assert ray_df_equals_pandas(ray_result, pandas_result)
|
||||
|
||||
other = pandas_df.loc[3]
|
||||
|
||||
pandas_result = pandas_df.where(pandas_cond_df, other, axis=1)
|
||||
ray_result = ray_df.where(ray_cond_df, other, axis=1)
|
||||
|
||||
assert ray_df_equals_pandas(ray_result, pandas_result)
|
||||
|
||||
other = pandas_df['e']
|
||||
|
||||
pandas_result = pandas_df.where(pandas_cond_df, other, axis=0)
|
||||
ray_result = ray_df.where(ray_cond_df, other, axis=0)
|
||||
|
||||
assert ray_df_equals_pandas(ray_result, pandas_result)
|
||||
|
||||
pandas_result = pandas_df.where(pandas_df < 2, True)
|
||||
ray_result = ray_df.where(ray_df < 2, True)
|
||||
|
||||
assert ray_df_equals_pandas(ray_result, pandas_result)
|
||||
|
||||
|
||||
def test_xs():
|
||||
|
||||
@@ -107,11 +107,7 @@ def to_pandas(df):
|
||||
Returns:
|
||||
A new pandas DataFrame.
|
||||
"""
|
||||
if df._row_partitions is not None:
|
||||
pd_df = pd.concat(ray.get(df._row_partitions))
|
||||
else:
|
||||
pd_df = pd.concat(ray.get(df._col_partitions),
|
||||
axis=1)
|
||||
pd_df = pd.concat(ray.get(df._row_partitions), copy=False)
|
||||
pd_df.index = df.index
|
||||
pd_df.columns = df.columns
|
||||
return pd_df
|
||||
|
||||
Reference in New Issue
Block a user