diff --git a/zipline/finance/performance.py b/zipline/finance/performance.py index 8f3d9aa1..a4b1223c 100644 --- a/zipline/finance/performance.py +++ b/zipline/finance/performance.py @@ -9,10 +9,9 @@ import zipline.util as qutil import zipline.protocol as zp import zipline.finance.risk as risk -class PortfolioClient(qmsg.Component): +class PerformanceTracker(): def __init__(self, period_start, period_end, capital_base, trading_environment): - qmsg.Component.__init__(self) self.trading_day = datetime.timedelta(hours=6, minutes=30) self.calendar_day = datetime.timedelta(hours=24) self.period_start = period_start @@ -27,35 +26,29 @@ class PortfolioClient(qmsg.Component): self.capital_base = capital_base self.trading_environment = trading_environment self.returns = [] - self.cumulative_performance = PerformancePeriod(self.period_start, self.period_end, {}, 0, capital_base = capital_base) - self.todays_performance = PerformancePeriod(self.market_open, self.market_close, {}, 0, capital_base = capital_base) - - @property - def get_id(self): - return str(zp.FINANCE_COMPONENT.PORTFOLIO_CLIENT) - - def open(self): - self.result_feed = self.connect_result() - - def do_work(self): - #next feed event - socks = dict(self.poll.poll(self.heartbeat_timeout)) - - if self.result_feed in socks and socks[self.result_feed] == self.zmq.POLLIN: - msg = self.result_feed.recv() - - if msg == str(zp.CONTROL_PROTOCOL.DONE): - self.handle_simulation_end() - qutil.LOGGER.info("Portfolio Client is DONE!") - self.signal_done() - return - - event = zp.MERGE_UNFRAME(msg) + self.txn_count = 0 + self.event_count = 0 + self.cumulative_performance = PerformancePeriod( + {}, + capital_base, + starting_cash = capital_base + ) + self.todays_performance = PerformancePeriod( + {}, + capital_base, + starting_cash = capital_base + ) + + + + def update(self, event): + self.event_count += 1 if(event.dt >= self.market_close): self.handle_market_close() - if event.TRANSACTION: + if event.TRANSACTION != None: + self.txn_count += 1 self.cumulative_performance.execute_transaction(event.TRANSACTION) self.todays_performance.execute_transaction(event.TRANSACTION) @@ -73,28 +66,35 @@ class PortfolioClient(qmsg.Component): #calculate performance as of last trade self.cumulative_performance.calculate_performance() self.todays_performance.calculate_performance() - - - + def handle_market_close(self): - self.market_open = self.market_open + self.calendar_day - while not self.trading_environment.is_trading_day(self.market_open): - if self.market_open > self.trading_environment.trading_days[-1]: - raise Exception("Attempting to backtest beyond available history.") - self.market_open = self.market_open + self.calendar_day - self.market_close = self.market_open + self.trading_day - self.day_count += 1.0 - self.progress = self.day_count / self.total_days - #add the return results from today to the list of daily return objects. - todays_date = self.todays_performance.period_end.replace(hour=0, minute=0, second=0) + #add the return results from today to the list of daily return objects. + todays_date = self.market_close.replace(hour=0, minute=0, second=0) todays_return_obj = risk.daily_return(todays_date, self.todays_performance.returns) self.returns.append(todays_return_obj) #calculate risk metrics for cumulative performance - self.cur_period_metrics = risk.RiskMetrics(start_date=self.cumulative_performance.period_start, - end_date=self.cumulative_performance.period_end.replace(hour=0, minute=0, second=0), - returns=self.returns, - trading_environment=self.trading_environment) + self.cumulative_risk_metrics = risk.RiskMetrics( + start_date=self.period_start, + end_date=self.market_close.replace(hour=0, minute=0, second=0), + returns=self.returns, + trading_environment=self.trading_environment + ) + + #move the market day markers forward + self.market_open = self.market_open + self.calendar_day + while not self.trading_environment.is_trading_day(self.market_open): + if self.market_open > self.trading_environment.trading_days[-1]: + raise Exception("Attempt to backtest beyond available history.") + self.market_open = self.market_open + self.calendar_day + self.market_close = self.market_open + self.trading_day + self.day_count += 1.0 + + #calculate progress of test + self.progress = self.day_count / self.total_days + + + ###################################################################################################### #######TODO: report/relay metrics out to qexec -- values come from self.cur_period_metrics ########### @@ -102,14 +102,19 @@ class PortfolioClient(qmsg.Component): ###################################################################################################### #roll over positions to current day. - self.todays_performance = PerformancePeriod(self.market_open, - self.market_close, - self.todays_performance.positions, - self.todays_performance.ending_value, - self.capital_base) - + self.todays_performance.calculate_performance() + self.todays_performance = PerformancePeriod( + self.todays_performance.positions, + self.todays_performance.ending_value, + self.todays_performance.ending_cash + ) + def handle_simulation_end(self): - self.risk_report = risk.RiskReport(self.returns, self.trading_environment) + self.risk_report = risk.RiskReport( + self.returns, + self.trading_environment + ) + ###################################################################################################### #######TODO: report/relay metrics out to qexec -- values come from self.risk_report ########### ###################################################################################################### @@ -131,14 +136,18 @@ class Position(): def update(self, txn): if(self.sid != txn.sid): - raise NameError('attempt to update position with transaction in different sid') + raise NameError('updating position with txn for a different sid') #throw exception if(self.amount + txn.amount == 0): #we're covering a short or closing a position self.cost_basis = 0.0 self.amount = 0 else: - self.cost_basis = (self.cost_basis*self.amount + (txn.amount*txn.price))/(self.amount + txn.amount) + prev_cost = self.cost_basis*self.amount + txn_cost = txn.amount*txn.price + total_cost = prev_cost + txn_cost + total_shares = self.amount + txn.amount + self.cost_basis = total_cost/total_shares self.amount = self.amount + txn.amount def currentValue(self): @@ -146,35 +155,40 @@ class Position(): def __repr__(self): - return "sid: {sid}, amount: {amount}, cost_basis: {cost_basis}, last_sale: {last_sale}".format( - sid=self.sid, amount=self.amount, cost_basis=self.cost_basis, last_sale=self.last_sale) + template = "sid: {sid}, amount: {amount}, cost_basis: {cost_basis}, \ + last_sale: {last_sale}" + return template.format( + sid=self.sid, + amount=self.amount, + cost_basis=self.cost_basis, + last_sale=self.last_sale + ) class PerformancePeriod(): - def __init__(self, period_start, period_end, initial_positions, initial_value, capital_base = None): + def __init__(self, initial_positions, starting_value, starting_cash): self.ending_value = 0.0 self.period_capital_used = 0.0 - self.period_start = period_start - self.period_end = period_end self.positions = initial_positions #sid => position object - self.starting_value = initial_value - if(capital_base != None): - self.capital_base = capital_base - else: - self.capital_base = 0 + self.starting_value = starting_value + #cash balance at start of period + self.starting_cash = starting_cash + self.ending_cash = starting_cash def calculate_performance(self): self.ending_value = self.calculate_positions_value() - self.pnl = (self.ending_value - self.starting_value) - self.period_capital_used - if(self.capital_base != 0): - self.returns = self.pnl / self.starting_value + + total_at_start = self.starting_cash + self.starting_value + self.ending_cash = self.starting_cash + self.period_capital_used + total_at_end = self.ending_cash + self.ending_value + + self.pnl = total_at_end - total_at_start + if(total_at_start != 0): + self.returns = self.pnl / total_at_start else: self.returns = 0.0 def execute_transaction(self, txn): - if(txn.dt > self.period_end): - raise Exception("transaction dated {dt} attempted for period ending {ending}". - format(dt=txn.dt, ending=self.period_end)) if(not self.positions.has_key(txn.sid)): self.positions[txn.sid] = Position(txn.sid) self.positions[txn.sid].update(txn) @@ -188,10 +202,9 @@ class PerformancePeriod(): return mktValue def update_last_sale(self, event): - if self.positions.has_key(event.sid): + if self.positions.has_key(event.sid) and event.type == zp.DATASOURCE_TYPE.TRADE: self.positions[event.sid].last_sale = event.price self.positions[event.sid].last_date = event.dt - \ No newline at end of file diff --git a/zipline/finance/risk.py b/zipline/finance/risk.py index 2d6195c0..c865aaa6 100644 --- a/zipline/finance/risk.py +++ b/zipline/finance/risk.py @@ -17,16 +17,18 @@ class daily_return(): return str(self.date) + " - " + str(self.returns) class RiskMetrics(): - def __init__(self, start_date, end_date, returns, benchmark_returns, treasury_curves, trading_calendar): + def __init__(self, start_date, end_date, returns, trading_environment): """ :param treasury_curves: {datetime in utc -> {duration label -> interest rate}} """ - self.treasury_curves = treasury_curves + self.treasury_curves = trading_environment.treasury_curves self.start_date = start_date self.end_date = end_date - self.trading_calendar = trading_calendar + self.trading_environment = trading_environment self.algorithm_period_returns, self.algorithm_returns = self.calculate_period_returns(returns) + benchmark_returns = [x for x in self.trading_environment.benchmark_returns if x.date >= returns[0].date and x.date <= returns[-1].date] + self.benchmark_period_returns, self.benchmark_returns = self.calculate_period_returns(benchmark_returns) if(len(self.benchmark_returns) != len(self.algorithm_returns)): raise Exception("Mismatch between benchmark_returns ({bm_count}) and algorithm_returns ({algo_count}) in range {start} : {end}".format( @@ -53,7 +55,7 @@ class RiskMetrics(): return '\n'.join(statements) def calculate_period_returns(self, daily_returns): - returns = [x.returns for x in daily_returns if x.date >= self.start_date and x.date <= self.end_date and self.trading_calendar.is_trading_day(x.date)] + returns = [x.returns for x in daily_returns if x.date >= self.start_date and x.date <= self.end_date and self.trading_environment.is_trading_day(x.date)] #qutil.LOGGER.debug("using {count} daily returns out of {total}".format(count=len(returns),total=len(daily_returns))) period_returns = 1.0 for r in returns: @@ -165,18 +167,13 @@ class RiskMetrics(): class RiskReport(): - def __init__(self, algorithm_returns, benchmark_returns, treasury_curves, trading_calendar): + def __init__(self, algorithm_returns, benchmark_returns, treasury_curves, trading_environment): """algorithm_returns needs to be a list of daily_return objects sorted in date ascending order""" self.algorithm_returns = algorithm_returns - self.bm_returns = [x for x in benchmark_returns if x.date >= self.algorithm_returns[0].date and x.date <= self.algorithm_returns[-1].date] self.treasury_curves = treasury_curves - self.trading_calendar = trading_calendar + self.trading_environment = trading_environment - qutil.LOGGER.debug("#### {start} thru {end} with {count} trading_days of {total} possible".format(start=self.algorithm_returns[0].date, - end=self.algorithm_returns[-1].date, - count=len(self.bm_returns), - total=len(benchmark_returns))) #calculate month ends self.month_periods = self.periodsInRange(1, self.algorithm_returns[0].date, self.algorithm_returns[-1].date) @@ -202,13 +199,13 @@ class RiskReport(): cur_end = advance_by_months(cur_start, months_per) - one_day if(cur_end > the_end): break - #qutil.LOGGER.debug("start: {start}, end: {end}".format(start=cur_start, end=cur_end)) - cur_period_metrics = RiskMetrics(start_date=cur_start, - end_date=cur_end, - returns=self.algorithm_returns, - benchmark_returns=self.bm_returns, - treasury_curves=self.treasury_curves, - trading_calendar=self.trading_calendar) + cur_period_metrics = RiskMetrics( + start_date=cur_start, + end_date=cur_end, + returns=self.algorithm_returns, + trading_environment=self.trading_environment + ) + ends.append(cur_period_metrics) cur_start = advance_by_months(cur_start, 1) @@ -242,6 +239,7 @@ class TradingEnvironment(object): self.trading_days = [] self.trading_day_map = {} self.treasury_curves = treasury_curves + self.benchmark_returns = benchmark_returns for bm in benchmark_returns: self.trading_days.append(bm.date) self.trading_day_map[bm.date] = bm diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 2ecdc197..5c4817de 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -15,11 +15,20 @@ class TradeSimulationClient(qmsg.Component): self.received_count = 0 self.prev_dt = None self.event_queue = [] + self.event_callbacks = [] + self.txn_count = 0 @property def get_id(self): return str(zp.FINANCE_COMPONENT.TRADING_CLIENT) + def add_event_callback(self, callback): + """ + :param callable callback: must be a function with the signature + f(frame). + """ + self.event_callbacks.append(callback) + def open(self): self.result_feed = self.connect_result() self.order_socket = self.connect_order() @@ -28,7 +37,9 @@ class TradeSimulationClient(qmsg.Component): #next feed event socks = dict(self.poll.poll(self.heartbeat_timeout)) - if self.result_feed in socks and socks[self.result_feed] == self.zmq.POLLIN: + if self.result_feed in socks and \ + socks[self.result_feed] == self.zmq.POLLIN: + msg = self.result_feed.recv() if msg == str(zp.CONTROL_PROTOCOL.DONE): @@ -37,19 +48,19 @@ class TradeSimulationClient(qmsg.Component): return event = zp.MERGE_UNFRAME(msg) - self._handle_event(event) + + if(event.TRANSACTION != None): + self.txn_count += 1 + + for cb in self.event_callbacks: + cb(event) + + #signal done to order source. + self.order_socket.send(str(zp.ORDER_PROTOCOL.BREAK)) def connect_order(self): return self.connect_push_socket(self.addresses['order_address']) - def _handle_event(self, event): - self.handle_event(event) - #signal done to order source. - self.order_socket.send(str(zp.ORDER_PROTOCOL.BREAK)) - - def handle_event(self, event): - raise NotImplementedError - def order(self, sid, amount): self.order_socket.send(zp.ORDER_FRAME(sid, amount)) @@ -61,7 +72,8 @@ class OrderDataSource(qmsg.DataSource): def __init__(self, simulation_dt): """ - :param simulation_time: datetime in UTC timezone, sets the start time of simulation. orders + :param simulation_time: datetime in UTC timezone, sets the start + time of simulation. orders will be timestamped relative to this datetime. event = { 'sid' : an integer for security id, @@ -93,21 +105,31 @@ class OrderDataSource(qmsg.DataSource): #TODO: if this is the first iteration, break deadlock by sending a dummy order if(self.sent_count == 0): - self.send_dummy() + self.send(zp.namedict({})) #pull all orders from client. orders = [] order_dt = None count = 0 while True: - (rlist, wlist, xlist) = select([self.order_socket], - [], - [self.order_socket], - timeout=self.heartbeat_timeout/1000) #select timeout is in sec + + (rlist, wlist, xlist) = select( + [self.order_socket], + [], + [self.order_socket], + #allow half the time of a heartbeat for the order + #timeout, so we have time to signal we are done. + timeout=self.heartbeat_timeout/2000 + ) + #no more orders, should this be an error condition? if len(rlist) == 0 or len(xlist) > 0: - continue + #no order message means there was a timeout above, + #and the client is done sending orders (but isn't + #telling us himself!). + self.signal_done() + return order_msg = rlist[0].recv() if order_msg == str(zp.ORDER_PROTOCOL.DONE): @@ -115,29 +137,22 @@ class OrderDataSource(qmsg.DataSource): return if order_msg == str(zp.ORDER_PROTOCOL.BREAK): - qutil.LOGGER.info("order loop finished") break sid, amount = zp.ORDER_UNFRAME(order_msg) #send the order along - self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start dt = self.simulation_dt + self.last_iteration_duration - order_event = zp.namedict({"sid":sid, "amount":amount, "dt":dt}) + order = zp.namedict({"dt":dt, 'sid':sid, 'amount':amount}) - self.send(order_event) + self.send(order) count += 1 self.sent_count += 1 - #TODO: we have to send at least one dummy order per do_work iteration or the feed will block waiting for our messages. + #TODO: we have to send at least one dummy order per do_work iteration + # or the feed will block waiting for our messages. if(count == 0): - self.send_dummy() - self.sent_count += 1 - - def send_dummy(self): - dt = self.simulation_dt + self.last_iteration_duration - dummy_order = zp.namedict({"sid":0, "amount":0, "dt":dt}) - self.send(dummy_order) + self.send(zp.namedict({})) @@ -147,6 +162,7 @@ class TransactionSimulator(qmsg.BaseTransform): qmsg.BaseTransform.__init__(self, zp.TRANSFORM_TYPE.TRANSACTION) self.open_orders = {} self.order_count = 0 + self.txn_count = 0 self.trade_windwo = datetime.timedelta(seconds=30) self.orderTTL = datetime.timedelta(days=1) self.volume_share = 0.05 @@ -157,7 +173,6 @@ class TransactionSimulator(qmsg.BaseTransform): Pulls one message from the event feed, then loops on orders until client sends DONE message. """ - #TODO: need a way to send a placeholder txn, to avoid blocking merge... maybe customize merge to not block on txn? if(event.type == zp.DATASOURCE_TYPE.ORDER): self.add_open_order(event) self.state['value'] = None @@ -190,7 +205,8 @@ class TransactionSimulator(qmsg.BaseTransform): def apply_trade_to_open_orders(self, event): if(event.volume == 0): - #there are zero volume events bc some stocks trade less frequently than once per minute. + #there are zero volume events bc some stocks trade + #less frequently than once per minute. return self.create_dummy_txn(event.dt) if self.open_orders.has_key(event.sid): @@ -203,7 +219,8 @@ class TransactionSimulator(qmsg.BaseTransform): dt = event.dt for order in orders: - #we're using minute bars, so allow orders within 30 seconds of the trade + #we're using minute bars, so allow orders within + #30 seconds of the trade if((order.dt - event.dt) < self.trade_windwo): total_order += order.amount if(order.dt > dt): @@ -224,10 +241,17 @@ class TransactionSimulator(qmsg.BaseTransform): volume_share = .25 amount = volume_share * event.volume * direction impact = (volume_share)**2 * .1 * direction * event.price - return self.create_transaction(event.sid, amount, event.price + impact, dt.replace(tzinfo = pytz.utc), direction) + return self.create_transaction( + event.sid, + amount, + event.price + impact, + dt.replace(tzinfo = pytz.utc), + direction + ) - def create_transaction(self, sid, amount, price, dt, direction): + def create_transaction(self, sid, amount, price, dt, direction): + self.txn_count += 1 txn = {'sid' : sid, 'amount' : int(amount), 'dt' : dt, diff --git a/zipline/messaging.py b/zipline/messaging.py index e187ab78..5d3e5fd0 100644 --- a/zipline/messaging.py +++ b/zipline/messaging.py @@ -287,13 +287,18 @@ class ParallelBuffer(Component): cur_source = None earliest_source = None earliest_event = None - #iterate over the queues of events from all sources (1 queue per datasource) + #iterate over the queues of events from all sources + #(1 queue per datasource) for events in self.data_buffer.values(): if len(events) == 0: continue cur_source = events first_in_list = events[0] - + if first_in_list.dt == None: + #this is a filler event, discard + events.pop(0) + continue + if (earliest_event == None) or (first_in_list.dt <= earliest_event.dt): earliest_event = first_in_list earliest_source = cur_source @@ -384,7 +389,8 @@ class MergedParallelBuffer(ParallelBuffer): def append(self, event): """ - :param event: a namedict with one entry. key is the name of the transform, value is the transformed value. + :param event: a namedict with one entry. key is the name of the + transform, value is the transformed value. Add an event to the buffer for the source specified by source_id. """ @@ -398,7 +404,7 @@ class BaseTransform(Component): Top level execution entry point for the transform - connects to the feed socket to subscribe to events - - connets to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms + - connects to the result socket (most oftened bound by a TransformsMerge) to PUSH transforms - processes all messages received from feed, until DONE message received - pushes all transforms - sends DONE to result socket, closes all sockets and context diff --git a/zipline/protocol.py b/zipline/protocol.py index 7286ce1d..caadfaa9 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -235,11 +235,27 @@ def DATASOURCE_FRAME(event): assert isinstance(event.source_id, basestring) assert isinstance(event.type, int), 'Unexpected type %s' % (event.type) + + #datasources will send sometimes send empty msgs to feel gaps + if len(event.keys()) == 2: + return msgpack.dumps(tuple([ + event.type, + event.source_id, + DATASOURCE_TYPE.EMPTY + ])) if(event.type == DATASOURCE_TYPE.TRADE): - return msgpack.dumps(tuple([event.type, TRADE_FRAME(event)])) + return msgpack.dumps(tuple([ + event.type, + event.source_id, + TRADE_FRAME(event) + ])) elif(event.type == DATASOURCE_TYPE.ORDER): - return msgpack.dumps(tuple([event.type, ORDER_SOURCE_FRAME(event)])) + return msgpack.dumps(tuple([ + event.type, + event.source_id, + ORDER_SOURCE_FRAME(event) + ])) else: raise INVALID_DATASOURCE_FRAME(str(event)) @@ -261,15 +277,21 @@ def DATASOURCE_UNFRAME(msg): """ try: - ds_type, payload = msgpack.loads(msg) + ds_type, source_id, payload = msgpack.loads(msg) assert isinstance(ds_type, int) - if(ds_type == DATASOURCE_TYPE.TRADE): - return TRADE_UNFRAME(payload) + rval = namedict({'source_id':source_id}) + if payload == DATASOURCE_TYPE.EMPTY: + child_value = namedict({'dt':None}) + elif(ds_type == DATASOURCE_TYPE.TRADE): + child_value = TRADE_UNFRAME(payload) elif(ds_type == DATASOURCE_TYPE.ORDER): - return ORDER_SOURCE_UNFRAME(payload) + child_value = ORDER_SOURCE_UNFRAME(payload) else: raise INVALID_DATASOURCE_FRAME(msg) - + + rval.merge(child_value) + return rval + except TypeError: raise INVALID_DATASOURCE_FRAME(msg) except ValueError: @@ -401,7 +423,6 @@ def TRADE_FRAME(event): """ assert isinstance(event, namedict) - assert isinstance(event.source_id, basestring) assert event.type == DATASOURCE_TYPE.TRADE assert isinstance(event.sid, int) assert isinstance(event.price, numbers.Real) @@ -411,16 +432,14 @@ def TRADE_FRAME(event): event.sid, event.price, event.volume, - event.epoch, - event.micros, + event.dt, event.type, - event.source_id ])) def TRADE_UNFRAME(msg): try: packed = msgpack.loads(msg) - sid, price, volume, epoch, micros, source_type, source_id = packed + sid, price, volume, dt, source_type = packed assert isinstance(sid, int) assert isinstance(price, numbers.Real) @@ -429,10 +448,8 @@ def TRADE_UNFRAME(msg): 'sid' : sid, 'price' : price, 'volume' : volume, - 'epoch' : epoch, - 'micros' : micros, - 'type' : source_type, - 'source_id' : source_id + 'dt' : dt, + 'type' : source_type }) UNPACK_DATE(rval) return rval @@ -480,13 +497,12 @@ def TRANSACTION_FRAME(event): event.price, event.amount, event.commission, - event.epoch, - event.micros + event.dt ])) def TRANSACTION_UNFRAME(msg): try: - sid, price, amount, commission, epoch, micros = msgpack.loads(msg) + sid, price, amount, commission, dt = msgpack.loads(msg) assert isinstance(sid, int) assert isinstance(price, numbers.Real) @@ -497,8 +513,7 @@ def TRANSACTION_UNFRAME(msg): 'price' : price, 'amount' : amount, 'commission' : commission, - 'epoch' : epoch, - 'micros' : micros + 'dt' : dt }) UNPACK_DATE(rval) @@ -523,8 +538,7 @@ def ORDER_SOURCE_FRAME(event): return msgpack.dumps(tuple([ event.sid, event.amount, - event.epoch, - event.micros, + event.dt, event.source_id, event.type ])) @@ -532,12 +546,11 @@ def ORDER_SOURCE_FRAME(event): def ORDER_SOURCE_UNFRAME(msg): try: - sid, amount, epoch, micros, source_id, source_type = msgpack.loads(msg) + sid, amount, dt, source_id, source_type = msgpack.loads(msg) event = namedict({ "sid" : sid, "amount" : amount, - "epoch" : epoch, - "micros" : micros, + "dt" : dt, "source_id" : source_id, "type" : source_type }) @@ -560,9 +573,8 @@ def PACK_DATE(event): """ Packs the datetime property of event into msgpack'able longs. This function should be called purely for its side effects. - The event's 'dt' property is replaced by two longs: epoch and micros. - Epoch is the unix epoch time in UTC, and micros is the microsecond - property of the original event.dt datetime object. + The event's 'dt' property is replaced by a tuple of integers:: + - year, month, day, hour, minute, second, microsecond PACK_DATE and UNPACK_DATE are inverse operations. @@ -571,44 +583,44 @@ def PACK_DATE(event): """ assert isinstance(event.dt, datetime.datetime) assert event.dt.tzinfo == pytz.utc #utc only please - epoch = long(event.dt.strftime('%s')) - event['epoch'] = epoch - event['micros'] = event.dt.microsecond - event.delete('dt') + year, month, day, hour, minute, second = event.dt.timetuple()[0:6] + micros = event.dt.microsecond + event['dt'] = tuple([year, month, day, hour, minute, second, micros]) def UNPACK_DATE(event): """ Unpacks the datetime property of event from msgpack'able longs. This function should be called purely for its side effects. - The event's 'dt' property is created by reading and then combining two longs: epoch and micros. - The epoch and micros properties are removed after dt is added. + The event's 'dt' property is converted to a datetime by reading and then + combining a tuple of integers. UNPACK_DATE and PACK_DATE are inverse operations. - :param event: event must a namedict with:: - - a property named 'epoch' that is an integral representing the unix \ - epoch time in UTC - - a property named 'micros' that is an integral the microsecond \ - property of the original event.dt datetime object + :param tuple event: event must a namedict with:: + - a property named 'dt_tuple' that is a tuple of integers + representing the date and time in UTC. dt_tumple must have year, + month, day, hour, minute, second, and microsecond :rtype: None """ - assert isinstance(event.epoch, numbers.Integral) - assert isinstance(event.micros, numbers.Integral) - dt = datetime.datetime.fromtimestamp(event.epoch) - dt = dt.replace(microsecond = event.micros, tzinfo = pytz.utc) - event.delete('epoch') - event.delete('micros') + assert isinstance(event.dt, tuple) + assert len(event.dt) == 7 + for item in event.dt: + assert isinstance(item, numbers.Integral) + year, month, day, hour, minute, second, micros = event.dt + dt = datetime.datetime(year, month, day, hour, minute, second) + dt = dt.replace(microsecond = micros, tzinfo = pytz.utc) event.dt = dt DATASOURCE_TYPE = Enum( 'ORDER', - 'TRADE' + 'TRADE', + 'EMPTY', ) ORDER_PROTOCOL = Enum( 'DONE', - 'BREAK' + 'BREAK', ) diff --git a/zipline/simulator.py b/zipline/simulator.py index 1e3e3f5f..b63d6d9b 100644 --- a/zipline/simulator.py +++ b/zipline/simulator.py @@ -61,10 +61,8 @@ class Simulator(ComponentHost): if not self.running: return - try: - self.controller.shutdown(context=self.context) - except: - import pdb; pdb.set_trace() + #if self.controller: + #self.controller.shutdown() for component in self.components.itervalues(): component.shutdown() diff --git a/zipline/test/client.py b/zipline/test/client.py index 205cabff..c63f68b8 100644 --- a/zipline/test/client.py +++ b/zipline/test/client.py @@ -66,20 +66,23 @@ class TestClient(qmsg.Component): return zp.MERGE_UNFRAME(msg) -class TestTradingClient(TradeSimulationClient): +class TestAlgorithm(): - def __init__(self, sid, amount, order_count): - TradeSimulationClient.__init__(self) + def __init__(self, sid, amount, order_count, trading_client): + self.trading_client = trading_client + self.trading_client.add_event_callback(self.handle_event) self.count = order_count self.sid = sid self.amount = amount self.incr = 0 + self.done = False def handle_event(self, event): #place an order for 100 shares of sid:133 - if(self.incr < self.count): - self.order(self.sid, self.amount) - self.incr += 1 - else: - self.signal_order_done() - self.signal_done() + if self.incr < self.count: + if event.source_id != zp.FINANCE_COMPONENT.ORDER_SOURCE: + self.trading_client.order(self.sid, self.amount) + self.incr += 1 + elif not self.done: + self.trading_client.signal_order_done() + self.done = True diff --git a/zipline/test/factory.py b/zipline/test/factory.py index 68ed322b..48b01bb2 100644 --- a/zipline/test/factory.py +++ b/zipline/test/factory.py @@ -11,26 +11,37 @@ def load_market_data(): bm_map = msgpack.loads(fp_bm.read()) bm_returns = [] for epoch, returns in bm_map.iteritems(): - bm_returns.append(risk.daily_return(date=datetime.datetime.fromtimestamp(epoch).replace(hour=0, minute=0, second=0, tzinfo=pytz.utc), returns=returns)) + event_dt = datetime.datetime.fromtimestamp(epoch) + event_dt = event_dt.replace( + hour=0, + minute=0, + second=0, + tzinfo=pytz.utc + ) + + daily_return = risk.daily_return(date=event_dt, returns=returns) + bm_returns.append(daily_return) bm_returns = sorted(bm_returns, key=lambda(x): x.date) fp_tr = open("./zipline/test/treasury_curves.msgpack", "rb") tr_map = msgpack.loads(fp_tr.read()) tr_curves = {} for epoch, curve in tr_map.iteritems(): - tr_curves[datetime.datetime.fromtimestamp(epoch).replace(hour=0, minute=0, second=0, tzinfo=pytz.utc)] = curve + tr_dt = datetime.datetime.fromtimestamp(epoch) + tr_dt = tr_dt.replace(hour=0, minute=0, second=0, tzinfo=pytz.utc) + tr_curves[tr_dt] = curve return bm_returns, tr_curves def create_trade(sid, price, amount, datetime): - row = { + row = zp.namedict({ 'source_id' : "test_factory", 'type' : zp.DATASOURCE_TYPE.TRADE, 'sid' : sid, 'dt' : datetime, 'price' : price, 'volume' : amount - } + }) return row def create_trade_history(sid, prices, amounts, start_time, interval, trading_calendar): @@ -50,19 +61,23 @@ def create_trade_history(sid, prices, amounts, start_time, interval, trading_cal return trades -def createTxn(sid, price, amount, datetime, btrid=None): - txn = Transaction(sid=sid, amount=amount, dt = datetime, - price=price, transaction_cost=-1*price*amount) +def create_txn(sid, price, amount, datetime, btrid=None): + txn = zp.namedict({ + 'sid':sid, + 'amount':amount, + 'dt':datetime, + 'price':price, + }) return txn -def create_transaction_history(sid, priceList, amtList, startTime, interval, trading_calendar): +def create_txn_history(sid, priceList, amtList, startTime, interval, trading_calendar): txns = [] current = startTime for price, amount in zip(priceList, amtList): if trading_calendar.is_trading_day(current): - txns.append(createTxn(sid, price, amount, current)) + txns.append(create_txn(sid, price, amount, current)) current = current + interval else: diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index 829f5fa2..bf26ae0e 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -1,8 +1,12 @@ """Tests for the zipline.finance package""" import mock import pytz + from unittest2 import TestCase from datetime import datetime, timedelta +from collections import defaultdict + +from nose.tools import timed import zipline.test.factory as factory import zipline.util as qutil @@ -10,26 +14,48 @@ import zipline.finance.risk as risk import zipline.protocol as zp import zipline.finance.performance as perf -from zipline.test.client import TestTradingClient +from zipline.test.client import TestAlgorithm from zipline.sources import SpecificEquityTrades -from zipline.finance.trading import TransactionSimulator, OrderDataSource +from zipline.finance.trading import TransactionSimulator, OrderDataSource, \ +TradeSimulationClient from zipline.simulator import AddressAllocator, Simulator from zipline.monitor import Controller +DEFAULT_TIMEOUT = 5 # seconds + +allocator = AddressAllocator(1000) class FinanceTestCase(TestCase): + leased_sockets = defaultdict(list) + def setUp(self): qutil.configure_logging() self.benchmark_returns, self.treasury_curves = \ factory.load_market_data() - + self.trading_environment = risk.TradingEnvironment( - self.benchmark_returns, + self.benchmark_returns, self.treasury_curves ) - + self.allocator = allocator + + def allocate_sockets(self, n): + """ + Allocate sockets local to this test case, track them so + we can gc after test run. + """ + + assert isinstance(n, int) + assert n > 0 + + leased = self.allocator.lease(n) + + self.leased_sockets[self.id()].extend(leased) + return leased + + @timed(DEFAULT_TIMEOUT) def test_trade_feed_protocol(self): # TODO: Perhaps something more self-documenting for variables names? @@ -82,6 +108,7 @@ class FinanceTestCase(TestCase): self.assertEqual(zp.namedict(trade), event) + @timed(DEFAULT_TIMEOUT) def test_order_protocol(self): #client places an order order_msg = zp.ORDER_FRAME(133, 100) @@ -126,14 +153,14 @@ class FinanceTestCase(TestCase): self.assertEqual(recovered_tx.sid, 133) self.assertEqual(recovered_tx.amount, 100) + @timed(DEFAULT_TIMEOUT) def test_orders(self): # Just verify sending and receiving orders. # -------------- # Allocate sockets for the simulator components - allocator = AddressAllocator(8) - sockets = allocator.lease(8) + sockets = self.allocate_sockets(8) addresses = { 'sync_address' : sockets[0], @@ -172,15 +199,21 @@ class FinanceTestCase(TestCase): ) set1 = SpecificEquityTrades("flat-133", trade_history) - - #client sill send 10 orders for 100 shares of 133 - client = TestTradingClient(133, 100, 10) + + trading_client = TradeSimulationClient() + #client will send 10 orders for 100 shares of 133 + test_algo = TestAlgorithm(133, 100, 10, trading_client) ts = datetime.strptime("02/1/2012","%m/%d/%Y").replace(tzinfo=pytz.utc) order_source = OrderDataSource(ts) transaction_sim = TransactionSimulator() - sim.register_components([client, order_source, transaction_sim, set1]) + sim.register_components([ + trading_client, + order_source, + transaction_sim, + set1 + ]) sim.register_controller( con ) # Simulation @@ -188,6 +221,8 @@ class FinanceTestCase(TestCase): sim_context = sim.simulate() sim_context.join() + self.assertTrue(sim.ready()) + self.assertFalse(sim.exception) # TODO: Make more assertions about the final state of the components. self.assertEqual(sim.feed.pending_messages(), 0, \ @@ -195,14 +230,14 @@ class FinanceTestCase(TestCase): .format(n=sim.feed.pending_messages())) - def test_performance(self): + @timed(DEFAULT_TIMEOUT) + def test_performance(self): # verify order -> transaction -> portfolio position. # -------------- # Allocate sockets for the simulator components - allocator = AddressAllocator(8) - sockets = allocator.lease(8) + sockets = self.allocate_sockets(8) addresses = { 'sync_address' : sockets[0], @@ -225,9 +260,10 @@ class FinanceTestCase(TestCase): # --------------------- # TODO: Perhaps something more self-documenting for variables names? + trade_count = 100 sid = 133 - price = [10.1] * 16 - volume = [100] * 16 + price = [10.1] * trade_count + volume = [100] * trade_count start_date = datetime.strptime("02/1/2012","%m/%d/%Y") trade_time_increment = timedelta(days=1) @@ -242,24 +278,27 @@ class FinanceTestCase(TestCase): set1 = SpecificEquityTrades("flat-133", trade_history) #client sill send 10 orders for 100 shares of 133 - client = TestTradingClient(133, 100, 10) + trading_client = TradeSimulationClient() + test_algo = TestAlgorithm(133, 100, 10, trading_client) ts = datetime.strptime("02/1/2012","%m/%d/%Y") ts = ts.replace(tzinfo=pytz.utc) order_source = OrderDataSource(ts) transaction_sim = TransactionSimulator() - portfolio_client = perf.PortfolioClient( + perf_tracker = perf.PerformanceTracker( trade_history[0]['dt'], trade_history[-1]['dt'], 1000000.0, self.trading_environment) + + #register perf_tracker to receive callbacks from the client. + trading_client.add_event_callback(perf_tracker.update) sim.register_components([ - client, + trading_client, order_source, transaction_sim, set1, - portfolio_client, ]) sim.register_controller( con ) @@ -268,8 +307,46 @@ class FinanceTestCase(TestCase): sim_context = sim.simulate() sim_context.join() - - # TODO: Make more assertions about the final state of the components. - self.assertEqual(sim.feed.pending_messages(), 0, \ + self.assertEqual( + sim.feed.pending_messages(), + 0, "The feed should be drained of all messages, found {n} remaining." \ - .format(n=sim.feed.pending_messages())) \ No newline at end of file + .format(n=sim.feed.pending_messages()) + ) + + self.assertEqual( + sim.merge.pending_messages(), + 0, + "The merge should be drained of all messages, found {n} remaining." \ + .format(n=sim.merge.pending_messages()) + ) + + self.assertEqual( + test_algo.count, + test_algo.incr, + "The test algorithm should send as many orders as specified.") + + self.assertEqual( + order_source.sent_count, + test_algo.count, + "The order source should have sent as many orders as the algo." + ) + + self.assertEqual( + transaction_sim.txn_count, + perf_tracker.txn_count, + "The perf tracker should handle the same number of transactions as\ +as the simulator emits." + ) + + self.assertEqual( + len(perf_tracker.cumulative_performance.positions), + 1, + "Portfolio should have one position." + ) + + self.assertEqual( + perf_tracker.cumulative_performance.positions[133].sid, + 133, + "Portfolio should have one position in 133." + ) diff --git a/zipline/test/test_messaging.py b/zipline/test/test_messaging.py index 187b68b8..c901f984 100644 --- a/zipline/test/test_messaging.py +++ b/zipline/test/test_messaging.py @@ -17,6 +17,15 @@ from nose.tools import timed # it up as a test. Its a Mixin of sorts at this point. class SimulatorTestCase(object): + # Leased sockets is a defaultdict keyed by the test case. + # This lets you debug the sockets being allocated in the + # specific test cases and tear them down appropriately. + # + # { + # 'test_orders' : ['tcp : //127.0.0.1 : 1000', ... ], + # 'test_performance' : ['tcp : //127.0.0.1 : 1025', ... ], + # } + leased_sockets = defaultdict(list) def setUp(self): diff --git a/zipline/test/test_perf_tracking.py b/zipline/test/test_perf_tracking.py new file mode 100644 index 00000000..b8ab30c0 --- /dev/null +++ b/zipline/test/test_perf_tracking.py @@ -0,0 +1,525 @@ +import unittest +import copy +import random +import datetime + +import zipline.test.factory as factory +import zipline.util as qutil +import zipline.finance.performance as perf +import zipline.finance.risk as risk +class PerformanceTestCase(unittest.TestCase): + + def setUp(self): + qutil.configure_logging() + self.benchmark_returns, self.treasury_curves = \ + factory.load_market_data() + + self.trading_environment = risk.TradingEnvironment( + self.benchmark_returns, + self.treasury_curves + ) + + self.onesec = datetime.timedelta(seconds=1) + self.oneday = datetime.timedelta(days=1) + self.tradingday = datetime.timedelta(hours=6, minutes=30) + random_index = random.randint( + 0, + len(self.trading_environment.trading_days) + ) + + self.dt = self.trading_environment.trading_days[random_index] + + def tearDown(self): + pass + + def test_long_position(self): + """ + verify that the performance period calculates properly for a \ +single buy transaction + """ + #post some trades in the market + trades = factory.create_trade_history( + 1, + [10,10,10,11], + [100,100,100,100], + self.dt, + self.onesec, + self.trading_environment + ) + + txn = factory.create_txn(1,10.0,100,self.dt + self.onesec) + pp = perf.PerformancePeriod({}, 0.0, 1000.0) + + pp.execute_transaction(txn) + for trade in trades: + pp.update_last_sale(trade) + + pp.calculate_performance() + + self.assertEqual( + pp.period_capital_used, + -1 * txn.price * txn.amount, + "capital used should be equal to the opposite of the transaction \ + cost of sole txn in test" + ) + + self.assertEqual(len(pp.positions),1,"should be just one position") + + self.assertEqual( + pp.positions[1].sid, + txn.sid, + "position should be in security with id 1") + + self.assertEqual( + pp.positions[1].amount, + txn.amount, + "should have a position of {sharecount} shares".format( + sharecount=txn.amount + ) + ) + + self.assertEqual( + pp.positions[1].cost_basis, + txn.price, + "should have a cost basis of 10" + ) + + self.assertEqual( + pp.positions[1].last_sale, + trades[-1]['price'], + "last sale should be same as last trade. \ + expected {exp} actual {act}".format( + exp=trades[-1]['price'], + act=pp.positions[1].last_sale + ) + ) + + self.assertEqual( + pp.ending_value, + 1100, + "ending value should be price of last trade times number of \ + shares in position" + ) + + self.assertEqual(pp.pnl, 100, "gain of 1 on 100 shares should be 100") + + def test_short_position(self): + """verify that the performance period calculates properly for a \ +single short-sale transaction""" + trades_1 = factory.create_trade_history( + 1, + [10,10,10,11], + [100,100,100,100], + self.dt, + self.onesec, + self.trading_environment + ) + + txn = factory.create_txn(1, 10.0, -100, self.dt + self.onesec) + pp = perf.PerformancePeriod({}, 0.0, 1000.0) + + pp.execute_transaction(txn) + for trade in trades_1: + pp.update_last_sale(trade) + + pp.calculate_performance() + + self.assertEqual( + pp.period_capital_used, + -1 * txn.price * txn.amount, + "capital used should be equal to the opposite of the transaction\ + cost of sole txn in test" + ) + + self.assertEqual( + len(pp.positions), + 1, + "should be just one position") + + self.assertEqual( + pp.positions[1].sid, + txn.sid, + "position should be in security from the transaction" + ) + + self.assertEqual( + pp.positions[1].amount, + -100, + "should have a position of -100 shares" + ) + + self.assertEqual( + pp.positions[1].cost_basis, + txn.price, + "should have a cost basis of 10" + ) + + self.assertEqual( + pp.positions[1].last_sale, + trades_1[-1]['price'], + "last sale should be price of last trade" + ) + + self.assertEqual( + pp.ending_value, + -1100, + "ending value should be price of last trade times number of \ + shares in position" + ) + + self.assertEqual(pp.pnl,-100,"gain of 1 on 100 shares should be 100") + + #simulate additional trades, and ensure that the position value + #reflects the new price + trades_2 = factory.create_trade_history( + 1, + [10,9], + [100,100], + trades_1[-1]['dt'] + self.onesec, + self.onesec, + self.trading_environment + ) + + #simulate a rollover to a new period + pp2 = perf.PerformancePeriod( + pp.positions, + pp.ending_value, + pp.ending_cash + ) + + for trade in trades_2: + pp2.update_last_sale(trade) + + pp2.calculate_performance() + + self.assertEqual( + pp2.period_capital_used, + 0, + "capital used should be zero, there were no transactions in \ + performance period" + ) + + self.assertEqual( + len(pp2.positions), + 1, + "should be just one position" + ) + + self.assertEqual( + pp2.positions[1].sid, + txn.sid, + "position should be in security from the transaction" + ) + + self.assertEqual( + pp2.positions[1].amount, + -100, + "should have a position of -100 shares" + ) + + self.assertEqual( + pp2.positions[1].cost_basis, + txn.price, + "should have a cost basis of 10" + ) + + self.assertEqual( + pp2.positions[1].last_sale, + trades_2[-1].price, + "last sale should be price of last trade" + ) + + self.assertEqual( + pp2.ending_value, + -900, + "ending value should be price of last trade times number of \ + shares in position") + + self.assertEqual( + pp2.pnl, + 200, + "drop of 2 on -100 shares should be 200" + ) + + #now run a performance period encompassing the entire trade sample. + ppTotal = perf.PerformancePeriod({}, 0.0, 1000.0) + + for trade in trades_1: + ppTotal.update_last_sale(trade) + + ppTotal.execute_transaction(txn) + + for trade in trades_2: + ppTotal.update_last_sale(trade) + + ppTotal.calculate_performance() + + self.assertEqual( + ppTotal.period_capital_used, + -1 * txn.price * txn.amount, + "capital used should be equal to the opposite of the transaction \ +cost of sole txn in test" + ) + + self.assertEqual( + len(ppTotal.positions), + 1, + "should be just one position" + ) + self.assertEqual( + ppTotal.positions[1].sid, + txn.sid, + "position should be in security from the transaction" + ) + + self.assertEqual( + ppTotal.positions[1].amount, + -100, + "should have a position of -100 shares" + ) + + self.assertEqual( + ppTotal.positions[1].cost_basis, + txn.price, + "should have a cost basis of 10" + ) + + self.assertEqual( + ppTotal.positions[1].last_sale, + trades_2[-1].price, + "last sale should be price of last trade" + ) + + self.assertEqual( + ppTotal.ending_value, + -900, + "ending value should be price of last trade times number of \ + shares in position") + + self.assertEqual( + ppTotal.pnl, + 100, + "drop of 1 on -100 shares should be 100" + ) + + def test_covering_short(self): + """verify performance where short is bought and covered, and shares \ +trade after cover""" + + trades = factory.create_trade_history( + 1, + [10,10,10,11,9,8,7,8,9,10], + [100,100,100,100,100,100,100,100,100,100], + self.dt, + self.onesec, + self.trading_environment + ) + + short_txn = factory.create_txn( + 1, + 10.0, + -100, + self.dt + self.onesec + ) + + cover_txn = factory.create_txn(1,7.0,100,self.dt + self.onesec * 6) + pp = perf.PerformancePeriod({}, 0.0, 1000.0) + + pp.execute_transaction(short_txn) + pp.execute_transaction(cover_txn) + + for trade in trades: + pp.update_last_sale(trade) + + pp.calculate_performance() + + short_txn_cost = short_txn.price * short_txn.amount + cover_txn_cost = cover_txn.price * cover_txn.amount + + self.assertEqual( + pp.period_capital_used, + -1 * short_txn_cost - cover_txn_cost, + "capital used should be equal to the net transaction costs" + ) + + self.assertEqual( + len(pp.positions), + 1, + "should be just one position" + ) + + self.assertEqual( + pp.positions[1].sid, + short_txn.sid, + "position should be in security from the transaction" + ) + + self.assertEqual( + pp.positions[1].amount, + 0, + "should have a position of -100 shares" + ) + + self.assertEqual( + pp.positions[1].cost_basis, + 0, + "a covered position should have a cost basis of 0" + ) + + self.assertEqual( + pp.positions[1].last_sale, + trades[-1].price, + "last sale should be price of last trade" + ) + + self.assertEqual( + pp.ending_value, + 0, + "ending value should be price of last trade times number of \ +shares in position" + ) + + self.assertEqual( + pp.pnl, + 300, + "gain of 1 on 100 shares should be 300" + ) + + def test_cost_basis_calc(self): + trades = factory.create_trade_history( + 1, + [10,11,11,12], + [100,100,100,100], + self.dt, + self.onesec, + self.trading_environment + ) + + transactions = factory.create_txn_history( + 1, + [10,11,11,12], + [100,100,100,100], + self.dt, + self.onesec, + self.trading_environment + ) + + pp = perf.PerformancePeriod({}, 0.0, 1000.0) + + for txn in transactions: + pp.execute_transaction(txn) + + for trade in trades: + pp.update_last_sale(trade) + + pp.calculate_performance() + + self.assertEqual( + pp.positions[1].last_sale, + trades[-1].price, + "should have a last sale of 12, got {val}".format( + val=pp.positions[1].last_sale + ) + ) + + self.assertEqual( + pp.positions[1].cost_basis, + 11, + "should have a cost basis of 11" + ) + + self.assertEqual( + pp.pnl, + 400 + ) + + saleTxn = factory.create_txn( + 1, + 10.0, + -100, + self.dt + self.onesec * 4) + + down_tick = factory.create_trade( + 1, + 10.0, + 100, + trades[-1].dt + self.onesec) + + pp2 = perf.PerformancePeriod( + copy.deepcopy(pp.positions), + pp.ending_value, + pp.ending_cash + ) + + pp2.execute_transaction(saleTxn) + pp2.update_last_sale(down_tick) + + pp2.calculate_performance() + self.assertEqual( + pp2.positions[1].last_sale, + 10, + "should have a last sale of 10, was {val}".format(val=pp2.positions[1].last_sale) + ) + + self.assertEqual( + round(pp2.positions[1].cost_basis,2), + 11.33, + "should have a cost basis of 11.33" + ) + + #print "second period pnl is {pnl}".format(pnl=pp2.pnl) + self.assertEqual(pp2.pnl, -800, "this period goes from +400 to -400") + + pp3 = perf.PerformancePeriod({}, 0.0, 1000.0) + + transactions.append(saleTxn) + for txn in transactions: + pp3.execute_transaction(txn) + + trades.append(down_tick) + for trade in trades: + pp3.update_last_sale(trade) + + pp3.calculate_performance() + self.assertEqual( + pp3.positions[1].last_sale, + 10, + "should have a last sale of 10" + ) + + self.assertEqual( + round(pp3.positions[1].cost_basis,2), + 11.33, + "should have a cost basis of 11.33" + ) + + self.assertEqual( + pp3.pnl, + -400, + "should be -400 for all trades and transactions in period" + ) + + + def dtest_daily_performance_calc(self): + hostedAlgo = factories.createAlgo("workingAlgo.py") + btRecord = BackTestRun(duration_unit="Days",duration_count=5,capital_base=25000000) + bt = BackTest(hostedAlgo,btRecord) + start = bt.periodStart + end = bt.periodEnd + #print "{start} to {end}".format(start=start, end=end) + + trades = factories.createTradeHistory(1,[10,11,12,11],[100,100,100,100],start, self.oneday) + #createTransaction(self, sid, amount, price, dt, order_id) + bt.createTransaction(1, 100, 10.0, trades[0].dt + 30*self.onesec, None) + curPeriod = start + bt.positions = {} + dailyPeriods = [] + bt.initialValue = 0.0 + while (bt.mktClose) <= bt.periodEnd: + bt.updatePerformance() + dailyPeriods.append(bt.curPeriod) + bt.nextMarketDay() + + self.assertEqual(dailyPeriods[0].pnl,0,"the first day's performance should be zero") + self.assertEqual(dailyPeriods[1].pnl,100,"the second day's pnl should be 100 but was {pnl}".format(pnl=dailyPeriods[1].pnl)) + \ No newline at end of file diff --git a/zipline/test/test_risk.py b/zipline/test/test_risk.py index 81aa0de9..bdcf575f 100644 --- a/zipline/test/test_risk.py +++ b/zipline/test/test_risk.py @@ -41,7 +41,7 @@ class Risk(unittest.TestCase): start_date = datetime.datetime(year=2006, month=1, day=1) returns = factory.create_returns_from_list([1.0,-0.5,0.8,.17,1.0,-0.1,-0.45], start_date, self.trading_calendar) #200, 100, 180, 210.6, 421.2, 379.8, 208.494 - metrics = risk.RiskMetrics(returns[0].date, returns[-1].date, returns, self.benchmark_returns, self.treasury_curves, self.trading_calendar) + metrics = risk.RiskMetrics(returns[0].date, returns[-1].date, returns, self.trading_calendar) self.assertEqual(metrics.max_drawdown, 0.505) def test_benchmark_returns_06(self):