mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 10:32:06 +08:00
ENH: Add function to concatenate list of dataframes with categoricals
STY: Alphabetized import list
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
from pandas import Timestamp
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
|
||||
from zipline.testing import ZiplineTestCase
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.date_utils import roll_dates_to_previous_session
|
||||
|
||||
|
||||
class TestRollDatesToPreviousSession(ZiplineTestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
(
|
||||
Timestamp('05-19-2017', tz='UTC'), # actual trading date
|
||||
Timestamp('05-19-2017', tz='UTC'),
|
||||
),
|
||||
(
|
||||
Timestamp('07-04-2015', tz='UTC'), # weekend nyse holiday
|
||||
Timestamp('07-02-2015', tz='UTC'),
|
||||
),
|
||||
(
|
||||
Timestamp('01-16-2017', tz='UTC'), # weeknight nyse holiday
|
||||
Timestamp('01-13-2017', tz='UTC'),
|
||||
),
|
||||
])
|
||||
def test_roll_dates_to_previous_session(self, date, expected_rolled_date):
|
||||
calendar = get_calendar('NYSE')
|
||||
result = roll_dates_to_previous_session(calendar, date)
|
||||
self.assertEqual(result[0], expected_rolled_date)
|
||||
@@ -4,7 +4,11 @@ Tests for zipline/utils/pandas_utils.py
|
||||
import pandas as pd
|
||||
|
||||
from zipline.testing import parameter_space, ZiplineTestCase
|
||||
from zipline.utils.pandas_utils import nearest_unequal_elements
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.pandas_utils import (
|
||||
categorical_df_concat,
|
||||
nearest_unequal_elements
|
||||
)
|
||||
|
||||
|
||||
class TestNearestUnequalElements(ZiplineTestCase):
|
||||
@@ -80,3 +84,97 @@ class TestNearestUnequalElements(ZiplineTestCase):
|
||||
str(e.exception),
|
||||
'dts must be sorted in increasing order',
|
||||
)
|
||||
|
||||
|
||||
class TestCatDFConcat(ZiplineTestCase):
|
||||
|
||||
def test_categorical_df_concat(self):
|
||||
|
||||
inp = [
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['a', 'b', 'c'], dtype='category'),
|
||||
'B': pd.Series([100, 102, 103], dtype='int64'),
|
||||
'C': pd.Series(['x', 'x', 'x'], dtype='category'),
|
||||
}
|
||||
),
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['c', 'b', 'd'], dtype='category'),
|
||||
'B': pd.Series([103, 102, 104], dtype='int64'),
|
||||
'C': pd.Series(['y', 'y', 'y'], dtype='category'),
|
||||
}
|
||||
),
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['a', 'b', 'd'], dtype='category'),
|
||||
'B': pd.Series([101, 102, 104], dtype='int64'),
|
||||
'C': pd.Series(['z', 'z', 'z'], dtype='category'),
|
||||
}
|
||||
),
|
||||
]
|
||||
result = categorical_df_concat(inp)
|
||||
|
||||
expected = pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(
|
||||
['a', 'b', 'c', 'c', 'b', 'd', 'a', 'b', 'd'],
|
||||
dtype='category'
|
||||
),
|
||||
'B': pd.Series(
|
||||
[100, 102, 103, 103, 102, 104, 101, 102, 104],
|
||||
dtype='int64'
|
||||
),
|
||||
'C': pd.Series(
|
||||
['x', 'x', 'x', 'y', 'y', 'y', 'z', 'z', 'z'],
|
||||
dtype='category'
|
||||
),
|
||||
},
|
||||
)
|
||||
expected.index = pd.Int64Index([0, 1, 2, 0, 1, 2, 0, 1, 2])
|
||||
assert_equal(expected, result)
|
||||
assert_equal(
|
||||
expected['A'].cat.categories,
|
||||
result['A'].cat.categories
|
||||
)
|
||||
assert_equal(
|
||||
expected['C'].cat.categories,
|
||||
result['C'].cat.categories
|
||||
)
|
||||
|
||||
def test_categorical_df_concat_value_error(self):
|
||||
|
||||
mismatched_dtypes = [
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['a', 'b', 'c'], dtype='category'),
|
||||
'B': pd.Series([100, 102, 103], dtype='int64'),
|
||||
}
|
||||
),
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['c', 'b', 'd'], dtype='category'),
|
||||
'B': pd.Series([103, 102, 104], dtype='float64'),
|
||||
}
|
||||
),
|
||||
]
|
||||
mismatched_column_names = [
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['a', 'b', 'c'], dtype='category'),
|
||||
'B': pd.Series([100, 102, 103], dtype='int64'),
|
||||
}
|
||||
),
|
||||
pd.DataFrame(
|
||||
{
|
||||
'A': pd.Series(['c', 'b', 'd'], dtype='category'),
|
||||
'X': pd.Series([103, 102, 104], dtype='int64'),
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
categorical_df_concat(mismatched_dtypes)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
categorical_df_concat(mismatched_column_names)
|
||||
|
||||
@@ -15,20 +15,20 @@
|
||||
|
||||
from .trading_calendar import TradingCalendar
|
||||
from .calendar_utils import (
|
||||
get_calendar,
|
||||
register_calendar_alias,
|
||||
register_calendar,
|
||||
register_calendar_type,
|
||||
clear_calendars,
|
||||
deregister_calendar,
|
||||
clear_calendars
|
||||
get_calendar,
|
||||
register_calendar,
|
||||
register_calendar_alias,
|
||||
register_calendar_type,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'TradingCalendar',
|
||||
'clear_calendars',
|
||||
'deregister_calendar',
|
||||
'get_calendar',
|
||||
'register_calendar',
|
||||
'register_calendar_alias',
|
||||
'register_calendar_type',
|
||||
'TradingCalendar',
|
||||
]
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
def roll_dates_to_previous_session(calendar, *dates):
|
||||
"""
|
||||
Roll ``dates`` to the next session of ``calendar``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendar : zipline.utils.calendars.trading_calendar.TradingCalendar
|
||||
The calendar to use as a reference.
|
||||
*dates : pd.Timestamp
|
||||
The dates for which the last trading date is needed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rolled_dates: pandas.tseries.index.DatetimeIndex
|
||||
The last trading date of the input dates, inclusive.
|
||||
|
||||
"""
|
||||
all_sessions = calendar.all_sessions
|
||||
|
||||
locs = [all_sessions.get_loc(dt, method='ffill') for dt in dates]
|
||||
return all_sessions[locs]
|
||||
@@ -2,6 +2,7 @@
|
||||
Utilities for working with pandas objects.
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
import operator as op
|
||||
import warnings
|
||||
@@ -222,3 +223,45 @@ def clear_dataframe_indexer_caches(df):
|
||||
delattr(df, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def categorical_df_concat(df_list, inplace=False):
|
||||
"""
|
||||
Prepare list of pandas DataFrames to be used as input to pd.concat.
|
||||
Ensure any columns of type 'category' have the same categories across each
|
||||
dataframe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_list : list
|
||||
List of dataframes with same columns.
|
||||
inplace : bool
|
||||
True if input list can be modified. Default is False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
concatenated : df
|
||||
Dataframe of concatenated list.
|
||||
"""
|
||||
|
||||
if not inplace:
|
||||
df_list = deepcopy(df_list)
|
||||
|
||||
# Assert each dataframe has the same columns/dtypes
|
||||
df = df_list[0]
|
||||
if not all([(df.dtypes.equals(df_i.dtypes)) for df_i in df_list[1:]]):
|
||||
raise ValueError("Input DataFrames must have the same columns/dtypes.")
|
||||
|
||||
categorical_columns = df.columns[df.dtypes == 'category']
|
||||
|
||||
for col in categorical_columns:
|
||||
new_categories = sorted(
|
||||
set().union(
|
||||
*(frame[col].cat.categories for frame in df_list)
|
||||
)
|
||||
)
|
||||
|
||||
for df in df_list:
|
||||
df[col].cat.set_categories(new_categories, inplace=True)
|
||||
|
||||
return pd.concat(df_list)
|
||||
|
||||
Reference in New Issue
Block a user