From 5fcb2de59b44968512edc28cdfd5c4490853f609 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 6 Jul 2012 11:54:40 -0400 Subject: [PATCH] ndicts are now deepcopyable. --- tests/test_ndict.py | 32 ++++++++++++++++++++++++++++---- zipline/utils/protocol_utils.py | 3 +++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/test_ndict.py b/tests/test_ndict.py index 9d72e732..845bb97d 100644 --- a/tests/test_ndict.py +++ b/tests/test_ndict.py @@ -1,6 +1,8 @@ from datetime import datetime import pytz +from copy import deepcopy + from zipline.utils.protocol_utils import ndict def test_ndict(): @@ -54,10 +56,32 @@ def test_ndict(): del nd['x'] assert not nd.has_key('x') assert nd.get('x') is None - - + + for n in xrange(1000): dt = datetime.utcnow().replace(tzinfo=pytz.utc) nd2 = ndict({"dt":dt, "otherdata":"ishere"*1000, "maybeanint":3}) - - nd2.dt2 = dt \ No newline at end of file + + nd2.dt2 = dt + +def test_ndict_deepcopy(): + def assert_correctly_copied(orig, copy): + assert nd == nd_dc, "Deepcopied ndict should have same keys and values." + + nd_dc.z = 3 + assert 'z' not in nd, "'z' also added to original ndict." + + nd_dc.y = 10 + assert nd_dc.y == 10, "value of copied ndict not correctly set." + assert nd.y != 10, "value also set of original ndict." + + nd = ndict({'x': 1, 'y': 2}) + nd_dc = deepcopy(nd) + assert_correctly_copied(nd, nd_dc) + + nd = ndict({'x':[1,2,3], 'y': {1: 1}}) + nd_dc = deepcopy(nd) + assert_correctly_copied(nd, nd_dc) + nd_dc.x.append(4) + assert nd_dc.x[-1] == 4, "not correctly appended to copied." + assert nd.x[-1] != 4, "also copied to original." \ No newline at end of file diff --git a/zipline/utils/protocol_utils.py b/zipline/utils/protocol_utils.py index 7149a7c1..c74c81f2 100644 --- a/zipline/utils/protocol_utils.py +++ b/zipline/utils/protocol_utils.py @@ -55,6 +55,9 @@ class ndict(MutableMapping): # Abstact Overloads # ----------------- + def __deepcopy__(self, memo): + return ndict(copy.deepcopy(self.__internal)) + def __setattr__(self, key, value): if key == 'cls' or key == '__internal' or '_ndict' in key: super(ndict, self).__setattr__(key, value)