diff --git a/zipline/finance/performance.py b/zipline/finance/performance.py index df17fe17..cf761c55 100644 --- a/zipline/finance/performance.py +++ b/zipline/finance/performance.py @@ -45,6 +45,7 @@ class PortfolioClient(qmsg.Component): msg = self.result_feed.recv() if msg == str(zp.CONTROL_PROTOCOL.DONE): + self.handle_simulation_end() qutil.LOGGER.info("Portfolio Client is DONE!") self.signal_done() return @@ -74,7 +75,7 @@ class PortfolioClient(qmsg.Component): self.todays_performance.calculate_performance() - # + def handle_market_close(self): self.market_open = self.market_open + self.calendar_day while not self.trading_environment.is_trading_day(self.market_open): @@ -99,7 +100,9 @@ class PortfolioClient(qmsg.Component): self.todays_performance.ending_value, self.capital_base) - # + def handle_simulation_end(self): + self.risk_report = risk.riskmetrics(self.returns, self.trading_environment) + def round_to_nearest(self, x, base=5): return int(base * round(float(x)/base)) diff --git a/zipline/finance/risk.py b/zipline/finance/risk.py index b1959187..56aceaf0 100644 --- a/zipline/finance/risk.py +++ b/zipline/finance/risk.py @@ -18,15 +18,14 @@ class daily_return(): class periodmetrics(): def __init__(self, start_date, end_date, returns, trading_environment): - """ - :param treasury_curves: {datetime in utc -> {duration label -> interest rate}} - """ + + self.trading_environment = trading_environment + bm_returns = [x for x in self.trading_environment.benchmark_returns if x.date >= returns[0].date and x.date <= returns[-1].date] self.start_date = start_date self.end_date = end_date - self.trading_environment = trading_environment self.algorithm_period_returns, self.algorithm_returns = self.calculate_period_returns(returns) - self.benchmark_period_returns, self.benchmark_returns = self.calculate_period_returns(trading_environment.benchmark_returns) + self.benchmark_period_returns, self.benchmark_returns = self.calculate_period_returns(bm_returns) if(len(self.benchmark_returns) != len(self.algorithm_returns)): raise Exception("Mismatch between benchmark_returns ({bm_count}) and algorithm_returns ({algo_count}) in range {start} : {end}".format( bm_count=len(self.benchmark_returns), @@ -168,14 +167,6 @@ class riskmetrics(): self.algorithm_returns = algorithm_returns self.trading_environment = trading_environment - self.bm_returns = [x for x in self.trading_environmentself.benchmark_returns if x.date >= self.algorithm_returns[0].date and x.date <= self.algorithm_returns[-1].date] - self.treasury_curves = self.trading_environment.treasury_curves - - - qutil.LOGGER.debug("#### {start} thru {end} with {count} trading_days of {total} possible".format(start=self.algorithm_returns[0].date, - end=self.algorithm_returns[-1].date, - count=len(self.bm_returns), - total=len(benchmark_returns))) #calculate month ends self.month_periods = self.periods_in_range(1, self.algorithm_returns[0].date, self.algorithm_returns[-1].date) @@ -205,8 +196,6 @@ class riskmetrics(): cur_period_metrics = periodmetrics(start_date=cur_start, end_date=cur_end, returns=self.algorithm_returns, - benchmark_returns=self.bm_returns, - treasury_curves=self.treasury_curves, trading_environment=self.trading_environment) ends.append(cur_period_metrics) cur_start = advance_by_months(cur_start, 1) diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index 4a1b5da7..96196a7f 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -21,11 +21,13 @@ class FinanceTestCase(TestCase): def setUp(self): qutil.configure_logging() - benchmark_returns, treasury_curves = factory.load_market_data() - self.trading_env = risk.TradingEnvironment(benchmark_returns, treasury_curves) + self.benchmark_returns, self.treasury_curves = factory.load_market_data() + self.trading_environment = risk.TradingEnvironment(self.benchmark_returns, self.treasury_curves) def test_trade_feed_protocol(self): + + # TODO: Perhaps something more self-documenting for variables names? sid = 133 price = [10.0] * 4 volume = [100] * 4 @@ -33,7 +35,7 @@ class FinanceTestCase(TestCase): start_date = datetime.strptime("02/15/2012","%m/%d/%Y") one_day_td = timedelta(days=1) - trades = factory.create_trade_history(sid, price, volume, start_date, one_day_td, self.trading_env) + trades = factory.create_trade_history(sid, price, volume, start_date, one_day_td, self.trading_environment) for trade in trades: #simulate data source sending frame @@ -148,7 +150,7 @@ class FinanceTestCase(TestCase): start_date = datetime.strptime("02/1/2012","%m/%d/%Y") trade_time_increment = timedelta(days=1) - trade_history = factory.create_trade_history( sid, price, volume, start_date, trade_time_increment, self.trading_env ) + trade_history = factory.create_trade_history( sid, price, volume, start_date, trade_time_increment, self.trading_environment ) set1 = SpecificEquityTrades("flat-133", trade_history) @@ -210,7 +212,7 @@ class FinanceTestCase(TestCase): start_date = datetime.strptime("02/1/2012","%m/%d/%Y") trade_time_increment = timedelta(days=1) - trade_history = factory.create_trade_history( sid, price, volume, start_date, trade_time_increment, self.trading_env ) + trade_history = factory.create_trade_history( sid, price, volume, start_date, trade_time_increment, self.trading_environment ) set1 = SpecificEquityTrades("flat-133", trade_history) @@ -220,7 +222,7 @@ class FinanceTestCase(TestCase): order_source = OrderDataSource(ts) transaction_sim = TransactionSimulator() - portfolio_client = perf.PortfolioClient(trade_history[0]['dt'], trade_history[-1]['dt'], 1000000.0, self.trading_env) + portfolio_client = perf.PortfolioClient(trade_history[0]['dt'], trade_history[-1]['dt'], 1000000.0, self.trading_environment) sim.register_components([client, order_source, transaction_sim, set1, portfolio_client]) sim.register_controller( con ) diff --git a/zipline/test/test_risk.py b/zipline/test/test_risk.py index 5620d3e1..de7dad11 100644 --- a/zipline/test/test_risk.py +++ b/zipline/test/test_risk.py @@ -12,7 +12,7 @@ class Risk(unittest.TestCase): def setUp(self): qutil.configure_logging() self.benchmark_returns, self.treasury_curves = factory.load_market_data() - self.trading_calendar = risk.TradingCalendar(self.benchmark_returns, self.treasury_curves) + self.trading_calendar = risk.TradingEnvironment(self.benchmark_returns, self.treasury_curves) self.onesec = datetime.timedelta(seconds=1) self.oneday = datetime.timedelta(days=1)