diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index ddd8af74..2919596a 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -21,8 +21,7 @@ from itertools import chain from logbook import Logger, Processor from collections import defaultdict -from zipline import ndict -from zipline.protocol import SIDData, DATASOURCE_TYPE +from zipline.protocol import BarData, DATASOURCE_TYPE from zipline.finance.performance import PerformanceTracker from zipline.gens.utils import hash_args @@ -248,10 +247,10 @@ class AlgorithmSimulator(object): # Snapshot Setup # ============== - # The algorithm's universe as of our most recent event. - # We want an ndict that will have empty objects as default + # The algorithm's data as of our most recent event. + # We want an object that will have empty objects as default # values on missing keys. - self.universe = ndict(internal=defaultdict(SIDData)) + self.current_data = BarData() # We don't have a datetime for the current snapshot until we # receive a message. @@ -412,7 +411,7 @@ class AlgorithmSimulator(object): Update the universe with new event information. """ # Update our knowledge of this event's sid - sid_data = self.universe[event.sid] + sid_data = self.current_data[event.sid] sid_data.__dict__.update(event.__dict__) def simulate_snapshot(self, date): @@ -426,4 +425,4 @@ class AlgorithmSimulator(object): self.algo.set_datetime(self.snapshot_dt) # Update the simulation time. self.simulation_dt = date - self.algo.handle_data(self.universe) + self.algo.handle_data(self.current_data) diff --git a/zipline/protocol.py b/zipline/protocol.py index cfa27979..f0e79244 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +from collections import defaultdict import datetime from utils.protocol_utils import Enum @@ -121,6 +120,54 @@ class SIDData(object): return "SIDData({0})".format(self.__dict__) +class BarData(object): + """ + Holds the event data for all sids for a given dt. + + This is what is passed as `data` to the `handle_data` function. + + Note: Many methods are analogues of dictionary because of historical + usage of what this replaced as a dictionary subclass. + """ + + def __init__(self): + self._data = defaultdict(SIDData) + self._contains_override = None + + def __contains__(self, name): + if self._contains_override: + return self._contains_override(name) + else: + return name in self.__dict__ + + def __setitem__(self, name, value): + self._data[name] = value + + def __getitem__(self, name): + return self._data[name] + + def __delitem__(self, name): + del self._data[name] + + def __iter__(self): + return self._data.iterkeys() + + def keys(self): + return self._data.keys() + + def iterkeys(self): + return self._data.iterkeys() + + def itervalues(self): + return self._data.itervalues() + + def iteritems(self): + return self._data.iteritems() + + def items(self): + return self._data.items() + + class DailyReturn(object): def __init__(self, date, returns):