diff --git a/segpy/catalog.py b/segpy/catalog.py index ad1983b..f4300f7 100644 --- a/segpy/catalog.py +++ b/segpy/catalog.py @@ -9,7 +9,6 @@ the CatalogBuilder class which will analyse the contents of the mapping to find a space and time efficient representation. """ -from abc import abstractmethod, ABCMeta from collections import Mapping, Sequence, OrderedDict from fractions import Fraction import reprlib @@ -88,14 +87,14 @@ class CatalogBuilder(object): index_max = self._catalog[-1][0] index_stride = measure_stride(index for index, value in self._catalog) - if index_stride is None: - # Dictionary strategy - arbitrary keys and values - return DictionaryCatalog(self._catalog) - value_start = self._catalog[0][1] value_stop = self._catalog[-1][1] value_stride = measure_stride(value for index, value in self._catalog) + if index_stride is None and value_stride is None: + # Dictionary strategy - arbitrary keys and values + return DictionaryCatalog(self._catalog) + if index_stride is not None and value_stride == 0: assert value_start == value_stop return RegularConstantCatalog(index_min, @@ -168,34 +167,7 @@ class CatalogBuilder(object): return True, diff -class Catalog(Mapping, metaclass=ABCMeta): - """An abstract base class for Catalogs which provides min and max keys and values.""" - - @abstractmethod - def __init__(self, key_min=None, key_max=None): - """Must be overridden and called by subclasses. - - Args: - key_min: Optional minimum key. - key_max: Optional maximum key. - """ - self._min_key = key_min - self._max_key = key_max - - def key_min(self): - """Minimum key""" - if self._min_key is None: - self._min_key = min(self.keys()) - return self._min_key - - def key_max(self): - """Maximum key""" - if self._max_key is None: - self._max_key = max(self.keys()) - return self._max_key - - -class RowMajorCatalog(Catalog): +class RowMajorCatalog(Mapping): """A mapping which assumes a row-major ordering of a two-dimensional matrix. This is the ordering of items in a two-dimensional matrix where in @@ -224,7 +196,6 @@ class RowMajorCatalog(Catalog): j_max (int): The maximum j value. c (int): The constant offset """ - super().__init__() self._i_min = i_min self._i_max = i_max self._j_min = j_min @@ -294,12 +265,11 @@ class RowMajorCatalog(Catalog): self._i_min, self._i_max, self._j_min, self._j_max, self._c) -class DictionaryCatalog(Catalog): +class DictionaryCatalog(Mapping): """An immutable, ordered, dictionary mapping. """ def __init__(self, items): - super().__init__() self._items = OrderedDict(items) def __getitem__(self, key): @@ -319,7 +289,7 @@ class DictionaryCatalog(Catalog): self.__class__.__name__, reprlib.repr(self._items.items())) -class RegularConstantCatalog(Catalog): +class RegularConstantCatalog(Mapping): """Mapping with keys ordered with regular spacing along the number line. The values associated with the keys are constant. @@ -343,10 +313,9 @@ class RegularConstantCatalog(Catalog): raise ValueError("RegularIndex key range {!r} is not " "a multiple of stride {!r}".format( key_stride, key_range)) - super().__init__( - key_min=key_min, - key_max=key_max) + self._key_min = key_min + self._key_max = key_max self._key_stride = key_stride self._value = value @@ -356,27 +325,27 @@ class RegularConstantCatalog(Catalog): return self._value def __len__(self): - return 1 + (self.key_max() - self.key_min()) / self._key_stride + return 1 + (self._key_max - self._key_min) / self._key_stride def __contains__(self, key): - return (self.key_min() <= key <= self.key_max()) and \ - ((key - self.key_min()) % self._key_stride == 0) + return (self._key_min <= key <= self._key_max) and \ + ((key - self._key_min) % self._key_stride == 0) def __iter__(self): - return iter(range(self.key_min(), - self.key_max() + 1, + return iter(range(self._key_min, + self._key_max + 1, self._key_stride)) def __repr__(self): return '{}(key_min={}, key_max={}, key_stride={}, value={})'.format( self.__class__.__name__, - self.key_min(), - self.key_max(), + self._key_min, + self._key_max, self._key_stride, self._value) -class ConstantCatalog(Catalog): +class ConstantCatalog(Mapping): """Mapping with arbitrary keys and a single constant value. """ @@ -392,8 +361,7 @@ class ConstantCatalog(Catalog): key_stride: The difference between successive keys. value: A value associated with all keys. """ - super().__init__() - self._items = SortedFrozenSet(keys) + self._keys = SortedFrozenSet(keys) self._value = value def __getitem__(self, key): @@ -402,22 +370,22 @@ class ConstantCatalog(Catalog): return self._value def __len__(self): - return len(self._items) + return len(self._keys) def __contains__(self, key): - return key in self._items + return key in self._keys def __iter__(self): - return iter(self._items) + return iter(self._keys) def __repr__(self): return '{}(keys={}, value={})'.format( self.__class__.__name__, - reprlib.repr(self._items), + reprlib.repr(self._keys), self._value) -class RegularCatalog(Catalog): +class RegularCatalog(Mapping): """Mapping with keys ordered with regular spacing along the number line. The values associated with the keys are arbitrary. @@ -445,17 +413,18 @@ class RegularCatalog(Catalog): raise ValueError("{} key range {!r} is not " "a multiple of stride {!r}".format(self.__class__.__name__, key_stride, key_range)) - super(RegularCatalog, self).__init__(key_min=key_min, key_max=key_max) + self._key_min = key_min + self._key_max = key_max self._key_stride = key_stride self._values = list(values) - num_keys = key_range // key_stride + num_keys = 1 + key_range // key_stride if num_keys != len(self._values): raise ValueError("{} key range and values inconsistent".format(self.__class__.__name__)) def __getitem__(self, key): - if not (self.key_min() <= key <= self.key_max()): + if not (self._key_min <= key <= self._key_max): raise KeyError("{!r} key {!r} out of range".format(self, key)) - offset = key - self.key_min() + offset = key - self._key_min if offset % self._key_stride != 0: raise KeyError("{!r} does not contain key {!r}".format(self, key)) index = offset // self._key_stride @@ -465,24 +434,24 @@ class RegularCatalog(Catalog): return len(self._values) def __contains__(self, key): - return (self.key_min() <= key <= self.key_max()) and \ - ((key - self.key_min()) % self._key_stride == 0) + return (self._key_min <= key <= self._key_max) and \ + ((key - self._key_min) % self._key_stride == 0) def __iter__(self): - return iter(range(self.key_min(), - self.key_max() + 1, + return iter(range(self._key_min, + self._key_max + 1, self._key_stride)) def __repr__(self): return '{}(key_min={}, key_max={}, key_stride={}, values={})'.format( self.__class__.__name__, - self.key_min(), - self.key_max(), + self._key_min, + self._key_max, self._key_stride, reprlib.repr(self._values)) -class LinearRegularCatalog(Catalog): +class LinearRegularCatalog(Mapping): """A mapping which assumes a linear relationship between keys and values. A LinearRegularCatalog predicts the value v from the key according to the @@ -529,14 +498,12 @@ class LinearRegularCatalog(Catalog): value_range)) self._value_stride = value_stride - super().__init__( - key_min=key_min, - key_max=key_max) - + self._key_min = key_min + self._key_max = key_max self._value_start = value_start self._value_stop = value_stop - num_keys = 1 + (self.key_max() - self.key_min()) // self._key_stride + num_keys = 1 + (self._key_max - self._key_min) // self._key_stride num_values = 1 + (self._value_stop - self._value_start) // self._value_stride if num_keys != num_values: raise ValueError("{} inconsistent number of " @@ -546,34 +513,34 @@ class LinearRegularCatalog(Catalog): num_values)) self._m = Fraction(self._value_stop - self._value_start, - self.key_max() - self.key_min()) + self._key_max - self._key_min) def __getitem__(self, key): - if not (self.key_min() <= key <= self.key_max()): + if not (self._key_min <= key <= self._key_max): raise KeyError("{!r} key {!r} out of range".format(self, key)) - offset = key - self.key_min() + offset = key - self._key_min if offset % self._key_stride != 0: raise KeyError("{!r} does not contain key {!r}".format(self, key)) - v = self._m * (key - self.key_min()) + self._value_start + v = self._m * (key - self._key_min) + self._value_start assert v.denominator == 1 return v.numerator def __len__(self): - return 1 + (self.key_max() - self.key_min()) // self._key_stride + return 1 + (self._key_max - self._key_min) // self._key_stride def __contains__(self, key): - return (self.key_min() <= key <= self.key_max()) and \ - ((key - self.key_min()) % self._key_stride == 0) + return (self._key_min <= key <= self._key_max) and \ + ((key - self._key_min) % self._key_stride == 0) def __iter__(self): - return iter(range(self.key_min(), self.key_max() + 1, self._key_stride)) + return iter(range(self._key_min, self._key_max + 1, self._key_stride)) def __repr__(self): return '{}(key_min={}, key_max{}, key_stride={}, value_start={}, value_stop={}, value_stride={})'.format( self.__class__.__name__, - self.key_min(), - self.key_max(), + self._key_min, + self._key_max, self._key_stride, self._value_start, self._value_stop, diff --git a/test/test_catalog.py b/test/test_catalog.py index 3d372ff..ec1c89e 100644 --- a/test/test_catalog.py +++ b/test/test_catalog.py @@ -1,18 +1,22 @@ import unittest from hypothesis import given, example -from hypothesis.specifiers import dictionary +from hypothesis.specifiers import dictionary, just from segpy.catalog import CatalogBuilder class TestCatalogBuilder(unittest.TestCase): @given(dictionary(int, int)) - def test_constructor(self, mapping): + def test_arbitrary_mapping(self, mapping): builder = CatalogBuilder(mapping) catalog = builder.create() shared_items = set(mapping.items()) & set(catalog.items()) self.assertEqual(len(shared_items), len(mapping)) - - + @given(dictionary(int, just(42))) + def test_constant_mapping(self, mapping): + builder = CatalogBuilder(mapping) + catalog = builder.create() + shared_items = set(mapping.items()) & set(catalog.items()) + self.assertEqual(len(shared_items), len(mapping))