mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 17:56:34 +08:00
+35
-15
@@ -5,18 +5,22 @@ and zipline development
|
||||
import random
|
||||
import pytz
|
||||
|
||||
from itertools import chain, cycle, ifilter, izip
|
||||
from itertools import chain, cycle, ifilter, izip, repeat
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.gens.utils import hash_args, create_trade
|
||||
|
||||
def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc),
|
||||
delta = timedelta(minutes = 1),
|
||||
count = 100):
|
||||
count = 100,
|
||||
repeats = None):
|
||||
"""
|
||||
Utility to generate a stream of dates.
|
||||
"""
|
||||
return (start + (i * delta) for i in xrange(count))
|
||||
if repeats:
|
||||
return (start + (i * delta) for i in xrange(count) for n in xrange(repeats))
|
||||
else:
|
||||
return (start + (i * delta) for i in xrange(count))
|
||||
|
||||
def mock_prices(count, rand = False):
|
||||
"""
|
||||
@@ -74,6 +78,7 @@ class SpecificEquityTrades(object):
|
||||
self.sids = kwargs.get('sids', [1, 2])
|
||||
self.start = kwargs.get('start', datetime(2008, 6, 6, 15, tzinfo = pytz.utc))
|
||||
self.delta = kwargs.get('delta', timedelta(minutes = 1))
|
||||
self.concurrent = kwargs.get('concurrent', False)
|
||||
|
||||
# Default to None for event_list and filter.
|
||||
self.event_list = kwargs.get('event_list')
|
||||
@@ -103,20 +108,35 @@ class SpecificEquityTrades(object):
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
prices = mock_prices(self.count)
|
||||
volumes = mock_volumes(self.count)
|
||||
sids = cycle(self.sids)
|
||||
if self.concurrent:
|
||||
# in this context the count is the number of
|
||||
# trades per sid, not the total.
|
||||
dates = date_gen(
|
||||
count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta,
|
||||
repeats=len(self.sids),
|
||||
)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = self.get_hash())
|
||||
for args in arg_gen)
|
||||
else:
|
||||
|
||||
dates = date_gen(
|
||||
count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
|
||||
prices = mock_prices(self.count)
|
||||
volumes = mock_volumes(self.count)
|
||||
|
||||
sids = cycle(self.sids)
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = self.get_hash())
|
||||
for args in arg_gen)
|
||||
|
||||
# If we specified a sid filter, filter out elements that don't
|
||||
# match the filter.
|
||||
|
||||
+10
-5
@@ -269,7 +269,12 @@ class SimulatedTrading(object):
|
||||
of StatefulTransform objects.
|
||||
"""
|
||||
assert isinstance(config, dict)
|
||||
sid = config['sid']
|
||||
sid_list = config.get('sid_list')
|
||||
if not sid_list:
|
||||
sid = config.get('sid')
|
||||
sid_list = [sid]
|
||||
|
||||
concurrent_trades = config.get('concurrent_trades', False)
|
||||
|
||||
#--------------------
|
||||
# Trading Environment
|
||||
@@ -307,17 +312,17 @@ class SimulatedTrading(object):
|
||||
#-------------------
|
||||
# Trade Source
|
||||
#-------------------
|
||||
sids = [sid]
|
||||
#-------------------
|
||||
if config.has_key('trade_source'):
|
||||
trade_source = config['trade_source']
|
||||
else:
|
||||
trade_source = factory.create_daily_trade_source(
|
||||
sids,
|
||||
sid_list,
|
||||
trade_count,
|
||||
trading_environment
|
||||
trading_environment,
|
||||
concurrent=concurrent_trades
|
||||
)
|
||||
|
||||
|
||||
#-------------------
|
||||
# Transforms
|
||||
#-------------------
|
||||
|
||||
@@ -174,7 +174,7 @@ def create_random_trade_source(sid, trade_count, trading_environment):
|
||||
|
||||
return source
|
||||
|
||||
def create_daily_trade_source(sids, trade_count, trading_environment):
|
||||
def create_daily_trade_source(sids, trade_count, trading_environment, concurrent=False):
|
||||
|
||||
"""
|
||||
creates trade_count trades for each sid in sids list.
|
||||
@@ -189,11 +189,12 @@ def create_daily_trade_source(sids, trade_count, trading_environment):
|
||||
sids,
|
||||
trade_count,
|
||||
timedelta(days=1),
|
||||
trading_environment
|
||||
trading_environment,
|
||||
concurrent=concurrent
|
||||
)
|
||||
|
||||
|
||||
def create_minutely_trade_source(sids, trade_count, trading_environment):
|
||||
def create_minutely_trade_source(sids, trade_count, trading_environment, concurrent=False):
|
||||
|
||||
"""
|
||||
creates trade_count trades for each sid in sids list.
|
||||
@@ -208,10 +209,11 @@ def create_minutely_trade_source(sids, trade_count, trading_environment):
|
||||
sids,
|
||||
trade_count,
|
||||
timedelta(minutes=1),
|
||||
trading_environment
|
||||
trading_environment,
|
||||
concurrent=concurrent
|
||||
)
|
||||
|
||||
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment):
|
||||
def create_trade_source(sids, trade_count, trade_time_increment, trading_environment, concurrent=False):
|
||||
|
||||
args = tuple()
|
||||
kwargs = {
|
||||
@@ -219,7 +221,8 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ
|
||||
'sids' : sids,
|
||||
'start' : trading_environment.first_open,
|
||||
'delta' : trade_time_increment,
|
||||
'filter' : sids
|
||||
'filter' : sids,
|
||||
'concurrent' : concurrent
|
||||
}
|
||||
source = SpecificEquityTrades(*args, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user