Code reorganisation - segpy is now a proper Python package. Rework of the extended textual header for symmetry with the reader. Some additional tests.

This commit is contained in:
Robert Smallshire
2015-01-28 21:42:52 +01:00
parent 246b8515e4
commit 255a490914
19 changed files with 669 additions and 292 deletions
-42
View File
@@ -1,42 +0,0 @@
COMMON = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789:_- '
EBCDIC = set(COMMON.encode('cp037'))
ASCII = set(COMMON.encode('ascii'))
def guess_encoding(bs, threshold=0.5):
"""Try to determine whether the encoding of byte stream b is an ASCII string or an EBCDIC string.
Args:
bs: A byte string (Python 2 - str; Python 3 - bytes)
Returns:
A string which can be used with the Python encoding functions: 'cp037' for EBCDIC, 'ascii' for ASCII or None
if neither.
"""
ebcdic_count = 0
ascii_count = 0
count = 0
for b in bs:
if b in EBCDIC:
ebcdic_count +=1
if b in ASCII:
ascii_count +=1
count += 1
if count == 0:
return None
ebcdic_freq = ebcdic_count / count
ascii_freq = ascii_count / count
if ebcdic_freq < threshold and ascii_freq < threshold:
return None
if ebcdic_freq < threshold and ascii_freq >= threshold:
return 'ascii'
if ebcdic_freq >= threshold and ascii_freq < threshold:
return 'cp037'
return None
+1
View File
@@ -0,0 +1 @@
@@ -2,7 +2,7 @@
SEG Y Header Definition
"""
from revisions import SEGY_REVISION_0, SEGY_REVISION_1
from segpy.revisions import SEGY_REVISION_0, SEGY_REVISION_1
HEADER_DEF = {"Job": {"pos": 3200, "type": "int32", "def": 0}}
HEADER_DEF["Line"] = {"pos": 3204, "type": "int32", "def": 0}
+120 -64
View File
@@ -1,7 +1,9 @@
from abc import abstractmethod, ABCMeta
from collections import Mapping, Sequence, OrderedDict
from fractions import Fraction
from portability import reprlib
from util import contains_duplicates, measure_stride, minmax
from segpy.portability import reprlib
from segpy.util import contains_duplicates, measure_stride, minmax
class CatalogBuilder:
@@ -155,7 +157,51 @@ class CatalogBuilder:
return True, diff
class RowMajorCatalog(Mapping):
class Catalog(Mapping):
"""An abstract base class for Catalogs which provides min and max keys and values."""
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self, key_min=None, key_max=None, value_min=None, value_max=None):
"""Must be overridden and called by subclasses.
Args:
key_min: Optional minimum key.
key_max: Optional maximum key.
value_min: Optional minimum value.
value_max: Optional maximum value.
"""
self._min_key = key_min
self._max_key = key_max
self._min_value = value_min
self._max_value = value_max
def key_min(self):
"""Minimum key"""
if self._min_key is None:
self._min_key = min(self.keys())
return self._min_key
def key_max(self):
"""Maximum key"""
if self._max_key is None:
self._max_key = max(self.keys())
return self._max_key
def value_min(self):
"""Minimum value"""
if self._min_value is None:
self._min_value = min(self.values())
return self._min_value
def value_max(self):
"""Maximum value"""
if self._max_value is None:
self._max_value = max(self.values())
return self._max_value
class RowMajorCatalog(Catalog):
"""A mapping which assumes a row-major ordering of a two-dimensional matrix.
This is the ordering of items in a two-dimensional matrix where in
@@ -184,6 +230,7 @@ class RowMajorCatalog(Mapping):
j_max (int): The maximum j value.
c (int): The constant offset
"""
super().__init__()
self._i_min = i_min
self._i_max = i_max
self._j_min = j_min
@@ -210,13 +257,21 @@ class RowMajorCatalog(Mapping):
"""Maximum j value"""
return self._j_max
def min(self):
"""Minimum (i, j) value"""
def key_min(self):
"""Minimum (i, j) key"""
return self._i_min, self._j_min
def max(self):
"""Maximum (i, j) value"""
return self._j_min, self._j_max
def key_max(self):
"""Maximum (i, j) key"""
return self._i_max, self._j_max
def value_min(self):
"""Minimum value at key_min"""
return self[self.key_min()]
def value_max(self):
"""Maximum value at key_max"""
return self[self.key_max()]
def __getitem__(self, key):
i, j = key
@@ -245,11 +300,12 @@ class RowMajorCatalog(Mapping):
self._i_min, self._i_max, self._j_min, self._j_max, self._c)
class DictionaryCatalog(Mapping):
class DictionaryCatalog(Catalog):
"""An immutable, ordered, dictionary mapping.
"""
def __init__(self, items):
super().__init__()
self._items = OrderedDict(items)
def __getitem__(self, key):
@@ -269,7 +325,7 @@ class DictionaryCatalog(Mapping):
self.__class__.__name__, reprlib.repr(self._items.items()))
class RegularConstantCatalog(Mapping):
class RegularConstantCatalog(Catalog):
"""Mapping with keys ordered with regular spacing along the number line.
The values associated with the keys are constant.
@@ -293,41 +349,40 @@ class RegularConstantCatalog(Mapping):
raise ValueError("RegularIndex key range {!r} is not "
"a multiple of stride {!r}".format(
key_stride, key_range))
self._key_min = key_min
self._key_max = key_max
super().__init__(key_min=key_min,
key_max=key_max,
value_min=value,
value_max=value)
self._key_stride = key_stride
self._value = value
def __getitem__(self, key):
if key not in self:
raise KeyError("{!r} does not contain key {!r}".format(self, key))
return self._value
return self.value_min()
def __len__(self):
return 1 + (self._key_max - self._key_min) / self._key_stride
return 1 + (self.key_max() - self.key_min()) / self._key_stride
def __contains__(self, key):
return (self._key_min <= key <= self._key_max) and \
((key - self._key_min) % self._key_stride == 0)
return (self.key_min() <= key <= self.key_max()) and \
((key - self.key_min()) % self._key_stride == 0)
def __iter__(self):
return iter(range(self._key_min,
self._key_max + 1,
return iter(range(self.key_min(),
self.key_max() + 1,
self._key_stride))
def __repr__(self):
return '{}({}, {}, {}, {})'.format(
self.__class__.__name__,
self._key_min,
self._key_max,
self.key_min(),
self.key_max(),
self._key_stride,
self._value)
self.value_min())
class ConstantCatalog(Mapping):
"""Mapping with keys ordered with regular spacing along the number line.
The values associated with the keys are constant.
class ConstantCatalog(Catalog):
"""Mapping with arbitrary keys and a single constant value.
"""
def __init__(self, keys, value):
@@ -342,13 +397,13 @@ class ConstantCatalog(Mapping):
key_stride: The difference between successive keys.
value: A value associated with all keys.
"""
super().__init__(value_min=value, value_max=value)
self._items = frozenset(keys)
self._value = value
def __getitem__(self, key):
if key not in self:
raise KeyError("{!r} does not contain key {!r}".format(self, key))
return self._value
return self.value_min()
def __len__(self):
return len(self._items)
@@ -363,10 +418,10 @@ class ConstantCatalog(Mapping):
return '{}({}, {})'.format(
self.__class__.__name__,
reprlib.repr(self._items),
self._value)
self.value_min())
class RegularCatalog(Mapping):
class RegularCatalog(Catalog):
"""Mapping with keys ordered with regular spacing along the number line.
The values associated with the keys are arbitrary.
@@ -391,21 +446,20 @@ class RegularCatalog(Mapping):
"""
key_range = key_max - key_min
if key_range % key_stride != 0:
raise ValueError("RegularIndex key range {!r} is not "
"a multiple of stride {!r}".format(
key_stride, key_range))
self._key_min = key_min
self._key_max = key_max
raise ValueError("{} key range {!r} is not "
"a multiple of stride {!r}".format(self.__class__.__name__,
key_stride, key_range))
super().__init__(key_min=key_min, key_max=key_max)
self._key_stride = key_stride
self._values = list(values)
num_keys = key_range // key_stride
if num_keys != len(self._values):
raise ValueError("RegularIndex key range and values inconsistent")
raise ValueError("{} key range and values inconsistent".format(self.__class__.__name__))
def __getitem__(self, key):
if not (self._key_min <= key <= self._key_max):
if not (self.key_min() <= key <= self.key_max()):
raise KeyError("{!r} key {!r} out of range".format(self, key))
offset = key - self._key_min
offset = key - self.key_min()
if offset % self._key_stride != 0:
raise KeyError("{!r} does not contain key {!r}".format(self, key))
index = offset // self._key_stride
@@ -415,24 +469,24 @@ class RegularCatalog(Mapping):
return len(self._values)
def __contains__(self, key):
return (self._key_min <= key <= self._key_max) and \
((key - self._key_min) % self._key_stride == 0)
return (self.key_min() <= key <= self.key_max()) and \
((key - self.key_min()) % self._key_stride == 0)
def __iter__(self):
return iter(range(self._key_min,
self._key_max + 1,
return iter(range(self.key_min(),
self.key_max() + 1,
self._key_stride))
def __repr__(self):
return '{}({}, {}, {}, {})'.format(
self.__class__.__name__,
self._key_min,
self._key_max,
self.key_min(),
self.key_max(),
self._key_stride,
reprlib.repr(self._values))
class LinearRegularCatalog(Mapping):
class LinearRegularCatalog(Catalog):
"""A mapping which assumes a linear relationship between keys and values.
This is the ordering of items in a two-dimensional matrix where in
@@ -472,8 +526,6 @@ class LinearRegularCatalog(Mapping):
self.__class__.__name__,
key_stride,
key_range))
self._key_min = key_min
self._key_max = key_max
self._key_stride = key_stride
value_range = value_max - value_min
@@ -483,12 +535,16 @@ class LinearRegularCatalog(Mapping):
self.__class__.__name__,
value_stride,
value_range))
self._value_min = value_min
self._value_max = value_max
self._value_stride = value_stride
num_keys = (self._key_max - self._key_min) // self._key_stride
num_values = (self._value_max - self._value_min) // self._value_stride
super().__init__(key_min=key_min,
key_max=key_max,
value_min=value_min,
value_max=value_max)
num_keys = (self.key_max() - self.key_min()) // self._key_stride
num_values = (self.value_max() - self.value_min()) // self._value_stride
if num_keys != num_values:
raise ValueError("{} inconsistent number of "
"keys {} and values {}".format(
@@ -496,36 +552,36 @@ class LinearRegularCatalog(Mapping):
num_keys,
num_values))
self._m = Fraction(self._value_max - self._value_min,
self._key_max - self._key_min)
self._m = Fraction(self.value_max() - self.value_min(),
self.key_max() - self.key_min())
def __getitem__(self, key):
if not (self._key_min <= key <= self._key_max):
if not (self.key_min() <= key <= self.key_max()):
raise KeyError("{!r} key {!r} out of range".format(self, key))
offset = key - self._key_min
offset = key - self.key_min()
if offset % self._key_stride != 0:
raise KeyError("{!r} does not contain key {!r}".format(self, key))
v = self._m * (key - self._key_min) + self._value_min
v = self._m * (key - self.key_min()) + self.value_min()
assert v.denominator == 1
return v.numerator
def __len__(self):
return 1 + (self._key_max - self._key_min) // self._key_stride
return 1 + (self.key_max() - self.key_min()) // self._key_stride
def __contains__(self, key):
return (self._key_min <= key <= self._key_max) and \
((key - self._key_min) % self._key_stride == 0)
return (self.key_min() <= key <= self.key_max()) and \
((key - self.key_min()) % self._key_stride == 0)
def __iter__(self):
return iter(range(self._key_min, self._key_max + 1, self._key_stride))
return iter(range(self.key_min(), self.key_max() + 1, self._key_stride))
def __repr__(self):
return '{}({}, {}, {}, {}, {}, {})'.format(
self.__class__.__name__,
self._key_min,
self._key_max,
self.key_min(),
self.key_max(),
self._key_stride,
self._value_min,
self._value_max,
self.value_min(),
self.value_max(),
self._value_stride)
View File
+69
View File
@@ -0,0 +1,69 @@
ASCII = 'ascii'
EBCDIC = 'cp037'
SUPPORTED_ENCODINGS = (ASCII, EBCDIC)
class UnsupportedEncodingError(Exception):
def __init__(self, text, encoding):
self._encoding = encoding
super(UnsupportedEncodingError, self).__init__(text)
@property
def encoding(self):
return self._encoding
def __str__(self):
return "{} not supported for encoding {}".format(self.args[0], self._encoding)
def __repr__(self):
return "{}({!r}, {!r}".format(self.__class__.__name__, self.args[0], self._encoding)
def is_supported_encoding(encoding):
return encoding in SUPPORTED_ENCODINGS
COMMON_CHARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789:_- '
COMMON_EBCDIC_CHARS = set(COMMON_CHARS.encode(EBCDIC))
COMMON_ASCII_CHARS = set(COMMON_CHARS.encode(ASCII))
def guess_encoding(bs, threshold=0.5):
"""Try to determine whether the encoding of byte stream b is an ASCII string or an EBCDIC string.
Args:
bs: A byte string (Python 2 - str; Python 3 - bytes)
Returns:
A string which can be used with the Python encoding functions: 'cp037' for EBCDIC, 'ascii' for ASCII or None
if neither.
"""
ebcdic_count = 0
ascii_count = 0
count = 0
for b in bs:
if b in COMMON_EBCDIC_CHARS:
ebcdic_count +=1
if b in COMMON_ASCII_CHARS:
ascii_count +=1
count += 1
if count == 0:
return None
ebcdic_freq = ebcdic_count / count
ascii_freq = ascii_count / count
if ebcdic_freq < threshold and ascii_freq < threshold:
return None
if ebcdic_freq < threshold and ascii_freq >= threshold:
return ASCII
if ebcdic_freq >= threshold and ascii_freq < threshold:
return EBCDIC
return None
+1
View File
@@ -0,0 +1 @@
+32
View File
@@ -0,0 +1,32 @@
"""Optional interoperability with Numpy."""
import numpy
NUMPY_DTYPES = {'ibm': numpy.dtype('f4'),
'l': numpy.dtype('i4'),
'h': numpy.dtype('i2'),
'f': numpy.dtype('f4'),
'b': numpy.dtype('i1')}
def make_dtype(data_sample_format):
"""Convert a SEG Y data sample format to a compatible numpy dtype.
Note :
IBM float data sample formats ('ibm') will correspond to IEEE float data types.
Args:
data_sample_format: A data sample format string.
Returns:
A numpy.dtype instance.
Raises:
ValueError: For unrecognised data sample format strings.
"""
try:
return NUMPY_DTYPES[data_sample_format]
except KeyError:
raise ValueError("Unknown data sample format string {!r}".format(data_sample_format))
+2 -2
View File
@@ -1,8 +1,8 @@
from __future__ import print_function
import sys
from math import frexp, isnan, isinf
from portability import long_int, byte_string, four_bytes
from segpy.portability import long_int, byte_string, four_bytes
_IBM_FLOAT32_BITS_PRECISION = 24
+171 -88
View File
@@ -1,18 +1,19 @@
from __future__ import print_function
from array import array
from portability import seekable
from util import file_length, filename_from_handle
from datatypes import DATA_SAMPLE_FORMAT, CTYPE_DESCRIPTION, CTYPES, size_in_bytes
from toolkit import (extract_revision,
bytes_per_sample,
read_binary_reel_header,
read_trace_header,
catalog_traces,
read_binary_values,
compile_trace_header_format,
REEL_HEADER_NUM_BYTES,
TRACE_HEADER_NUM_BYTES, read_textual_reel_header, read_extended_textual_headers)
from segpy.portability import seekable
from segpy.util import file_length, filename_from_handle
from segpy.datatypes import DATA_SAMPLE_FORMAT, CTYPE_DESCRIPTION, CTYPES, size_in_bytes
from segpy.toolkit import (extract_revision,
bytes_per_sample,
read_binary_reel_header,
read_trace_header,
catalog_traces,
read_binary_values,
compile_trace_header_format,
REEL_HEADER_NUM_BYTES,
TRACE_HEADER_NUM_BYTES,
read_textual_reel_header,
read_extended_textual_headers)
def create_reader(fh, encoding=None, endian='>', progress=None):
@@ -107,18 +108,17 @@ def create_reader(fh, encoding=None, endian='>', progress=None):
class SegYReader(object):
"""A basic SEG Y reader.
Use to obtain read the reel header, the trace headers or trace
values. Traces can be accessed only by trace index.
Use to obtain read the reel header, the trace_samples headers or trace_samples
values. Traces can be accessed only by trace_samples index.
"""
def __init__(self,
fh,
textual_reel_header,
binary_reel_header,
extended_textual_header,
extended_textual_headers,
trace_offset_catalog,
trace_length_catalog,
endian='>'):
@@ -137,13 +137,13 @@ class SegYReader(object):
binary_reel_header: A dictionary containing reel header data.
extended_textual_header: A Unicode string (which may be empty).
extended_textual_headers: A sequence of sequences of Unicode strings.
trace_catalog: A mapping from zero-based trace index to
trace_offset_catalog: A mapping from zero-based trace_samples index to
the byte-offset to individual traces within the file.
trace_length_catalog: A mapping from zero-based trace index to the
number of samples in that trace.
trace_length_catalog: A mapping from zero-based trace_samples index to the
number of samples in that trace_samples.
endian: '>' for big-endian data (the standard and default), '<' for
little-endian (non-standard)
@@ -155,7 +155,8 @@ class SegYReader(object):
self._textual_reel_header = textual_reel_header
self._binary_reel_header = binary_reel_header
self._extended_textual_header = extended_textual_header
self._extended_textual_headers = extended_textual_headers
self._trace_offset_catalog = trace_offset_catalog
self._trace_length_catalog = trace_length_catalog
@@ -164,7 +165,7 @@ class SegYReader(object):
self._binary_reel_header, self.revision)
def trace_indexes(self):
"""An iterator over zero-based trace indexes.
"""An iterator over zero-based trace_samples indexes.
Returns:
An iterator which yields integers in the range zero to
@@ -176,8 +177,16 @@ class SegYReader(object):
"""The number of traces"""
return len(self._trace_offset_catalog)
def read_trace(self, trace_index, start=None, stop=None):
"""Read a specific trace.
def max_num_trace_samples(self):
"""The number of samples in the trace_samples with the most samples."""
return self._trace_length_catalog.value_max()
def num_trace_samples(self, trace_index):
"""The number of samples in the specified trace_samples."""
return self._trace_length_catalog[trace_index]
def trace_samples(self, trace_index, start=None, stop=None):
"""Read a specific trace_samples.
Args:
trace_index: An integer in the range zero to num_traces() - 1
@@ -189,27 +198,27 @@ class SegYReader(object):
slice convention this is one beyond the end.
Returns:
A sequence of numeric trace samples.
A sequence of numeric trace_samples samples.
Example:
first_trace_samples = segy_reader.read_trace(0)
part_of_second_trace_samples = segy_reader.read_trace(1, 1000, 2000)
first_trace_samples = segy_reader.trace_samples(0)
part_of_second_trace_samples = segy_reader.trace_samples(1, 1000, 2000)
"""
if not (0 <= trace_index < self.num_traces()):
raise ValueError("Trace index out of range.")
num_samples_in_trace = self._trace_length_catalog[trace_index]
num_samples_in_trace = self.num_trace_samples(trace_index)
start_sample = start if start is not None else 0
stop_sample = stop if stop is not None else num_samples_in_trace
if not (0 <= stop_sample <= num_samples_in_trace):
raise ValueError("read_trace(): stop value {} out of range 0 to {}"
raise ValueError("trace_samples(): stop value {} out of range 0 to {}"
.format(stop, num_samples_in_trace))
if not (0 <= start_sample <= stop_sample):
raise ValueError("read_trace(): start value {} out of range 0 to {}"
raise ValueError("trace_samples(): start value {} out of range 0 to {}"
.format(start, stop_sample))
dsf = self._binary_reel_header['DataSampleFormat']
@@ -223,18 +232,18 @@ class SegYReader(object):
self._fh, start_pos, ctype, num_samples_to_read, self._endian)
return trace_values
def read_trace_header(self, trace_index):
"""Read a specific trace.
def trace_header(self, trace_index):
"""Read a specific trace_samples.
Args:
trace_index: An integer in the range zero to num_traces() - 1
Returns:
A TraceHeader corresponding to the requested trace.
A TraceHeader corresponding to the requested trace_samples.
Example:
first_trace_header, first_trace_samples = segy_reader.read_trace(0)
first_trace_header, first_trace_samples = segy_reader.trace_samples(0)
"""
if not (0 <= trace_index < self.num_traces()):
raise ValueError("Trace index {} out of range".format(trace_index))
@@ -248,7 +257,7 @@ class SegYReader(object):
Returns:
3 for 3D seismic volumes, 2 for 2D seismic lines, 1 for a
single trace, otherwise 0.
single trace_samples, otherwise 0.
"""
return self._dimensionality()
@@ -274,11 +283,12 @@ class SegYReader(object):
@property
def extended_textual_header(self):
"""The concatenation of any extended textual headers as a Unicode string.
"""A sequence of sequences of Unicode strings.
If there were no headers, the string may be empty.
If there were no headers, the sequence will be empty.
"""
return self._extended_textual_header
return self._extended_textual_headers
@property
def filename(self):
@@ -300,7 +310,7 @@ class SegYReader(object):
@property
def bytes_per_sample(self):
"""The number of bytes per trace sample.
"""The number of bytes per trace_samples sample.
"""
return self._bytes_per_sample
@@ -332,7 +342,7 @@ class SegYReader3D(SegYReader):
fh,
textual_reel_header,
binary_reel_header,
extended_textual_header,
extended_textual_headers,
trace_offset_catalog,
trace_length_catalog,
line_catalog,
@@ -349,11 +359,11 @@ class SegYReader3D(SegYReader):
binary_reel_header: A dictionary containing reel header data.
trace_offset_catalog: A mapping from zero-based trace indexes to
trace_offset_catalog: A mapping from zero-based trace_samples indexes to
the byte-offset to individual traces within the file.
trace_length_catalog: A mapping from zero-based trace indexes to
the number of samples in that trace.
trace_length_catalog: A mapping from zero-based trace_samples indexes to
the number of samples in that trace_samples.
line_catalog: A mapping from (xline, inline) tuples to
trace_indexes.
@@ -361,30 +371,70 @@ class SegYReader3D(SegYReader):
endian: '>' for big-endian data (the standard and default), '<' for
little-endian (non-standard)
"""
super(SegYReader3D, self).__init__(fh, textual_reel_header, binary_reel_header, extended_textual_header,
super(SegYReader3D, self).__init__(fh, textual_reel_header, binary_reel_header, extended_textual_headers,
trace_offset_catalog, trace_length_catalog, endian)
self._line_catalog = line_catalog
self._num_inlines = None
self._num_xlines = None
def _dimensionality(self):
return 3
def num_inlines(self):
"""The number of distinct inlines in the survey
def inline_range(self):
"""A range encompassing inline numbers.
The number of inlines within this range can be found with len(reader.inline_range()).
Returns:
A range() object with start set to the first inline number and stop set to
one beyond the last inline number. The range always has a step of one, although
this should not be taken as meaning that any intermediate inline number generated
by the range is valid.
"""
try:
return self._line_catalog.i_max - self._line_catalog.i_min + 1
except AttributeError:
# TODO: Memoize
return len(set(i for i, j in self._line_catalog))
start = self._line_catalog.key_min()[0]
stop = self._line_catalog.key_max()[0] + 1
return range(start, stop)
def num_inlines(self):
"""The number of distinct inlines in the survey.
This number is not necessarily the same as the value returned by
len(reader.inline_range()) as there may be missing inlines within the range.
"""
if self._num_inlines is None:
try:
self._num_inlines = self._line_catalog.i_max - self._line_catalog.i_min + 1
except AttributeError:
self._num_inlines = len(set(i for i, j in self._line_catalog))
return self._num_inlines
def xline_range(self):
"""A range encompassing crossline numbers.
The number of crosslines within this range can be found with len(reader.crossline_range()).
Returns:
A range() object with start set to the first crossline number and stop set to
one beyond the last crossline number. The range always has a step of one, although
this should not be taken as meaning that any intermediate crossline number generated
by the range is valid.
"""
start = self._line_catalog.key_min()[1]
stop = self._line_catalog.key_max()[1] + 1
return range(start, stop)
def num_xlines(self):
"""The number of distinct crosslines in the survey
"""The number of distinct crosslines in the survey.
This number is not necessarily the same as the value returned by
len(reader.xline_range()) as there may be missing crosslines within the range.
"""
try:
return self._line_catalog.j_max - self._line_catalog.j_min + 1
except AttributeError:
# TODO: Memoize
return len(set(j for i, j in self._line_catalog))
if self._num_xlines is None:
try:
self._num_xlines = self._line_catalog.j_max - self._line_catalog.j_min + 1
except AttributeError:
self._num_xlines = len(set(j for i, j in self._line_catalog))
return self._num_xlines
def inline_xline_numbers(self):
"""An iterator over all (inline_number, xline_number) tuples
@@ -392,8 +442,19 @@ class SegYReader3D(SegYReader):
"""
return iter(self._line_catalog)
def has_trace_index(self, inline_xline):
"""Determine whether a specific trace_samples exists.
Args:
inline_xline: A 2-tuple of inline number, crossline number.
Returns:
True if the specified trace_samples exists, otherwise False.
"""
return inline_xline in self._line_catalog
def trace_index(self, inline_xline):
"""Obtain the trace index given an xline and a inline.
"""Obtain the trace_samples index given an xline and a inline.
Note:
Do not assume that all combinations of crossline and
@@ -409,24 +470,23 @@ class SegYReader3D(SegYReader):
inline_xline: A 2-tuple of inline number, crossline number.
Returns:
A trace index which can be used with read_trace().
A trace_samples index which can be used with trace_samples().
"""
return self._line_catalog[inline_xline]
class SegYReader2D(SegYReader):
def __init__(self,
fh,
textual_reel_header,
binary_reel_header,
extended_textual_header,
extended_textual_headers,
trace_offset_catalog,
trace_length_catalog,
cdp_catalog, endian='>'):
"""Initialize a SegYReader2D around a file-like-object.
Note:
Note:
Usually a SegYReader is most easily constructed using the
create_reader() function.
@@ -436,18 +496,18 @@ class SegYReader2D(SegYReader):
binary_reel_header: A dictionary containing reel header data.
trace_catalog_offset: A mapping from zero-based trace index to
trace_catalog_offset: A mapping from zero-based trace_samples index to
the byte-offset to individual traces within the file.
trace_length_catalog: A mapping from zero-based trace indexes to
the number of samples in that trace.
trace_length_catalog: A mapping from zero-based trace_samples indexes to
the number of samples in that trace_samples.
cdp_catalog: A mapping from CDP numbers to trace_indexes.
endian: '>' for big-endian data (the standard and default), '<' for
little-endian (non-standard)
"""
super(SegYReader2D, self).__init__(fh, textual_reel_header, binary_reel_header, extended_textual_header,
super(SegYReader2D, self).__init__(fh, textual_reel_header, binary_reel_header, extended_textual_headers,
trace_offset_catalog, trace_length_catalog, endian)
self._cdp_catalog = cdp_catalog
@@ -459,34 +519,55 @@ class SegYReader2D(SegYReader):
"""
return iter(self._cdp_catalog)
def num_cdps(self):
return len(self._cdp_catalog)
def cdp_range(self):
"""A range encompassing CDP numbers.
def trace_index(self, cdp_number):
"""Obtain the trace index given an xline and a inline.
Note:
Do not assume that all combinations of crossline and
inline co-ordinates are valid. The volume may not be
rectangular. Valid values can be obtained from the
inline_xline_numbers() iterator.
Furthermore, inline and crossline numbers should not be
relied upon to be zero- or one-based indexes (although
they may be).
Args:
xline: A crossline number.
inline: An inline number.
The number of CDPs within this range can be found with len(reader.cdp_range()).
Returns:
A trace index which can be used with read_trace().
A range() object with start set to the first CDP number and stop set to
one beyond the last CDP number. The range always has a step of one, although
this should not be taken as meaning that any intermediate CDP number generated
by the range is valid.
"""
start = self._cdp_catalog.value_min()
stop = self._cdp_catalog.value_max() + 1
return range(start, stop)
def num_cdps(self):
"""The number of distinct CDPs.
This number is not necessarily the same as the value returned by
len(reader.cdp_range()) as there may be missing CDPs.
"""
return len(self._cdp_catalog)
def has_trace_index(self, cdp_number):
"""Determine whether a specified trace_samples exists.
Args:
cdp_number: A CDP number.
Returns:
True if the trace_samples exists, otherwise False.
"""
return self._cdp_catalog[cdp_number]
def trace_index(self, cdp_number):
"""Obtain the trace_samples index given an xline and a inline.
Args:
cdp_number: A CDP number.
Returns:
A trace_samples index which can be used with trace_samples().
"""
return self._cdp_catalog[cdp_number]
def main(argv=None):
import sys
if argv is None:
argv = sys.argv[1:]
@@ -539,8 +620,10 @@ def main(argv=None):
print("=== END EXTENDED TEXTUAL_HEADER ===")
for trace_index in segy_reader.trace_indexes():
trace_header = segy_reader.read_trace_header(trace_index)
print("Inline {}, Crossline {}, Shotpoint {}".format(trace_header.Inline3D, trace_header.Crossline3D, trace_header.ShotPoint))
trace_header = segy_reader.trace_header(trace_index)
print("Inline {}, Crossline {}, Shotpoint {}".format(trace_header.Inline3D, trace_header.Crossline3D,
trace_header.ShotPoint))
if __name__ == '__main__':
main()
View File
@@ -1,4 +1,4 @@
from revisions import SEGY_REVISION_0, SEGY_REVISION_1
from segpy.revisions import SEGY_REVISION_0, SEGY_REVISION_1
TEMPLATE = """
C 1 CLIENT { client } COMPANY { company } CREW NO {crew }
+135 -85
View File
@@ -6,18 +6,19 @@ import os
import struct
import re
from catalog import CatalogBuilder
from datatypes import CTYPES, size_in_bytes
from encoding import guess_encoding
from binary_reel_header_definition import HEADER_DEF
from ibm_float import ibm2ieee, ieee2ibm
from revisions import canonicalize_revision
import textual_reel_header_definition
from trace_header_definition import TRACE_HEADER_DEF
from util import file_length, batched, pad, round_up, complementary_slices
from portability import EMPTY_BYTE_STRING
from segpy import textual_reel_header_definition
from segpy.catalog import CatalogBuilder
from segpy.datatypes import CTYPES, size_in_bytes
from segpy.encoding import guess_encoding, is_supported_encoding, UnsupportedEncodingError
from segpy.binary_reel_header_definition import HEADER_DEF
from segpy.ibm_float import ibm2ieee, ieee2ibm
from segpy.revisions import canonicalize_revision
from segpy.trace_header_definition import TRACE_HEADER_DEF
from segpy.util import file_length, batched, pad, round_up, complementary_slices
from segpy.portability import EMPTY_BYTE_STRING
HEADER_NEWLINE = '\r\n'
CARD_LENGTH = 80
CARDS_PER_HEADER = 40
@@ -29,6 +30,7 @@ TRACE_HEADER_NUM_BYTES = 240
END_TEXT_STANZA = "((SEG: EndText))"
def extract_revision(binary_reel_header):
"""Obtain the SEG Y revision from the reel header.
@@ -77,32 +79,32 @@ def bytes_per_sample(binary_reel_header, revision):
def samples_per_trace(binary_reel_header):
"""Determine the number of samples per trace from the reel header.
"""Determine the number of samples per trace_samples from the reel header.
Note: There is no requirement for all traces to be of the same length,
so this value should be considered indicative only, and as such is
mostly useful in the absence of other information. The actual number
of samples for a specific trace should be retrieved from individual
trace headers.
of samples for a specific trace_samples should be retrieved from individual
trace_samples headers.
Args:
binary_reel_header: A dictionary containing a reel header, such as obtained
from read_binary_reel_header()
Returns:
An integer number of samples per trace
An integer number of samples per trace_samples
"""
return binary_reel_header['ns']
def trace_length_bytes(binary_reel_header, bps):
"""Determine the trace length in bytes from the reel header.
"""Determine the trace_samples length in bytes from the reel header.
Note: There is no requirement for all traces to be of the same length,
so this value should be considered indicative only, and as such is
mostly useful in the absence of other information. The actual number
of samples for a specific trace should be retrieved from individual
trace headers.
of samples for a specific trace_samples should be retrieved from individual
trace_samples headers.
Args:
binary_reel_header: A dictionary containing a reel header, such as obtained
@@ -163,7 +165,6 @@ def read_binary_reel_header(fh, endian='>'):
return reel_header
def has_end_text_stanza(ext_header):
"""Determine whether the header is the end text stanza.
@@ -190,14 +191,15 @@ def read_extended_headers_until_end(fh, encoding):
Typically 'cp037' for EBCDIC or 'ascii' for ASCII.
Returns:
A list of tuples each containing forty CARD_LENGTH character Unicode strings.
A list of tuples each containing forty CARD_LENGTH character Unicode strings. If present, the end_text
stanza is excluded.
"""
extended_headers = []
while True:
ext_header = read_textual_reel_header(fh, encoding)
extended_headers.append(ext_header)
if has_end_text_stanza(ext_header):
break
extended_headers.append(ext_header)
return extended_headers
@@ -250,24 +252,46 @@ def read_extended_textual_headers(fh, binary_reel_header, encoding):
Typically 'cp037' for EBCDIC or 'ascii' for ASCII.
Returns:
A Unicode string containing the concatenated contents of any extended headers. If there
were no extended headers, the string will be empty.
A sequence of sequences of Unicode strings representing headers of lines of characters. The length of the
outer sequence will be equal to the number of extended headers read. Each item in the outer sequence will be
a sequence of exactly forty Unicode strings. To combine the headers into a single string, consider using
concatenate_extended_textual_headers().
Postcondition:
As a post-condition to this function, the file-pointer of fh will be
positioned immediately after the last extended textual header, which
should be the start of the first trace header.
should be the start of the first trace_samples header.
"""
fh.seek(REEL_HEADER_NUM_BYTES)
declared_num_ext_headers = num_extended_textual_headers(binary_reel_header)
extended_headers = []
if declared_num_ext_headers == -1:
extended_headers.extend(read_extended_headers_until_end(fh, encoding))
elif declared_num_ext_headers > 0:
extended_headers.extend(read_extended_headers_counted(fh, declared_num_ext_headers, encoding))
if declared_num_ext_headers < 0:
return read_extended_headers_until_end(fh, encoding)
return read_extended_headers_counted(fh, declared_num_ext_headers, encoding)
def concatenate_extended_textual_headers(extended_textual_headers):
"""Combine extended textual headers.
Args:
extended_textual_headers: A sequence of sequences of Unicode strings, such as that returned
by read_extended_textual_headers().
Returns:
A Unicode string containing the concatenated contents of any extended headers. If there
were no extended headers, the string will be empty.
"""
if len(extended_textual_headers) == 0:
return ""
# Remove the end text header if it is present
if has_end_text_stanza(extended_textual_headers[-1]):
del extended_textual_headers[-1]
# Concatenate the extended headers
extended_textual_header = ''.join(line for header in extended_headers for line in header).strip(' ')
extended_textual_header = ''.join(line for header in extended_textual_headers for line in header).strip(' ')
return extended_textual_header
@@ -276,7 +300,7 @@ _READ_PROPORTION = 0.75 # The proportion of time spent in catalog_traces
def catalog_traces(fh, bps, endian='>', progress=None):
"""Build catalogs to facilitate random access to trace data.
"""Build catalogs to facilitate random access to trace_samples data.
Note:
This function can take significant time to run, proportional
@@ -284,20 +308,20 @@ def catalog_traces(fh, bps, endian='>', progress=None):
Four catalogs will be build:
1. A catalog mapping trace index (0-based) to the position of that
trace header in the file.
1. A catalog mapping trace_samples index (0-based) to the position of that
trace_samples header in the file.
2. A catalog mapping trace index (0-based) to the number of
samples in that trace.
2. A catalog mapping trace_samples index (0-based) to the number of
samples in that trace_samples.
3. A catalog mapping CDP number to the trace index.
3. A catalog mapping CDP number to the trace_samples index.
4. A catalog mapping an (inline, crossline) number 2-tuple to
trace index.
trace_samples index.
Args:
fh: A file-like-object open in binary mode, positioned at the
start of the first trace header.
start of the first trace_samples header.
bps: The number of bytes per sample, such as obtained by a call
to bytes_per_sample()
@@ -311,8 +335,8 @@ def catalog_traces(fh, bps, endian='>', progress=None):
an argument equal to 1
Returns:
A 4-tuple of the form (trace-offset-catalog,
trace-length-catalog,
A 4-tuple of the form (trace_samples-offset-catalog,
trace_samples-length-catalog,
cdp-catalog,
line-catalog)` where
each catalog is an instance of ``collections.Mapping`` or None
@@ -374,7 +398,6 @@ def catalog_traces(fh, bps, endian='>', progress=None):
# Some 3D files put Inline and Crossline numbers in (TraceSequenceFile, cdp) pair
line_catalog = alt_line_catalog_builder.create()
progress_callback(1)
return (trace_offset_catalog,
@@ -384,7 +407,7 @@ def catalog_traces(fh, bps, endian='>', progress=None):
def read_trace_header(fh, trace_header_format, pos=None):
"""Read a trace header.
"""Read a trace_samples header.
Args:
fh: A file-like-object open in binary mode.
@@ -583,7 +606,7 @@ def write_textual_reel_header(fh, lines, encoding):
standard) although this is not enforced by this function, since
many widespread SEG Y readers and writers do not adhere to this
constraint. To produce a SEG Y compliant series of header lines
consider using the standard_textual_header() function.
consider using the format_standard_textual_header() function.
Any lines longer than CARD_LENGTH characters will be truncated without
warning. Any excess lines over CARDS_PER_HEADER will be discarded. Short
@@ -591,6 +614,7 @@ def write_textual_reel_header(fh, lines, encoding):
encoding: Typically 'cp037' for EBCDIC or 'ascii' for ASCII.
"""
# TODO: Seek
padded_lines = [line.encode(encoding).ljust(CARD_LENGTH, ' '.encode(encoding))[:CARD_LENGTH]
for line in pad(lines, padding='', size=CARDS_PER_HEADER)]
header = ''.join(padded_lines)
@@ -608,6 +632,7 @@ def write_binary_reel_header(fh, binary_reel_header, endian='>'):
in binary_reel_header_definition.HEADER_DEF associated with
compatible values.
"""
# TODO: Seek
for key in HEADER_DEF:
pos = HEADER_DEF[key]['pos']
ctype = HEADER_DEF[key]['type']
@@ -615,64 +640,89 @@ def write_binary_reel_header(fh, binary_reel_header, endian='>'):
write_binary_values(fh, [value], ctype, pos)
def page_buffer(padded_buffer, page_size):
return [padded_buffer[i:i + page_size] for i in
range(0, len(padded_buffer), page_size)]
def format_extended_textual_header(text, encoding, include_text_stop=False):
"""Format an extended textual header into 3200 byte pages.
"""Format a string into pages and line suitable for an extended textual header.
Args:
text: A Unicode string to be written to the extended headers.
Args
text: An arbitrary text string. Any universal newlines will be preserved.
encoding: Typically 'cp037' for EBCDIC or 'ascii' for ASCII.
encoding: Either ASCII ('ascii') or EBCDIC ('cp037')
include_text_stop: If True, a text-stop header will be written.
Returns:
A sequence of byte strings, each of which will be exactly 3200 bytes in length.
include_text_stop: If True, a text stop stanza header will be appended, otherwise not.
"""
buffer = text.encode(encoding)
padded_buffer = buffer.ljust(round_up(len(buffer), TEXTUAL_HEADER_NUM_BYTES), ' '.encode(encoding))
pages = page_buffer(padded_buffer, TEXTUAL_HEADER_NUM_BYTES)
if not is_supported_encoding(encoding):
raise UnsupportedEncodingError("Extended textual header", encoding)
# According to the standard: "The Extended Textual File Header consists of one or more 3200-byte records, each
# record containing 40 lines of textual card-image text." It goes on "... Each line in an Extented Textual File
# Header ends in carriage return and linefeed (EBCDIX 0D25 or ASCII 0D0A)." Given that we're dealing with fixed-
# length (80 byte) lines, this implies that we have 78 bytes of space into which we can encode the content of each
# line, which must be left-justified and padded with spaces.
width = CARD_LENGTH - len(HEADER_NEWLINE)
original_lines = text.splitlines()
# Split overly long lines (i.e. > 78) and pad too-short lines with spaces
lines = []
for original_line in original_lines:
padded_lines = (pad_and_terminate_header_line(original_line[i:i+width], width)
for i in range(0, len(original_line), width))
lines.extend(padded_lines)
pages = list(batched(lines, 40, pad_and_terminate_header_line('', width)))
if include_text_stop:
pages.append(text_stop_page(encoding))
stop_page = format_extended_textual_header(END_TEXT_STANZA, encoding)[0]
pages.append(stop_page)
return pages
def write_extended_textual_headers(fh, pages):
def pad_and_terminate_header_line(line, width):
return line.ljust(width, ' ') + HEADER_NEWLINE
def write_extended_textual_headers(fh, pages, encoding):
"""Write extended textual headers.
Args:
fh: fh: A file-like object open in binary mode for writing.
pages: A sequence of byte strings each of which is exactly
TEXTUAL_HEADER_NUM_BYTES in length. To produce such a
sequence of pages, consider calling the
format_extended_textual_header() function.
"""
if any(len(page) != TEXTUAL_HEADER_NUM_BYTES for page in pages):
raise ValueError("Page length must be {} bytes".format(TEXTUAL_HEADER_NUM_BYTES))
for page in pages:
fh.write(page)
pages: An iterables series of sequences of Unicode strings, where the outer iterable
represents 3200 byte pages, each comprised of a sequence of exactly 40 strings of nominally 80 characters
each. Although Unicode strings are accepted, and when encoded they should result in exact 80 bytes
sequences. To produce a valid data structure for pages, consider using format_extended_textual_header()
_text_stop_pages = {}
def text_stop_page(encoding):
"""Produce a text-stop extended textual header page.
Args:
encoding: Typically 'cp037' for EBCDIC or 'ascii' for ASCII.
Raises:
ValueError:
"""
if encoding not in _text_stop_pages:
_text_stop_pages[encoding] = (END_TEXT_STANZA + '\r\n') \
.encode(encoding) \
.ljust(TEXTUAL_HEADER_NUM_BYTES, ' '.encode(encoding))
return _text_stop_pages[encoding]
# TODO: Seek
encoded_pages = []
for page_index, page in enumerate(pages):
encoded_page = []
# TODO: Share some of this code with writing the textual reel header.
for line_index, line in enumerate(page):
encoded_line = line.encode(encoding)
num_encoded_bytes = len(encoded_line)
if num_encoded_bytes != CARD_LENGTH:
raise ValueError("Extended textual header line {} of page {} at {} bytes is not "
"{} bytes".format(line_index, page_index, num_encoded_bytes, CARD_LENGTH))
encoded_page.append(encoded_line)
num_encoded_lines = len(encoded_page)
if num_encoded_lines != CARDS_PER_HEADER:
raise ValueError("Extended textual header page {} number of "
"lines {} is not {}".format(num_encoded_lines, CARDS_PER_HEADER))
encoded_pages.append(encoded_page)
for encoded_page in encoded_pages:
concatenated_page = b''.join(encoded_page)
assert(len(concatenated_page) == TEXTUAL_HEADER_NUM_BYTES)
fh.write(concatenated_page)
def write_trace_header(fh, trace_header, trace_header_format, pos=None):
@@ -695,7 +745,7 @@ def write_trace_header(fh, trace_header, trace_header_format, pos=None):
fh.write(buf)
def write_trace_values(fh, values, ctype='l', pos=None):
def write_trace_samples(fh, values, ctype='l', pos=None):
write_binary_values(fh, values, ctype, pos)
@@ -763,7 +813,7 @@ _TraceAttributeSpec = namedtuple('Record', ['name', 'pos', 'type'])
def compile_trace_header_format(endian='>'):
"""Compile a format string for use with the struct module from the
trace header definition.
trace_samples header definition.
Args:
endian: '>' for big-endian data (the standard and default), '<' for
@@ -771,7 +821,7 @@ def compile_trace_header_format(endian='>'):
Returns:
A string which can be used with the struct module for parsing
trace headers.
trace_samples headers.
"""
@@ -801,7 +851,7 @@ def compile_trace_header_format(endian='>'):
def _compile_trace_header_record():
"""Build a TraceHeader namedtuple from the trace header definition"""
"""Build a TraceHeader namedtuple from the trace_samples header definition"""
record_specs = sorted(
[_TraceAttributeSpec(name,
TRACE_HEADER_DEF[name]['pos'],
@@ -1,4 +1,4 @@
from revisions import SEGY_REVISION_0, SEGY_REVISION_1
from segpy.revisions import SEGY_REVISION_0, SEGY_REVISION_1
TRACE_HEADER_DEF = {"TraceSequenceLine": {"pos": 0, "type": "int32"}}
TRACE_HEADER_DEF["TraceSequenceFile"] = {"pos": 4, "type": "int32"}
+26 -6
View File
@@ -2,8 +2,9 @@ import itertools
import time
import os
from portability import izip
from segpy.portability import izip
UNSET = object()
def pairwise(iterable):
"""Pairwise iteration.
@@ -19,11 +20,27 @@ def pairwise(iterable):
return izip(a, b)
def batched(iterable, batch_size):
"""
def batched(iterable, batch_size, padding=UNSET):
"""Batch an iterable series into equal sized batches.
Args:
iterable: The series to be batched.
batch_size: The size of the batch. Must be at least one.
padding: Optional value used to pad the final batch to batch_size. If
omitted, the final batch may be smaller than batch_size.
Yields:
A series of lists, each containing batch_size items from iterable.
Raises:
ValueError: If batch_size is less than one.
"""
if batch_size < 1:
raise ValueError("Batch size {} is not at least one.".format(batch_size))
pending = []
batch = pending
for item in iterable:
pending.append(item)
@@ -32,8 +49,11 @@ def batched(iterable, batch_size):
pending = []
yield batch
if len(pending) > 0:
yield batch
num_left_over = len(pending)
if num_left_over > 0:
if padding is not UNSET:
pending.extend([padding] * (batch_size - num_left_over))
yield pending
def pad(iterable, padding=None, size=None):
+64
View File
@@ -0,0 +1,64 @@
from collections import namedtuple, Counter
import random
import unittest
from hypothesis import given
from hypothesis.descriptors import one_of, SampledFrom, Just, sampled_from, just
from hypothesis.searchstrategy import MappedSearchStrategy, StringStrategy
from hypothesis.strategytable import StrategyTable
from segpy.encoding import EBCDIC, ASCII
from segpy.toolkit import format_extended_textual_header, CARDS_PER_HEADER, END_TEXT_STANZA, CARD_LENGTH
# class MultiLineString(str):
# pass
#
#
# class MultiLineStringStrategy(MappedSearchStrategy):
#
# def pack(self, x):
# return '\n'.join(x)
#
# def unpack(self, x):
# return ''.join(x.splitlines())
#
#
# StrategyTable.default().define_specification_for(
# MultiLineString,
# lambda s, d: MultiLineStringStrategy(
# strategy=s.strategy([str]),
# descriptor=MultiLineString))
class TestFormatExtendedTextualHeader(unittest.TestCase):
@given(str,
sampled_from([ASCII, EBCDIC]),
bool)
def test_forty_lines_per_page(self, text, encoding, include_text_stop):
pages = format_extended_textual_header(text, encoding, include_text_stop)
self.assertTrue(all(len(page) == CARDS_PER_HEADER for page in pages))
@given(str,
sampled_from([ASCII, EBCDIC]),
bool)
def test_eighty_bytes_per_encoded_line(self, text, encoding, include_text_stop):
pages = format_extended_textual_header(text, encoding, include_text_stop)
self.assertTrue(all([len(line.encode(encoding)) == CARD_LENGTH for page in pages for line in page]))
@given(str,
sampled_from([ASCII, EBCDIC]),
bool)
def test_lines_end_with_cr_lf(self, text, encoding, include_text_stop):
pages = format_extended_textual_header(text, encoding, include_text_stop)
self.assertTrue(all([line.endswith('\r\n') for page in pages for line in page]))
@given(str,
sampled_from([ASCII, EBCDIC]),
just(True))
def test_lines_end_with_cr_lf(self, text, encoding, include_text_stop):
pages = format_extended_textual_header(text, encoding, include_text_stop)
self.assertTrue(pages[-1][0].startswith(END_TEXT_STANZA))
if __name__ == '__main__':
unittest.main()
+2 -2
View File
@@ -1,5 +1,6 @@
import unittest
from ibm_float import ieee2ibm, ibm2ieee
from segpy.ibm_float import ieee2ibm, ibm2ieee
class Ibm2Ieee(unittest.TestCase):
@@ -21,7 +22,6 @@ class Ibm2Ieee(unittest.TestCase):
self.assertEqual(ibm2ieee((0b11000010, 0b01110110, 0b10100000, 0b00000000)), -118.625)
class Ieee2Ibm(unittest.TestCase):
def test_zero(self):
+43
View File
@@ -0,0 +1,43 @@
import unittest
from hypothesis import given, assume
from segpy.util import batched
class TestBatched(unittest.TestCase):
@given([int], int)
def test_batch_sizes_unpadded(self, items, batch_size):
assume(batch_size > 0)
batches = list(batched(items, batch_size))
self.assertTrue(all(len(batch) == batch_size for batch in batches[:-1]))
@given([int], int)
def test_final_batch_sizes(self, items, batch_size):
assume(len(items) > 0)
assume(batch_size > 0)
batches = list(batched(items, batch_size))
self.assertTrue(len(batches[-1]) <= batch_size)
@given([int], int, int)
def test_batch_sizes_padded(self, items, batch_size, pad):
assume(batch_size > 0)
batches = list(batched(items, batch_size, padding=pad))
self.assertTrue(all(len(batch) == batch_size for batch in batches))
@given([int], int, int)
def test_pad_contents(self, items, batch_size, pad):
assume(len(items) > 0)
assume(0 < batch_size < 1000)
num_left_over = len(items) % batch_size
pad_length = batch_size - num_left_over if num_left_over != 0 else 0
assume(pad_length != 0)
batches = list(batched(items, batch_size, padding=pad))
self.assertEqual(batches[-1][batch_size - pad_length:], [pad] * pad_length)
def test_pad(self):
batches = list(batched([0, 0], 3, 42))
self.assertEqual(batches[-1], [0, 0, 42])
if __name__ == '__main__':
unittest.main()