mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 02:28:41 +08:00
cf4d7ca353
Changed all unittest2 imports to unittest
276 lines
8.0 KiB
Python
276 lines
8.0 KiB
Python
#
|
|
# Copyright 2012 Quantopian, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import pytz
|
|
|
|
from unittest import TestCase
|
|
from itertools import chain, izip_longest
|
|
from datetime import datetime, timedelta
|
|
from collections import deque
|
|
|
|
from zipline import ndict
|
|
from zipline.gens.sort import (
|
|
date_sort,
|
|
ready,
|
|
done,
|
|
queue_is_ready,
|
|
queue_is_done
|
|
)
|
|
from zipline.gens.utils import alternate, done_message
|
|
from zipline.sources import SpecificEquityTrades
|
|
from zipline.gens.composites import date_sorted_sources
|
|
|
|
|
|
class HelperTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
pass
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
def test_individual_queue_logic(self):
|
|
queue = deque()
|
|
# Empty queues are neither done nor ready.
|
|
assert not queue_is_ready(queue)
|
|
assert not queue_is_done(queue)
|
|
|
|
queue.append(to_dt('foo'))
|
|
assert queue_is_ready(queue)
|
|
assert not queue_is_done(queue)
|
|
|
|
queue.appendleft(to_dt('DONE'))
|
|
assert queue_is_ready(queue)
|
|
|
|
# Checking done when we have a message after done will trip an assert.
|
|
self.assertRaises(AssertionError, queue_is_done, queue)
|
|
|
|
queue.pop()
|
|
assert queue_is_ready(queue)
|
|
assert queue_is_done(queue)
|
|
|
|
def test_pop_logic(self):
|
|
sources = {}
|
|
ids = ['a', 'b', 'c']
|
|
for id in ids:
|
|
sources[id] = deque()
|
|
|
|
assert not ready(sources)
|
|
assert not done(sources)
|
|
|
|
# All sources must have a message to be ready/done
|
|
sources['a'].append(to_dt("datetime"))
|
|
assert not ready(sources)
|
|
assert not done(sources)
|
|
sources['a'].pop()
|
|
|
|
for id in ids:
|
|
sources[id].append(to_dt("datetime"))
|
|
|
|
assert ready(sources)
|
|
assert not done(sources)
|
|
|
|
for id in ids:
|
|
sources[id].appendleft(to_dt("DONE"))
|
|
|
|
# ["DONE", message] will trip an assert in queue_is_done.
|
|
assert ready(sources)
|
|
self.assertRaises(AssertionError, done, sources)
|
|
|
|
for id in ids:
|
|
sources[id].pop()
|
|
|
|
assert ready(sources)
|
|
assert done(sources)
|
|
|
|
|
|
class DateSortTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
pass
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
def run_date_sort(self, event_stream, expected, source_ids):
|
|
"""
|
|
Take a list of events, their source_ids, and an expected sorting.
|
|
Assert that date_sort's output agrees with expected.
|
|
"""
|
|
sort_out = date_sort(event_stream, source_ids)
|
|
for m1, m2 in izip_longest(sort_out, expected):
|
|
assert m1 == m2
|
|
|
|
def test_single_source(self):
|
|
|
|
# Just using the built-in defaults. See
|
|
# zipline.sources.py
|
|
source = SpecificEquityTrades()
|
|
expected = list(source)
|
|
source.rewind()
|
|
# The raw source doesn't handle done messaging, so we need to
|
|
# append a done message for sort to work properly.
|
|
with_done = chain(source, [done_message(source.get_hash())])
|
|
self.run_date_sort(with_done, expected, [source.get_hash()])
|
|
|
|
def test_multi_source(self):
|
|
|
|
filter = [2, 3]
|
|
args_a = tuple()
|
|
kwargs_a = {
|
|
'count': 100,
|
|
'sids': [1, 2, 3],
|
|
'start': datetime(2012, 1, 3, 15, tzinfo=pytz.utc),
|
|
'delta': timedelta(minutes=6),
|
|
'filter': filter
|
|
}
|
|
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
|
|
|
args_b = tuple()
|
|
kwargs_b = {
|
|
'count': 100,
|
|
'sids': [2, 3, 4],
|
|
'start': datetime(2012, 1, 3, 15, tzinfo=pytz.utc),
|
|
'delta': timedelta(minutes=5),
|
|
'filter': filter
|
|
}
|
|
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
|
|
|
all_events = list(chain(source_a, source_b))
|
|
|
|
# The expected output is all events, sorted by dt with
|
|
# source_id as a tiebreaker.
|
|
expected = sorted(all_events, comp)
|
|
source_ids = [source_a.get_hash(), source_b.get_hash()]
|
|
|
|
# Generating the events list consumes the sources. Rewind them
|
|
# for testing.
|
|
source_a.rewind()
|
|
source_b.rewind()
|
|
|
|
# Append a done message to each source.
|
|
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
|
|
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
|
|
|
|
interleaved = alternate(with_done_a, with_done_b)
|
|
|
|
# Test sort with alternating messages from source_a and
|
|
# source_b.
|
|
self.run_date_sort(interleaved, expected, source_ids)
|
|
|
|
source_a.rewind()
|
|
source_b.rewind()
|
|
with_done_a = chain(source_a, [done_message(source_a.get_hash())])
|
|
with_done_b = chain(source_b, [done_message(source_b.get_hash())])
|
|
|
|
sequential = chain(with_done_a, with_done_b)
|
|
|
|
# Test sort with all messages from a, followed by all messages
|
|
# from b.
|
|
|
|
self.run_date_sort(sequential, expected, source_ids)
|
|
|
|
def test_sort_composite(self):
|
|
|
|
filter = [1, 2]
|
|
|
|
#Set up source a. One hour between events.
|
|
args_a = tuple()
|
|
kwargs_a = {
|
|
'count': 100,
|
|
'sids': [1],
|
|
'start': datetime(2012, 6, 6, 0),
|
|
'delta': timedelta(hours=1),
|
|
'filter': filter
|
|
}
|
|
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
|
|
|
#Set up source b. One day between events.
|
|
args_b = tuple()
|
|
kwargs_b = {
|
|
'count': 50,
|
|
'sids': [2],
|
|
'start': datetime(2012, 6, 6, 0),
|
|
'delta': timedelta(days=1),
|
|
'filter': filter
|
|
}
|
|
source_b = SpecificEquityTrades(*args_b, **kwargs_b)
|
|
|
|
#Set up source c. One minute between events.
|
|
args_c = tuple()
|
|
kwargs_c = {
|
|
'count': 150,
|
|
'sids': [1, 2],
|
|
'start': datetime(2012, 6, 6, 0),
|
|
'delta': timedelta(minutes=1),
|
|
'filter': filter
|
|
}
|
|
source_c = SpecificEquityTrades(*args_c, **kwargs_c)
|
|
# Set up source d. This should produce no events because the
|
|
# internal sids don't match the filter.
|
|
args_d = tuple()
|
|
kwargs_d = {
|
|
'count': 50,
|
|
'sids': [3],
|
|
'start': datetime(2012, 6, 6, 0),
|
|
'delta': timedelta(minutes=1),
|
|
'filter': filter
|
|
}
|
|
source_d = SpecificEquityTrades(*args_d, **kwargs_d)
|
|
sources = [source_a, source_b, source_c, source_d]
|
|
hashes = [source.get_hash() for source in sources]
|
|
|
|
sort_out = date_sorted_sources(*sources)
|
|
|
|
# Read all the values from sort and assert that they arrive in
|
|
# the correct sorting with the expected hash values.
|
|
to_list = list(sort_out)
|
|
copy = to_list[:]
|
|
|
|
# We should have 300 events (100 from a, 150 from b, 50 from c)
|
|
assert len(to_list) == 300
|
|
|
|
for e in to_list:
|
|
# All events should match one of our expected source_ids.
|
|
assert e.source_id in hashes
|
|
# But none of them should match source_d.
|
|
assert e.source_id != source_d.get_hash()
|
|
|
|
# The events should be sorted by dt, with source_id as tiebreaker.
|
|
expected = sorted(copy, comp)
|
|
|
|
assert to_list == expected
|
|
|
|
|
|
def compare_by_dt_source_id(x, y):
|
|
if x.dt < y.dt:
|
|
return -1
|
|
elif x.dt > y.dt:
|
|
return 1
|
|
|
|
elif x.source_id < y.source_id:
|
|
return -1
|
|
elif x.source_id > y.source_id:
|
|
return 1
|
|
else:
|
|
return 0
|
|
|
|
#Alias for ease of use
|
|
comp = compare_by_dt_source_id
|
|
|
|
|
|
def to_dt(msg):
|
|
return ndict({'dt': msg})
|