diff --git a/tests/utils/test_date_utils.py b/tests/utils/test_date_utils.py new file mode 100644 index 00000000..ab43e11e --- /dev/null +++ b/tests/utils/test_date_utils.py @@ -0,0 +1,29 @@ +from pandas import Timestamp + +from nose_parameterized import parameterized + +from zipline.testing import ZiplineTestCase +from zipline.utils.calendars import get_calendar +from zipline.utils.date_utils import roll_dates_to_previous_session + + +class TestRollDatesToPreviousSession(ZiplineTestCase): + + @parameterized.expand([ + ( + Timestamp('05-19-2017', tz='UTC'), # actual trading date + Timestamp('05-19-2017', tz='UTC'), + ), + ( + Timestamp('07-04-2015', tz='UTC'), # weekend nyse holiday + Timestamp('07-02-2015', tz='UTC'), + ), + ( + Timestamp('01-16-2017', tz='UTC'), # weeknight nyse holiday + Timestamp('01-13-2017', tz='UTC'), + ), + ]) + def test_roll_dates_to_previous_session(self, date, expected_rolled_date): + calendar = get_calendar('NYSE') + result = roll_dates_to_previous_session(calendar, date) + self.assertEqual(result[0], expected_rolled_date) diff --git a/tests/utils/test_pandas_utils.py b/tests/utils/test_pandas_utils.py index e84ed3fe..870a99eb 100644 --- a/tests/utils/test_pandas_utils.py +++ b/tests/utils/test_pandas_utils.py @@ -4,7 +4,11 @@ Tests for zipline/utils/pandas_utils.py import pandas as pd from zipline.testing import parameter_space, ZiplineTestCase -from zipline.utils.pandas_utils import nearest_unequal_elements +from zipline.testing.predicates import assert_equal +from zipline.utils.pandas_utils import ( + categorical_df_concat, + nearest_unequal_elements +) class TestNearestUnequalElements(ZiplineTestCase): @@ -80,3 +84,97 @@ class TestNearestUnequalElements(ZiplineTestCase): str(e.exception), 'dts must be sorted in increasing order', ) + + +class TestCatDFConcat(ZiplineTestCase): + + def test_categorical_df_concat(self): + + inp = [ + pd.DataFrame( + { + 'A': pd.Series(['a', 'b', 'c'], dtype='category'), + 'B': pd.Series([100, 102, 103], dtype='int64'), + 'C': pd.Series(['x', 'x', 'x'], dtype='category'), + } + ), + pd.DataFrame( + { + 'A': pd.Series(['c', 'b', 'd'], dtype='category'), + 'B': pd.Series([103, 102, 104], dtype='int64'), + 'C': pd.Series(['y', 'y', 'y'], dtype='category'), + } + ), + pd.DataFrame( + { + 'A': pd.Series(['a', 'b', 'd'], dtype='category'), + 'B': pd.Series([101, 102, 104], dtype='int64'), + 'C': pd.Series(['z', 'z', 'z'], dtype='category'), + } + ), + ] + result = categorical_df_concat(inp) + + expected = pd.DataFrame( + { + 'A': pd.Series( + ['a', 'b', 'c', 'c', 'b', 'd', 'a', 'b', 'd'], + dtype='category' + ), + 'B': pd.Series( + [100, 102, 103, 103, 102, 104, 101, 102, 104], + dtype='int64' + ), + 'C': pd.Series( + ['x', 'x', 'x', 'y', 'y', 'y', 'z', 'z', 'z'], + dtype='category' + ), + }, + ) + expected.index = pd.Int64Index([0, 1, 2, 0, 1, 2, 0, 1, 2]) + assert_equal(expected, result) + assert_equal( + expected['A'].cat.categories, + result['A'].cat.categories + ) + assert_equal( + expected['C'].cat.categories, + result['C'].cat.categories + ) + + def test_categorical_df_concat_value_error(self): + + mismatched_dtypes = [ + pd.DataFrame( + { + 'A': pd.Series(['a', 'b', 'c'], dtype='category'), + 'B': pd.Series([100, 102, 103], dtype='int64'), + } + ), + pd.DataFrame( + { + 'A': pd.Series(['c', 'b', 'd'], dtype='category'), + 'B': pd.Series([103, 102, 104], dtype='float64'), + } + ), + ] + mismatched_column_names = [ + pd.DataFrame( + { + 'A': pd.Series(['a', 'b', 'c'], dtype='category'), + 'B': pd.Series([100, 102, 103], dtype='int64'), + } + ), + pd.DataFrame( + { + 'A': pd.Series(['c', 'b', 'd'], dtype='category'), + 'X': pd.Series([103, 102, 104], dtype='int64'), + } + ), + ] + + with self.assertRaises(ValueError): + categorical_df_concat(mismatched_dtypes) + + with self.assertRaises(ValueError): + categorical_df_concat(mismatched_column_names) diff --git a/zipline/utils/calendars/__init__.py b/zipline/utils/calendars/__init__.py index d3f3ecf0..61c22e6b 100644 --- a/zipline/utils/calendars/__init__.py +++ b/zipline/utils/calendars/__init__.py @@ -15,20 +15,20 @@ from .trading_calendar import TradingCalendar from .calendar_utils import ( - get_calendar, - register_calendar_alias, - register_calendar, - register_calendar_type, + clear_calendars, deregister_calendar, - clear_calendars + get_calendar, + register_calendar, + register_calendar_alias, + register_calendar_type, ) __all__ = [ - 'TradingCalendar', 'clear_calendars', 'deregister_calendar', 'get_calendar', 'register_calendar', 'register_calendar_alias', 'register_calendar_type', + 'TradingCalendar', ] diff --git a/zipline/utils/date_utils.py b/zipline/utils/date_utils.py new file mode 100644 index 00000000..07ceeda9 --- /dev/null +++ b/zipline/utils/date_utils.py @@ -0,0 +1,21 @@ +def roll_dates_to_previous_session(calendar, *dates): + """ + Roll ``dates`` to the next session of ``calendar``. + + Parameters + ---------- + calendar : zipline.utils.calendars.trading_calendar.TradingCalendar + The calendar to use as a reference. + *dates : pd.Timestamp + The dates for which the last trading date is needed. + + Returns + ------- + rolled_dates: pandas.tseries.index.DatetimeIndex + The last trading date of the input dates, inclusive. + + """ + all_sessions = calendar.all_sessions + + locs = [all_sessions.get_loc(dt, method='ffill') for dt in dates] + return all_sessions[locs] diff --git a/zipline/utils/pandas_utils.py b/zipline/utils/pandas_utils.py index ccac273a..39e4040b 100644 --- a/zipline/utils/pandas_utils.py +++ b/zipline/utils/pandas_utils.py @@ -2,6 +2,7 @@ Utilities for working with pandas objects. """ from contextlib import contextmanager +from copy import deepcopy from itertools import product import operator as op import warnings @@ -222,3 +223,45 @@ def clear_dataframe_indexer_caches(df): delattr(df, attr) except AttributeError: pass + + +def categorical_df_concat(df_list, inplace=False): + """ + Prepare list of pandas DataFrames to be used as input to pd.concat. + Ensure any columns of type 'category' have the same categories across each + dataframe. + + Parameters + ---------- + df_list : list + List of dataframes with same columns. + inplace : bool + True if input list can be modified. Default is False. + + Returns + ------- + concatenated : df + Dataframe of concatenated list. + """ + + if not inplace: + df_list = deepcopy(df_list) + + # Assert each dataframe has the same columns/dtypes + df = df_list[0] + if not all([(df.dtypes.equals(df_i.dtypes)) for df_i in df_list[1:]]): + raise ValueError("Input DataFrames must have the same columns/dtypes.") + + categorical_columns = df.columns[df.dtypes == 'category'] + + for col in categorical_columns: + new_categories = sorted( + set().union( + *(frame[col].cat.categories for frame in df_list) + ) + ) + + for df in df_list: + df[col].cat.set_categories(new_categories, inplace=True) + + return pd.concat(df_list)