diff --git a/backtester/datahandler/schema.py b/backtester/datahandler/schema.py index 9cfa040..5c60964 100644 --- a/backtester/datahandler/schema.py +++ b/backtester/datahandler/schema.py @@ -53,26 +53,55 @@ class Field: self.name = name self.mapping = mapping - def _create_filter(self, op, value): - query = Field._format_query(self.mapping, op, value) + def _create_filter(self, op, other): + if isinstance(other, Field): + query = Field._format_query(self.mapping, op, other.mapping) + else: + query = Field._format_query(self.mapping, op, other) return Filter(query) - def _combine_fields(self, op, other): - name = Field._format_query(self.name, op, other.name) - mapping = Field._format_query(self.mapping, op, other.mapping) + def _combine_fields(self, op, other, invert=False): + if isinstance(other, Field): + name = Field._format_query(self.name, op, other.name, invert) + mapping = Field._format_query(self.mapping, op, other.mapping, + invert) + elif isinstance(other, (int, float)): + name = Field._format_query(self.name, op, other, invert) + mapping = Field._format_query(self.mapping, op, other, invert) + else: + raise TypeError + return Field(name, mapping) - def _format_query(left, op, right): + def _format_query(left, op, right, invert=False): + if invert: + left, right = right, left query = "{left} {op} {right}".format(left=left, op=op, right=right) return query - def __add__(self, field): - assert isinstance(field, Field) - return self._combine_fields("+", field) + def __add__(self, value): + return self._combine_fields("+", value) - def __sub__(self, field): - assert isinstance(field, Field) - return self._create_filter("-", field) + def __radd__(self, value): + return self._combine_fields("+", value, invert=True) + + def __sub__(self, value): + return self._combine_fields("-", value) + + def __rsub__(self, value): + return self._combine_fields("-", value, invert=True) + + def __mul__(self, value): + return self._combine_fields("*", value) + + def __rmul__(self, value): + return self._combine_fields("*", value, invert=True) + + def __truediv__(self, value): + return self._combine_fields("/", value) + + def __rtruediv__(self, value): + return self._combine_fields("/", value, invert=True) def __lt__(self, value): return self._create_filter("<", value)