diff --git a/zipline/finance/risk.py b/zipline/finance/risk.py index 516a35f0..38e8159e 100644 --- a/zipline/finance/risk.py +++ b/zipline/finance/risk.py @@ -374,49 +374,3 @@ class RiskReport(): if len(col) == 1: return col[0] return None - - -class TradingEnvironment(object): - - def __init__( - self, - benchmark_returns, - treasury_curves, - period_start=None, - period_end=None, - capital_base=None, - frame_index=None - ): - - self.trading_days = [] - self.trading_day_map = {} - self.treasury_curves = treasury_curves - self.benchmark_returns = benchmark_returns - self.frame_index = frame_index - self.period_start = period_start - self.period_end = period_end - self.capital_base = capital_base - - for bm in benchmark_returns: - self.trading_days.append(bm.date) - self.trading_day_map[bm.date] = bm - - def normalize_date(self, test_date): - return datetime.datetime( - year=test_date.year, - month=test_date.month, - day=test_date.day, - tzinfo=pytz.utc - ) - - def is_trading_day(self, test_date): - dt = self.normalize_date(test_date) - return self.trading_day_map.has_key(dt) - - def get_benchmark_daily_return(self, test_date): - date = self.normalize_date(test_date) - if self.trading_day_map.has_key(date): - return self.trading_day_map[date].returns - else: - return 0.0 - diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index a7d0e65b..3392f1a9 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -303,6 +303,57 @@ class TransactionSimulator(qmsg.BaseTransform): } return zp.namedict(txn) - + +class TradingEnvironment(object): + + def __init__( + self, + benchmark_returns, + treasury_curves, + period_start=None, + period_end=None, + capital_base=None + ): + + self.trading_days = [] + self.trading_day_map = {} + self.treasury_curves = treasury_curves + self.benchmark_returns = benchmark_returns + self.frame_index = ['sid', 'volume', 'dt', 'price', 'changed'] + self.period_start = period_start + self.period_end = period_end + self.capital_base = capital_base + + for bm in benchmark_returns: + self.trading_days.append(bm.date) + self.trading_day_map[bm.date] = bm + + def normalize_date(self, test_date): + return datetime.datetime( + year=test_date.year, + month=test_date.month, + day=test_date.day, + tzinfo=pytz.utc + ) + + def is_trading_day(self, test_date): + dt = self.normalize_date(test_date) + return self.trading_day_map.has_key(dt) + + def get_benchmark_daily_return(self, test_date): + date = self.normalize_date(test_date) + if self.trading_day_map.has_key(date): + return self.trading_day_map[date].returns + else: + return 0.0 + + def add_to_frame(self, name): + """ + Add an entry to the frame index. + :param name: new index entry name. Used by TradingSimulationClient + to + """ + self.frame_index.append(name) + diff --git a/zipline/lines.py b/zipline/lines.py index e5a82aca..88ead984 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -55,13 +55,74 @@ class SimulatedTrading(object): a control monitor, which can kill the entire zipline in the event of exceptions in one of the components or an external request to end the simulation. + + Here is a diagram of the SimulatedTrading zipline: + + + +----------------------+ +------------------------+ + +-->| Orders DataSource | | (DataSource added | + | | Integrates algo | | via add_source) | + | | orders into history | | | + | +--------------------+-+ +-+----------------------+ + | | | + | | | + | v v + | +---------+ + | | Feed | + | +-+------++ + | | | + | | | + | v v + | +----------------------+ +----------------------+ + | | Transaction | | | + | | Transform simulates | | (Transforms added | + | | trades based on | | via add_transform) | + | | orders from algo. | | | + | +-------------------+--+ +-+--------------------+ + | | | + | | | + | v v + | +------------+ + | | Merge | + | +------+-----+ + | | + | | + | V + | +--------------------------------+ + | | | + | | TradingSimulationClient | + | orders | tracks performance and | + +---------------+ provides API to algorithm. | + | | + +---------------------+----------+ + ^ | + | orders | frames + | | + | v + +---------+-----------------------+ + | | + | Algorithm added via | + | __init__. | + | | + | | + | | + +---------------------------------+ """ - def __init__(self, trading_environment, allocator): + def __init__(self, algorithm, trading_environment, allocator): + """ + :param algorithm: a class that follows the algorithm protocol. Must + have a handle_frame method that accepts a pandas.Dataframe of the + current state of the simulation universe. Must have an order property + which can be set equal to the order method of trading_client. + :param trading_environment: TradingEnvironment object. + """ + self.algorithm = algorithm self.allocator = allocator self.leased_sockets = [] self.trading_environment = trading_environment self.sim_context = None + self.algorithm = algorithm sockets = self.allocate_sockets(8) addresses = { @@ -81,9 +142,6 @@ class SimulatedTrading(object): self.sim = Simulator(addresses) - self.trading_environment.frame_index = ['sid', 'volume', 'dt', \ - 'price', 'changed'] - self.clients = {} self.trading_client = TradeSimulationClient(self.trading_environment) self.clients[self.trading_client.get_id] = self.trading_client @@ -109,6 +167,9 @@ class SimulatedTrading(object): self.sim.on_done = self.shutdown() self.started = False + self.trading_client.add_event_callback(self.algorithm.handle_frame) + self.algorithm.set_order(self.trading_client.order) + def add_source(self, source): assert isinstance(source, zmsg.DataSource) self.check_started()