mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 00:50:45 +08:00
ENH: allow columns from different datasets
This commit is contained in:
@@ -126,7 +126,6 @@ from __future__ import division, absolute_import
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from collections import namedtuple, defaultdict
|
||||
from itertools import count
|
||||
from operator import attrgetter
|
||||
import warnings
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
@@ -148,10 +147,12 @@ from toolz import (
|
||||
compose,
|
||||
concat,
|
||||
flip,
|
||||
groupby,
|
||||
identity,
|
||||
memoize,
|
||||
)
|
||||
from six import with_metaclass, PY2
|
||||
import toolz.curried.operator as op
|
||||
from six import with_metaclass, PY2, iteritems
|
||||
|
||||
|
||||
from ..data.dataset import DataSet, Column
|
||||
@@ -172,7 +173,7 @@ traversable_nodes = (
|
||||
bz.expr.Label,
|
||||
)
|
||||
is_invalid_deltas_node = complement(flip(isinstance, valid_deltas_node_types))
|
||||
getname = attrgetter('__name__')
|
||||
getname = op.attrgetter('__name__')
|
||||
|
||||
|
||||
class _ExprRepr(object):
|
||||
@@ -553,8 +554,8 @@ def from_blaze(expr,
|
||||
return ds
|
||||
|
||||
|
||||
getdataset = attrgetter('dataset')
|
||||
dataset_name = attrgetter('name')
|
||||
getdataset = op.attrgetter('dataset')
|
||||
dataset_name = op.attrgetter('name')
|
||||
|
||||
|
||||
def inline_novel_deltas(base, deltas, dates):
|
||||
@@ -712,6 +713,17 @@ class BlazeLoader(dict):
|
||||
return cls()
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
return map(
|
||||
op.getitem(
|
||||
dict(concat(
|
||||
self._load_dataset(cs, dates, assets, mask)
|
||||
for _, cs in iteritems(groupby(getdataset, columns))
|
||||
)),
|
||||
),
|
||||
columns,
|
||||
)
|
||||
|
||||
def _load_dataset(self, columns, dates, assets, mask):
|
||||
try:
|
||||
dataset, = set(map(getdataset, columns))
|
||||
except ValueError:
|
||||
@@ -826,7 +838,7 @@ class BlazeLoader(dict):
|
||||
|
||||
for column_idx, column in enumerate(columns):
|
||||
column_name = column.name
|
||||
yield adjusted_array(
|
||||
yield column, adjusted_array(
|
||||
column_view(
|
||||
dense_output[column_name].values.astype(column.dtype),
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user