Removes Catalog base class and fixes defects in CatalogBuilder and RegularCatalog detected by Hypothesis.

This commit is contained in:
Robert Smallshire
2015-05-07 21:15:19 +02:00
parent 4c222b96a6
commit c713475e07
2 changed files with 56 additions and 85 deletions
+48 -81
View File
@@ -9,7 +9,6 @@ the CatalogBuilder class which will analyse the contents of the
mapping to find a space and time efficient representation.
"""
from abc import abstractmethod, ABCMeta
from collections import Mapping, Sequence, OrderedDict
from fractions import Fraction
import reprlib
@@ -88,14 +87,14 @@ class CatalogBuilder(object):
index_max = self._catalog[-1][0]
index_stride = measure_stride(index for index, value in self._catalog)
if index_stride is None:
# Dictionary strategy - arbitrary keys and values
return DictionaryCatalog(self._catalog)
value_start = self._catalog[0][1]
value_stop = self._catalog[-1][1]
value_stride = measure_stride(value for index, value in self._catalog)
if index_stride is None and value_stride is None:
# Dictionary strategy - arbitrary keys and values
return DictionaryCatalog(self._catalog)
if index_stride is not None and value_stride == 0:
assert value_start == value_stop
return RegularConstantCatalog(index_min,
@@ -168,34 +167,7 @@ class CatalogBuilder(object):
return True, diff
class Catalog(Mapping, metaclass=ABCMeta):
"""An abstract base class for Catalogs which provides min and max keys and values."""
@abstractmethod
def __init__(self, key_min=None, key_max=None):
"""Must be overridden and called by subclasses.
Args:
key_min: Optional minimum key.
key_max: Optional maximum key.
"""
self._min_key = key_min
self._max_key = key_max
def key_min(self):
"""Minimum key"""
if self._min_key is None:
self._min_key = min(self.keys())
return self._min_key
def key_max(self):
"""Maximum key"""
if self._max_key is None:
self._max_key = max(self.keys())
return self._max_key
class RowMajorCatalog(Catalog):
class RowMajorCatalog(Mapping):
"""A mapping which assumes a row-major ordering of a two-dimensional matrix.
This is the ordering of items in a two-dimensional matrix where in
@@ -224,7 +196,6 @@ class RowMajorCatalog(Catalog):
j_max (int): The maximum j value.
c (int): The constant offset
"""
super().__init__()
self._i_min = i_min
self._i_max = i_max
self._j_min = j_min
@@ -294,12 +265,11 @@ class RowMajorCatalog(Catalog):
self._i_min, self._i_max, self._j_min, self._j_max, self._c)
class DictionaryCatalog(Catalog):
class DictionaryCatalog(Mapping):
"""An immutable, ordered, dictionary mapping.
"""
def __init__(self, items):
super().__init__()
self._items = OrderedDict(items)
def __getitem__(self, key):
@@ -319,7 +289,7 @@ class DictionaryCatalog(Catalog):
self.__class__.__name__, reprlib.repr(self._items.items()))
class RegularConstantCatalog(Catalog):
class RegularConstantCatalog(Mapping):
"""Mapping with keys ordered with regular spacing along the number line.
The values associated with the keys are constant.
@@ -343,10 +313,9 @@ class RegularConstantCatalog(Catalog):
raise ValueError("RegularIndex key range {!r} is not "
"a multiple of stride {!r}".format(
key_stride, key_range))
super().__init__(
key_min=key_min,
key_max=key_max)
self._key_min = key_min
self._key_max = key_max
self._key_stride = key_stride
self._value = value
@@ -356,27 +325,27 @@ class RegularConstantCatalog(Catalog):
return self._value
def __len__(self):
return 1 + (self.key_max() - self.key_min()) / self._key_stride
return 1 + (self._key_max - self._key_min) / self._key_stride
def __contains__(self, key):
return (self.key_min() <= key <= self.key_max()) and \
((key - self.key_min()) % self._key_stride == 0)
return (self._key_min <= key <= self._key_max) and \
((key - self._key_min) % self._key_stride == 0)
def __iter__(self):
return iter(range(self.key_min(),
self.key_max() + 1,
return iter(range(self._key_min,
self._key_max + 1,
self._key_stride))
def __repr__(self):
return '{}(key_min={}, key_max={}, key_stride={}, value={})'.format(
self.__class__.__name__,
self.key_min(),
self.key_max(),
self._key_min,
self._key_max,
self._key_stride,
self._value)
class ConstantCatalog(Catalog):
class ConstantCatalog(Mapping):
"""Mapping with arbitrary keys and a single constant value.
"""
@@ -392,8 +361,7 @@ class ConstantCatalog(Catalog):
key_stride: The difference between successive keys.
value: A value associated with all keys.
"""
super().__init__()
self._items = SortedFrozenSet(keys)
self._keys = SortedFrozenSet(keys)
self._value = value
def __getitem__(self, key):
@@ -402,22 +370,22 @@ class ConstantCatalog(Catalog):
return self._value
def __len__(self):
return len(self._items)
return len(self._keys)
def __contains__(self, key):
return key in self._items
return key in self._keys
def __iter__(self):
return iter(self._items)
return iter(self._keys)
def __repr__(self):
return '{}(keys={}, value={})'.format(
self.__class__.__name__,
reprlib.repr(self._items),
reprlib.repr(self._keys),
self._value)
class RegularCatalog(Catalog):
class RegularCatalog(Mapping):
"""Mapping with keys ordered with regular spacing along the number line.
The values associated with the keys are arbitrary.
@@ -445,17 +413,18 @@ class RegularCatalog(Catalog):
raise ValueError("{} key range {!r} is not "
"a multiple of stride {!r}".format(self.__class__.__name__,
key_stride, key_range))
super(RegularCatalog, self).__init__(key_min=key_min, key_max=key_max)
self._key_min = key_min
self._key_max = key_max
self._key_stride = key_stride
self._values = list(values)
num_keys = key_range // key_stride
num_keys = 1 + key_range // key_stride
if num_keys != len(self._values):
raise ValueError("{} key range and values inconsistent".format(self.__class__.__name__))
def __getitem__(self, key):
if not (self.key_min() <= key <= self.key_max()):
if not (self._key_min <= key <= self._key_max):
raise KeyError("{!r} key {!r} out of range".format(self, key))
offset = key - self.key_min()
offset = key - self._key_min
if offset % self._key_stride != 0:
raise KeyError("{!r} does not contain key {!r}".format(self, key))
index = offset // self._key_stride
@@ -465,24 +434,24 @@ class RegularCatalog(Catalog):
return len(self._values)
def __contains__(self, key):
return (self.key_min() <= key <= self.key_max()) and \
((key - self.key_min()) % self._key_stride == 0)
return (self._key_min <= key <= self._key_max) and \
((key - self._key_min) % self._key_stride == 0)
def __iter__(self):
return iter(range(self.key_min(),
self.key_max() + 1,
return iter(range(self._key_min,
self._key_max + 1,
self._key_stride))
def __repr__(self):
return '{}(key_min={}, key_max={}, key_stride={}, values={})'.format(
self.__class__.__name__,
self.key_min(),
self.key_max(),
self._key_min,
self._key_max,
self._key_stride,
reprlib.repr(self._values))
class LinearRegularCatalog(Catalog):
class LinearRegularCatalog(Mapping):
"""A mapping which assumes a linear relationship between keys and values.
A LinearRegularCatalog predicts the value v from the key according to the
@@ -529,14 +498,12 @@ class LinearRegularCatalog(Catalog):
value_range))
self._value_stride = value_stride
super().__init__(
key_min=key_min,
key_max=key_max)
self._key_min = key_min
self._key_max = key_max
self._value_start = value_start
self._value_stop = value_stop
num_keys = 1 + (self.key_max() - self.key_min()) // self._key_stride
num_keys = 1 + (self._key_max - self._key_min) // self._key_stride
num_values = 1 + (self._value_stop - self._value_start) // self._value_stride
if num_keys != num_values:
raise ValueError("{} inconsistent number of "
@@ -546,34 +513,34 @@ class LinearRegularCatalog(Catalog):
num_values))
self._m = Fraction(self._value_stop - self._value_start,
self.key_max() - self.key_min())
self._key_max - self._key_min)
def __getitem__(self, key):
if not (self.key_min() <= key <= self.key_max()):
if not (self._key_min <= key <= self._key_max):
raise KeyError("{!r} key {!r} out of range".format(self, key))
offset = key - self.key_min()
offset = key - self._key_min
if offset % self._key_stride != 0:
raise KeyError("{!r} does not contain key {!r}".format(self, key))
v = self._m * (key - self.key_min()) + self._value_start
v = self._m * (key - self._key_min) + self._value_start
assert v.denominator == 1
return v.numerator
def __len__(self):
return 1 + (self.key_max() - self.key_min()) // self._key_stride
return 1 + (self._key_max - self._key_min) // self._key_stride
def __contains__(self, key):
return (self.key_min() <= key <= self.key_max()) and \
((key - self.key_min()) % self._key_stride == 0)
return (self._key_min <= key <= self._key_max) and \
((key - self._key_min) % self._key_stride == 0)
def __iter__(self):
return iter(range(self.key_min(), self.key_max() + 1, self._key_stride))
return iter(range(self._key_min, self._key_max + 1, self._key_stride))
def __repr__(self):
return '{}(key_min={}, key_max{}, key_stride={}, value_start={}, value_stop={}, value_stride={})'.format(
self.__class__.__name__,
self.key_min(),
self.key_max(),
self._key_min,
self._key_max,
self._key_stride,
self._value_start,
self._value_stop,
+8 -4
View File
@@ -1,18 +1,22 @@
import unittest
from hypothesis import given, example
from hypothesis.specifiers import dictionary
from hypothesis.specifiers import dictionary, just
from segpy.catalog import CatalogBuilder
class TestCatalogBuilder(unittest.TestCase):
@given(dictionary(int, int))
def test_constructor(self, mapping):
def test_arbitrary_mapping(self, mapping):
builder = CatalogBuilder(mapping)
catalog = builder.create()
shared_items = set(mapping.items()) & set(catalog.items())
self.assertEqual(len(shared_items), len(mapping))
@given(dictionary(int, just(42)))
def test_constant_mapping(self, mapping):
builder = CatalogBuilder(mapping)
catalog = builder.create()
shared_items = set(mapping.items()) & set(catalog.items())
self.assertEqual(len(shared_items), len(mapping))