mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 22:04:02 +08:00
Merge pull request #593 from grundgruen/close_pos_event
ENH: Add CLOSE_POSTION as DataSource event type
This commit is contained in:
@@ -46,7 +46,8 @@ from zipline.finance.commission import PerShare, PerTrade, PerDollar
|
||||
from zipline.finance import trading
|
||||
from zipline.utils.factory import create_random_simulation_parameters
|
||||
import zipline.protocol as zp
|
||||
from zipline.protocol import Event
|
||||
from zipline.protocol import Event, DATASOURCE_TYPE
|
||||
from zipline.sources.data_frame_source import DataPanelSource
|
||||
|
||||
logger = logging.getLogger('Test Perf Tracking')
|
||||
|
||||
@@ -1937,6 +1938,37 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
|
||||
check_perf_tracker_serialization(tracker)
|
||||
|
||||
def test_close_position_event(self):
|
||||
pt = perf.PositionTracker()
|
||||
dt = pd.Timestamp("1984/03/06 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(120.0),
|
||||
last_sale_date=dt, last_sale_price=3.4)
|
||||
pos2 = perf.Position(2, amount=np.float64(-100.0),
|
||||
last_sale_date=dt, last_sale_price=3.4)
|
||||
pt.update_positions({1: pos1, 2: pos2})
|
||||
|
||||
event_type = DATASOURCE_TYPE.CLOSE_POSITION
|
||||
index = [dt + timedelta(days=1)]
|
||||
pan = pd.Panel({1: pd.DataFrame({'price': 1, 'volume': 0,
|
||||
'type': event_type}, index=index),
|
||||
2: pd.DataFrame({'price': 1, 'volume': 0,
|
||||
'type': event_type}, index=index),
|
||||
3: pd.DataFrame({'price': 1, 'volume': 0,
|
||||
'type': event_type}, index=index)})
|
||||
|
||||
source = DataPanelSource(pan)
|
||||
for i, event in enumerate(source):
|
||||
txn = pt.create_close_position_transaction(event)
|
||||
if event.sid == 1:
|
||||
# Test owned long
|
||||
self.assertEqual(-120, txn.amount)
|
||||
elif event.sid == 2:
|
||||
# Test owned short
|
||||
self.assertEqual(100, txn.amount)
|
||||
elif event.sid == 3:
|
||||
# Test not-owned SID
|
||||
self.assertIsNone(txn)
|
||||
|
||||
def test_serialization(self):
|
||||
start_dt = datetime(year=2008,
|
||||
month=10,
|
||||
|
||||
@@ -47,14 +47,16 @@ class TestDataFrameSource(TestCase):
|
||||
"DataFrameSource should only stream selected sid 0, not sid 1."
|
||||
|
||||
def test_panel_source(self):
|
||||
source, panel = factory.create_test_panel_source()
|
||||
source, panel = factory.create_test_panel_source(source_type=5)
|
||||
assert isinstance(source.start, pd.lib.Timestamp)
|
||||
assert isinstance(source.end, pd.lib.Timestamp)
|
||||
for event in source:
|
||||
self.assertTrue('sid' in event)
|
||||
self.assertTrue('arbitrary' in event)
|
||||
self.assertTrue('type' in event)
|
||||
self.assertTrue(hasattr(event, 'volume'))
|
||||
self.assertTrue(hasattr(event, 'price'))
|
||||
self.assertEquals(event['type'], 5)
|
||||
self.assertEquals(event['arbitrary'], 1.)
|
||||
self.assertEquals(event['sid'], 0)
|
||||
self.assertTrue(isinstance(event['volume'], int))
|
||||
|
||||
@@ -13,6 +13,7 @@ except ImportError:
|
||||
from six import iteritems
|
||||
from six.moves import map, filter
|
||||
|
||||
from zipline.finance.slippage import Transaction
|
||||
from zipline.utils.serialization_utils import (
|
||||
VERSION_LABEL
|
||||
)
|
||||
@@ -217,6 +218,19 @@ class PositionTracker(object):
|
||||
net_cash_payment = payments['cash_amount'].fillna(0).sum()
|
||||
return net_cash_payment
|
||||
|
||||
def create_close_position_transaction(self, event):
|
||||
if not self._position_amounts.get(event.sid):
|
||||
return None
|
||||
txn = Transaction(
|
||||
sid=event.sid,
|
||||
amount=(-1 * self._position_amounts[event.sid]),
|
||||
dt=event.dt,
|
||||
price=event.price,
|
||||
commission=0,
|
||||
order_id=0
|
||||
)
|
||||
return txn
|
||||
|
||||
def get_positions(self):
|
||||
|
||||
positions = self._positions_store
|
||||
|
||||
@@ -337,6 +337,11 @@ class PerformanceTracker(object):
|
||||
|
||||
self.all_benchmark_returns[midnight] = event.returns
|
||||
|
||||
def process_close_position(self, event):
|
||||
txn = self.position_tracker.create_close_position_transaction(event)
|
||||
if txn:
|
||||
self.process_transaction(txn)
|
||||
|
||||
def check_upcoming_dividends(self, midnight_of_date_that_just_ended):
|
||||
"""
|
||||
Check if we currently own any stocks with dividends whose ex_date is
|
||||
|
||||
@@ -204,6 +204,8 @@ class AlgorithmSimulator(object):
|
||||
perf_process_split = self.algo.perf_tracker.process_split
|
||||
perf_process_dividend = self.algo.perf_tracker.process_dividend
|
||||
perf_process_commission = self.algo.perf_tracker.process_commission
|
||||
perf_process_close_position = \
|
||||
self.algo.perf_tracker.process_close_position
|
||||
blotter_process_trade = self.algo.blotter.process_trade
|
||||
blotter_process_benchmark = self.algo.blotter.process_benchmark
|
||||
|
||||
@@ -219,6 +221,7 @@ class AlgorithmSimulator(object):
|
||||
# custom events.
|
||||
trades = []
|
||||
customs = []
|
||||
closes = []
|
||||
|
||||
# splits and dividends are processed once a day.
|
||||
#
|
||||
@@ -247,6 +250,8 @@ class AlgorithmSimulator(object):
|
||||
if dividends is None:
|
||||
dividends = []
|
||||
dividends.append(event)
|
||||
elif event.type == DATASOURCE_TYPE.CLOSE_POSITION:
|
||||
closes.append(event)
|
||||
else:
|
||||
raise log.warn("Unrecognized event=%s".format(event))
|
||||
|
||||
@@ -282,6 +287,10 @@ class AlgorithmSimulator(object):
|
||||
for custom in customs:
|
||||
self.update_universe(custom)
|
||||
|
||||
for close in closes:
|
||||
self.update_universe(close)
|
||||
perf_process_close_position(close)
|
||||
|
||||
if splits is not None:
|
||||
for split in splits:
|
||||
# process_split is not assigned to a variable since it is
|
||||
|
||||
+2
-1
@@ -42,7 +42,8 @@ DATASOURCE_TYPE = Enum(
|
||||
'DONE',
|
||||
'CUSTOM',
|
||||
'BENCHMARK',
|
||||
'COMMISSION'
|
||||
'COMMISSION',
|
||||
'CLOSE_POSITION'
|
||||
)
|
||||
|
||||
# Expected fields/index values for a dividend Series.
|
||||
|
||||
@@ -45,11 +45,12 @@ class DataSource(with_metaclass(ABCMeta)):
|
||||
"""
|
||||
Override this to hand craft conversion of row.
|
||||
"""
|
||||
row = {target: mapping_func(raw_row[source_key])
|
||||
for target, (mapping_func, source_key)
|
||||
in self.mapping.items()}
|
||||
row.update({'source_id': self.get_hash()})
|
||||
row = {}
|
||||
row.update({'type': self.event_type})
|
||||
row.update({target: mapping_func(raw_row[source_key])
|
||||
for target, (mapping_func, source_key)
|
||||
in self.mapping.items()})
|
||||
row.update({'source_id': self.get_hash()})
|
||||
return row
|
||||
|
||||
@property
|
||||
|
||||
@@ -315,7 +315,7 @@ def create_test_df_source(sim_params=None, bars='daily'):
|
||||
return DataFrameSource(df), df
|
||||
|
||||
|
||||
def create_test_panel_source(sim_params=None):
|
||||
def create_test_panel_source(sim_params=None, source_type=None):
|
||||
start = sim_params.first_open \
|
||||
if sim_params else pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
|
||||
|
||||
@@ -329,12 +329,17 @@ def create_test_panel_source(sim_params=None):
|
||||
|
||||
price = np.arange(0, len(index))
|
||||
volume = np.ones(len(index)) * 1000
|
||||
|
||||
arbitrary = np.ones(len(index))
|
||||
|
||||
df = pd.DataFrame({'price': price,
|
||||
'volume': volume,
|
||||
'arbitrary': arbitrary},
|
||||
index=index)
|
||||
if source_type:
|
||||
source_types = np.full(len(index), source_type)
|
||||
df['type'] = source_types
|
||||
|
||||
panel = pd.Panel.from_dict({0: df})
|
||||
|
||||
return DataPanelSource(panel), panel
|
||||
|
||||
Reference in New Issue
Block a user