From d1f4b6ba7506462a638d3a1f2ab16e2a2e7dc883 Mon Sep 17 00:00:00 2001 From: Eric Xihui Lin Date: Thu, 18 May 2017 13:28:17 -0400 Subject: [PATCH] CRF for keras 2.x (#76) * implemented CRF * added test for CRF * added a chunking example for CRF * changed to use up2date tensorflow * added conll2000 data * minimize package dependency * using logsumexp in keras instead --- examples/conll2000_chunking_crf.py | 94 +++++ keras_contrib/datasets/conll2000.py | 54 +++ keras_contrib/layers/__init__.py | 1 + keras_contrib/layers/crf.py | 535 +++++++++++++++++++++++++ tests/keras_contrib/layers/crf_test.py | 83 ++++ 5 files changed, 767 insertions(+) create mode 100644 examples/conll2000_chunking_crf.py create mode 100644 keras_contrib/datasets/conll2000.py create mode 100644 keras_contrib/layers/crf.py create mode 100644 tests/keras_contrib/layers/crf_test.py diff --git a/examples/conll2000_chunking_crf.py b/examples/conll2000_chunking_crf.py new file mode 100644 index 0000000..4ab20b0 --- /dev/null +++ b/examples/conll2000_chunking_crf.py @@ -0,0 +1,94 @@ +"""Train CRF and BiLSTM-CRF on CONLL2000 chunking data, similar to https://arxiv.org/pdf/1508.01991v1.pdf. +""" +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy +from collections import Counter + +from keras.models import Sequential +from keras.layers import Embedding, Bidirectional, LSTM +from keras.preprocessing.sequence import pad_sequences +from keras_contrib.layers import CRF +from keras_contrib.datasets import conll2000 + +EPOCHS = 10 +EMBED_DIM = 200 +BiRNN_UNITS = 200 + + +def classification_report(y_true, y_pred, labels): + '''Similar to the one in sklearn.metrics, reports per classs recall, precision and F1 score''' + y_true = numpy.asarray(y_true).ravel() + y_pred = numpy.asarray(y_pred).ravel() + corrects = Counter(yt for yt, yp in zip(y_true, y_pred) if yt == yp) + y_true_counts = Counter(y_true) + y_pred_counts = Counter(y_pred) + report = ((lab, # label + corrects[i] / max(1, y_true_counts[i]), # recall + corrects[i] / max(1, y_pred_counts[i]), # precision + y_true_counts[i] # support + ) for i, lab in enumerate(labels)) + report = [(l, r, p, 2 * r * p / max(1e-9, r + p), s) for l, r, p, s in report] + + print('{:<15}{:>10}{:>10}{:>10}{:>10}\n'.format('', 'recall', 'precision', 'f1-score', 'support')) + formatter = '{:<15}{:>10.2f}{:>10.2f}{:>10.2f}{:>10d}'.format + for r in report: + print(formatter(*r)) + print('') + report2 = zip(*[(r * s, p * s, f1 * s) for l, r, p, f1, s in report]) + N = len(y_true) + print(formatter('avg / total', sum(report2[0]) / N, sum(report2[1]) / N, sum(report2[2]) / N, N) + '\n') + + +# ------ +# Data +# ----- + +# conll200 has two different targets, here will only use IBO like chunking as an example +(train_x, _, train_y), (test_x, _, test_y), (vocab, _, class_labels) = conll2000.load_data() + + +# -------------- +# 1. Regular CRF +# -------------- + +print('==== training CRF ====') + +model = Sequential() +model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding +crf = CRF(len(class_labels), sparse_target=True) +model.add(crf) +model.summary() + +model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) +model.fit(train_x, train_y, epochs=EPOCHS, validation_data=[test_x, test_y]) + +test_y_pred = model.predict(test_x).argmax(-1)[test_x > 0] +test_y_true = test_y[test_x > 0] + +print('\n---- Result of CRF ----\n') +classification_report(test_y_true, test_y_pred, class_labels) + +# ------------- +# 2. BiLSTM-CRF +# ------------- + +print('==== training BiLSTM-CRF ====') + +model = Sequential() +model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding +model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True))) +crf = CRF(len(class_labels), sparse_target=True) +model.add(crf) +model.summary() + +model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) +model.fit(train_x, train_y, epochs=EPOCHS, validation_data=[test_x, test_y]) + +test_y_pred = model.predict(test_x).argmax(-1)[test_x > 0] +test_y_true = test_y[test_x > 0] + +print('\n---- Result of BiLSTM-CRF ----\n') +classification_report(test_y_true, test_y_pred, class_labels) diff --git a/keras_contrib/datasets/conll2000.py b/keras_contrib/datasets/conll2000.py new file mode 100644 index 0000000..d97f90b --- /dev/null +++ b/keras_contrib/datasets/conll2000.py @@ -0,0 +1,54 @@ +import numpy +from keras.utils.data_utils import get_file +from zipfile import ZipFile +from collections import Counter +from keras.preprocessing.sequence import pad_sequences +from keras.datasets import cifar10 + + +def load_data(path='conll2000.zip', min_freq=2): + path = get_file(path, origin='https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/conll2000.zip') + print path + archive = ZipFile(path, 'r') + train = _parse_data(archive.open('conll2000/train.txt')) + test = _parse_data(archive.open('conll2000/test.txt')) + archive.close() + + word_counts = Counter(row[0].lower() for sample in train for row in sample) + vocab = ['', ''] + [w for w, f in word_counts.iteritems() if f >= min_freq] + pos_tags = sorted(list(set(row[1] for sample in train + test for row in sample))) # in alphabetic order + chunk_tags = sorted(list(set(row[2] for sample in train + test for row in sample))) # in alphabetic order + + train = _process_data(train, vocab, pos_tags, chunk_tags) + test = _process_data(test, vocab, pos_tags, chunk_tags) + return train, test, (vocab, pos_tags, chunk_tags) + + +def _parse_data(fh): + string = fh.read() + data = [[row.split() for row in sample.split('\n')] for sample in string.strip().split('\n\n')] + fh.close() + return data + + +def _process_data(data, vocab, pos_tags, chunk_tags, maxlen=None, onehot=False): + if maxlen is None: + maxlen = max(len(s) for s in data) + word2idx = dict((w, i) for i, w in enumerate(vocab)) + x = [[word2idx.get(w[0].lower(), 1) for w in s] for s in data] # set to (index 1) if not in vocab + + y_pos = [[pos_tags.index(w[1]) for w in s] for s in data] + y_chunk = [[chunk_tags.index(w[2]) for w in s] for s in data] + + x = pad_sequences(x, maxlen) # left padding + + y_pos = pad_sequences(y_pos, maxlen, value=-1) # lef padded with -1. Indeed, any interger works as it will be masked + y_chunk = pad_sequences(y_chunk, maxlen, value=-1) + + if onehot: + y_pos = numpy.eye(len(pos_tags), dtype='float32')[y] + y_chunk = numpy.eye(len(chunk_tags), dtype='float32')[y] + else: + y_pos = numpy.expand_dims(y_pos, 2) + y_chunk = numpy.expand_dims(y_chunk, 2) + return x, y_pos, y_chunk diff --git a/keras_contrib/layers/__init__.py b/keras_contrib/layers/__init__.py index 1ee566e..415cb2e 100644 --- a/keras_contrib/layers/__init__.py +++ b/keras_contrib/layers/__init__.py @@ -10,3 +10,4 @@ from .noise import * from .advanced_activations import * from .wrappers import * from .convolutional_recurrent import * +from .crf import * diff --git a/keras_contrib/layers/crf.py b/keras_contrib/layers/crf.py new file mode 100644 index 0000000..9cc06eb --- /dev/null +++ b/keras_contrib/layers/crf.py @@ -0,0 +1,535 @@ +from __future__ import absolute_import +from __future__ import division + +from .. import backend as K +from .. import activations +from .. import initializers +from .. import regularizers +from .. import constraints +from keras.engine import Layer +from keras.engine import InputSpec +from keras.objectives import categorical_crossentropy +from keras.objectives import sparse_categorical_crossentropy + + +class CRF(Layer): + """An implementation of linear chain conditional random field (CRF). + + An linear chain CRF is defined to maximize the following likelihood function: + + $$ L(W, U, b; y_1, ..., y_n) := \frac{1}{Z} \sum_{y_1, ..., y_n} \exp(-a_1' y_1 - a_n' y_n + - \sum_{k=1^n}((f(x_k' W + b) y_k) + y_1' U y_2)), $$ + + where: + $Z$: normalization constant + $x_k, y_k$: inputs and outputs + + This implementation has two modes for optimization: + 1. (`join mode`) optimized by maximizing join likelihood, which is optimal in theory of statistics. + Note that in this case, CRF mast be the output/last layer. + 2. (`marginal mode`) return marginal probabilities on each time step and optimized via composition + likelihood (product of marginal likelihood), i.e., using `categorical_crossentropy` loss. + Note that in this case, CRF can be either the last layer or an intermediate layer (though not explored). + + For prediction (test phrase), one can choose either Viterbi best path (class indices) or marginal + probabilities if probabilities are needed. However, if one chooses *join mode* for training, + Viterbi output is typically better than marginal output, but the marginal output will still perform + reasonably close, while if *marginal mode* is used for training, marginal output usually performs + much better. The default behavior is set according to this observation. + + In addition, this implementation supports masking and accepts either onehot or sparse target. + + + # Examples + + ```python + model = Sequential() + model.add(Embedding(3001, 300, mask_zero=True)(X) + + # use learn_mode = 'join', test_mode = 'viterbi', sparse_target = True (label indice output) + crf = CRF(10, sparse_target=True) + model.add(crf) + + # crf.accuracy is default to Viterbi acc if using join-mode (default). + # One can add crf.marginal_acc if interested, but may slow down learning + model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy]) + + # y must be label indices (with shape 1 at dim 3) here, since `sparse_target=True` + model.fit(x, y) + + # prediction give onehot representation of Viterbi best path + y_hat = model.predict(x_test) + ``` + + + # Arguments + units: Positive integer, dimensionality of the output space. + learn_mode: Either 'join' or 'marginal'. + The former train the model by maximizing join likelihood while the latter + maximize the product of marginal likelihood over all time steps. + test_mode: Either 'viterbi' or 'marginal'. + The former is recommended and as default when `learn_mode = 'join'` and + gives one-hot representation of the best path at test (prediction) time, + while the latter is recommended and chosen as default when `learn_mode = 'marginal'`, + which produces marginal probabilities for each time step. + sparse_target: Boolen (default False) indicating if provided labels are one-hot or + indices (with shape 1 at dim 3). + use_boundary: Boolen (default True) inidicating if trainable start-end chain energies + should be added to model. + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + chain_initializer: Initializer for the `chain_kernel` weights matrix, + used for the CRF chain energy. + (see [initializers](../initializers.md)). + boundary_initializer: Initializer for the `left_boundary`, 'right_boundary' weights vectors, + used for the start/left and end/right boundary energy. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + chain_regularizer: Regularizer function applied to + the `chain_kernel` weights matrix + (see [regularizer](../regularizers.md)). + boundary_regularizer: Regularizer function applied to + the 'left_boundary', 'right_boundary' weight vectors + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + chain_constraint: Constraint function applied to + the `chain_kernel` weights matrix + (see [constraints](../constraints.md)). + boundary_constraint: Constraint function applied to + the `left_boundary`, `right_boundary` weights vectors + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + input_dim: dimensionality of the input (integer). + This argument (or alternatively, the keyword argument `input_shape`) + is required when using this layer as the first layer in a model. + unroll: Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. + Unrolling can speed-up a RNN, although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + # Input shape + 3D tensor with shape `(nb_samples, timesteps, input_dim)`. + + # Output shape + 3D tensor with shape `(nb_samples, timesteps, units)`. + + # Masking + This layer supports masking for input data with a variable number + of timesteps. To introduce masks to your data, + use an [Embedding](embeddings.md) layer with the `mask_zero` parameter + set to `True`. + + """ + + def __init__(self, units, + learn_mode='join', + test_mode=None, + sparse_target=False, + use_boundary=True, + use_bias=True, + activation='linear', + kernel_initializer='glorot_uniform', + chain_initializer='orthogonal', + bias_initializer='zeros', + boundary_initializer='zeros', + kernel_regularizer=None, + chain_regularizer=None, + boundary_regularizer=None, + bias_regularizer=None, + kernel_constraint=None, + chain_constraint=None, + boundary_constraint=None, + bias_constraint=None, + input_dim=None, + unroll=False, + **kwargs): + super(CRF, self).__init__(**kwargs) + self.supports_masking = True + self.units = units + self.learn_mode = learn_mode + assert self.learn_mode in ['join', 'marginal'] + self.test_mode = test_mode + if self.test_mode is None: + self.test_mode = 'viterbi' if self.learn_mode == 'join' else 'marginal' + else: + assert self.test_mode in ['viterbi', 'marginal'] + self.sparse_target = sparse_target + self.use_boundary = use_boundary + self.use_bias = use_bias + + self.activation = activations.get(activation) + + self.kernel_initializer = initializers.get(kernel_initializer) + self.chain_initializer = initializers.get(chain_initializer) + self.boundary_initializer = initializers.get(boundary_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.chain_regularizer = regularizers.get(chain_regularizer) + self.boundary_regularizer = regularizers.get(boundary_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.chain_constraint = constraints.get(chain_constraint) + self.boundary_constraint = constraints.get(boundary_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.unroll = unroll + + def build(self, input_shape): + self.input_spec = [InputSpec(shape=input_shape)] + self.input_dim = input_shape[-1] + + self.kernel = self.add_weight((self.input_dim, self.units), + name='kernel', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.chain_kernel = self.add_weight((self.units, self.units), + name='chain_kernel', + initializer=self.chain_initializer, + regularizer=self.chain_regularizer, + constraint=self.chain_constraint) + if self.use_bias: + self.bias = self.add_weight((self.units,), + name='bias', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + + if self.use_boundary: + self.left_boundary = self.add_weight((self.units,), + name='left_boundary', + initializer=self.boundary_initializer, + regularizer=self.boundary_regularizer, + constraint=self.boundary_constraint) + self.right_boundary = self.add_weight((self.units,), + name='right_boundary', + initializer=self.boundary_initializer, + regularizer=self.boundary_regularizer, + constraint=self.boundary_constraint) + self.built = True + + def call(self, X, mask=None): + if mask is not None: + assert K.ndim(mask) == 2, 'Input mask to CRF must have dim 2 if not None' + + if self.test_mode == 'viterbi': + test_output = self.viterbi_decoding(X, mask) + else: + test_output = self.get_marginal_prob(X, mask) + + self.uses_learning_phase = True + if self.learn_mode == 'join': + train_output = K.zeros_like(K.dot(X, self.kernel)) + out = K.in_train_phase(train_output, test_output) + else: + if self.test_mode == 'viterbi': + train_output = self.get_marginal_prob(X, mask) + out = K.in_train_phase(train_output, test_output) + else: + out = test_output + return out + + def compute_output_shape(self, input_shape): + return input_shape[:2] + (self.units,) + + def compute_mask(self, input, mask=None): + if mask is not None and self.learn_mode == 'join': + return K.any(mask, axis=1) + return mask + + def get_config(self): + config = {'units': self.units, + 'learn_mode': self.learn_mode, + 'test_mode': self.test_mode, + 'use_boundary': self.use_boundary, + 'use_bias': self.use_bias, + 'sparse_target': self.sparse_target, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'chain_initializer': initializers.serialize(self.chain_initializer), + 'boundary_initializer': initializers.serialize(self.boundary_initializer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'activation': activations.serialize(self.activation), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'chain_regularizer': regularizers.serialize(self.chain_regularizer), + 'boundary_regularizer': regularizers.serialize(self.boundary_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'chain_constraint': constraints.serialize(self.chain_constraint), + 'boundary_constraint': constraints.serialize(self.boundary_constraint), + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'input_dim': self.input_dim, + 'unroll': self.unroll} + base_config = super(CRF, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @property + def loss_function(self): + if self.learn_mode == 'join': + def loss(y_true, y_pred): + assert self.inbound_nodes, 'CRF has not connected to any layer.' + assert not self.outbound_nodes, 'When learn_model="join", CRF must be the last layer.' + if self.sparse_target: + y_true = K.one_hot(K.cast(y_true[:, :, 0], 'int32'), self.units) + X = self.inbound_nodes[0].input_tensors[0] + mask = self.inbound_nodes[0].input_masks[0] + nloglik = self.get_negative_log_likelihood(y_true, X, mask) + return nloglik + return loss + else: + if self.sparse_target: + return sparse_categorical_crossentropy + else: + return categorical_crossentropy + + @property + def accuracy(self): + if self.test_mode == 'viterbi': + return self.viterbi_acc + else: + return self.marginal_acc + + @staticmethod + def _get_accuracy(y_true, y_pred, mask, sparse_target=False): + y_pred = K.argmax(y_pred, -1) + if sparse_target: + y_true = K.cast(y_true[:, :, 0], K.dtype(y_pred)) + else: + y_true = K.argmax(y_true, -1) + judge = K.cast(K.equal(y_pred, y_true), K.floatx()) + if mask is None: + return K.mean(judge) + else: + mask = K.cast(mask, K.floatx()) + return K.sum(judge * mask) / K.sum(mask) + + @property + def viterbi_acc(self): + def acc(y_true, y_pred): + X = self.inbound_nodes[0].input_tensors[0] + mask = self.inbound_nodes[0].input_masks[0] + y_pred = self.viterbi_decoding(X, mask) + return self._get_accuracy(y_true, y_pred, mask, self.sparse_target) + acc.func_name = 'viterbi_acc' + return acc + + @property + def marginal_acc(self): + def acc(y_true, y_pred): + X = self.inbound_nodes[0].input_tensors[0] + mask = self.inbound_nodes[0].input_masks[0] + y_pred = self.get_marginal_prob(X, mask) + return self._get_accuracy(y_true, y_pred, mask, self.sparse_target) + acc.func_name = 'marginal_acc' + return acc + + @staticmethod + def softmaxNd(x, axis=-1): + m = K.max(x, axis=axis, keepdims=True) + exp_x = K.exp(x - m) + prob_x = exp_x / K.sum(exp_x, axis=axis, keepdims=True) + return prob_x + + @staticmethod + def shift_left(x, offset=1): + assert offset > 0 + return K.concatenate([x[:, offset:], K.zeros_like(x[:, :offset])], axis=1) + + @staticmethod + def shift_right(x, offset=1): + assert offset > 0 + return K.concatenate([K.zeros_like(x[:, :offset]), x[:, :-offset]], axis=1) + + def add_boundary_energy(self, energy, mask, start, end): + start = K.expand_dims(K.expand_dims(start, 0), 0) + end = K.expand_dims(K.expand_dims(end, 0), 0) + if mask is None: + energy = K.concatenate([energy[:, :1, :] + start, energy[:, 1:, :]], axis=1) + energy = K.concatenate([energy[:, :-1, :], energy[:, -1:, :] + end], axis=1) + else: + mask = K.expand_dims(K.cast(mask, K.floatx())) + start_mask = K.cast(K.greater(mask, self.shift_right(mask)), K.floatx()) + end_mask = K.cast(K.greater(self.shift_left(mask), mask), K.floatx()) + energy = energy + start_mask * start + energy = energy + end_mask * end + return energy + + def get_log_normalization_constant(self, input_energy, mask, **kwargs): + """Compute logarithm of the normalization constance Z, where + Z = sum exp(-E) -> logZ = log sum exp(-E) =: -nlogZ + """ + # should have logZ[:, i] == logZ[:, j] for any i, j + logZ = self.recursion(input_energy, mask, return_sequences=False, **kwargs) + return logZ[:, 0] + + def get_energy(self, y_true, input_energy, mask): + """Energy = a1' y1 + u1' y1 + y1' U y2 + u2' y2 + y2' U y3 + u3' y3 + an' y3 + """ + input_energy = K.sum(input_energy * y_true, 2) # (B, T) + chain_energy = K.sum(K.dot(y_true[:, :-1, :], self.chain_kernel) * y_true[:, 1:, :], 2) # (B, T-1) + + if mask is not None: + mask = K.cast(mask, K.floatx()) + chain_mask = mask[:, :-1] * mask[:, 1:] # (B, T-1), mask[:,:-1]*mask[:,1:] makes it work with any padding + input_energy = input_energy * mask + chain_energy = chain_energy * chain_mask + total_energy = K.sum(input_energy, -1) + K.sum(chain_energy, -1) # (B, ) + + return total_energy + + def get_negative_log_likelihood(self, y_true, X, mask): + """Compute the loss, i.e., negative log likelihood (normalize by number of time steps) + likelihood = 1/Z * exp(-E) -> neg_log_like = - log(1/Z * exp(-E)) = logZ + E + """ + input_energy = self.activation(K.dot(X, self.kernel) + self.bias) + if self.use_boundary: + input_energy = self.add_boundary_energy(input_energy, mask, self.left_boundary, self.right_boundary) + energy = self.get_energy(y_true, input_energy, mask) + logZ = self.get_log_normalization_constant(input_energy, mask, input_length=K.int_shape(X)[1]) + nloglik = logZ + energy + if mask is not None: + nloglik = nloglik / K.sum(K.cast(mask, K.floatx()), 1) + else: + nloglik = nloglik / K.cast(K.shape(X)[1], K.floatx()) + return nloglik + + def step(self, input_energy_t, states, return_logZ=True): + # not in the following `prev_target_val` has shape = (B, F) + # where B = batch_size, F = output feature dim + # Note: `i` is of float32, due to the behavior of `K.rnn` + prev_target_val, i, chain_energy = states[:3] + t = K.cast(i[0, 0], dtype='int32') + if len(states) > 3: + if K.backend() == 'theano': + m = states[3][:, t:(t + 2)] + else: + m = K.tf.slice(states[3], [0, t], [-1, 2]) + input_energy_t = input_energy_t * K.expand_dims(m[:, 0]) + chain_energy = chain_energy * K.expand_dims(K.expand_dims(m[:, 0] * m[:, 1])) # (1, F, F)*(B, 1, 1) -> (B, F, F) + if return_logZ: + energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2) # shapes: (1, B, F) + (B, F, 1) -> (B, F, F) + new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) + return new_target_val, [new_target_val, i + 1] + else: + energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2) + min_energy = K.min(energy, 1) + argmin_table = K.cast(K.argmin(energy, 1), K.floatx()) # cast for tf-version `K.rnn` + return argmin_table, [min_energy, i + 1] + + def recursion(self, input_energy, mask=None, go_backwards=False, return_sequences=True, return_logZ=True, input_length=None): + """Forward (alpha) or backward (beta) recursion + + If `return_logZ = True`, compute the logZ, the normalization constance: + + \[ Z = \sum_{y1, y2, y3} exp(-E) # energy + = \sum_{y1, y2, y3} exp(-(u1' y1 + y1' W y2 + u2' y2 + y2' W y3 + u3' y3)) + = sum_{y2, y3} (exp(-(u2' y2 + y2' W y3 + u3' y3)) sum_{y1} exp(-(u1' y1' + y1' W y2))) \] + + Denote: + \[ S(y2) := sum_{y1} exp(-(u1' y1 + y1' W y2)), \] + \[ Z = sum_{y2, y3} exp(log S(y2) - (u2' y2 + y2' W y3 + u3' y3)) \] + \[ logS(y2) = log S(y2) = log_sum_exp(-(u1' y1' + y1' W y2)) \] + Note that: + yi's are one-hot vectors + u1, u3: boundary energies have been merged + + If `return_logZ = False`, compute the Viterbi's best path lookup table. + """ + chain_energy = self.chain_kernel + chain_energy = K.expand_dims(chain_energy, 0) # shape=(1, F, F): F=num of output features. 1st F is for t-1, 2nd F for t + prev_target_val = K.zeros_like(input_energy[:, 0, :]) # shape=(B, F), dtype=float32 + + if go_backwards: + input_energy = K.reverse(input_energy, 1) + if mask is not None: + mask = K.reverse(mask, 1) + + initial_states = [prev_target_val, K.zeros_like(prev_target_val[:, :1])] + constants = [chain_energy] + + if mask is not None: + mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1), K.floatx()) + constants.append(mask2) + + def _step(input_energy_i, states): + return self.step(input_energy_i, states, return_logZ) + + target_val_last, target_val_seq, _ = K.rnn(_step, input_energy, initial_states, constants=constants, + input_length=input_length, unroll=self.unroll) + + if return_sequences: + if go_backwards: + target_val_seq = K.reverse(target_val_seq, 1) + return target_val_seq + else: + return target_val_last + + def forward_recursion(self, input_energy, **kwargs): + return self.recursion(input_energy, **kwargs) + + def backward_recursion(self, input_energy, **kwargs): + return self.recursion(input_energy, go_backwards=True, **kwargs) + + def get_marginal_prob(self, X, mask=None): + input_energy = self.activation(K.dot(X, self.kernel) + self.bias) + if self.use_boundary: + input_energy = self.add_boundary_energy(input_energy, mask, self.left_boundary, self.right_boundary) + input_length = K.int_shape(X)[1] + alpha = self.forward_recursion(input_energy, mask=mask, input_length=input_length) + beta = self.backward_recursion(input_energy, mask=mask, input_length=input_length) + if mask is not None: + input_energy = input_energy * K.expand_dims(K.cast(mask, K.floatx())) + margin = -(self.shift_right(alpha) + input_energy + self.shift_left(beta)) + return self.softmaxNd(margin) + + def viterbi_decoding(self, X, mask=None): + input_energy = self.activation(K.dot(X, self.kernel) + self.bias) + if self.use_boundary: + input_energy = self.add_boundary_energy(input_energy, mask, self.left_boundary, self.right_boundary) + + argmin_tables = self.recursion(input_energy, mask, return_logZ=False) + argmin_tables = K.cast(argmin_tables, 'int32') + + # backward to find best path, `initial_best_idx` can be any, as all elements in the last argmin_table are the same + argmin_tables = K.reverse(argmin_tables, 1) + initial_best_idx = [K.expand_dims(argmin_tables[:, 0, 0])] # matrix instead of vector is required by tf `K.rnn` + if K.backend() == 'theano': + initial_best_idx = [K.T.unbroadcast(initial_best_idx[0], 1)] + + def gather_each_row(params, indices): + n = K.shape(indices)[0] + if K.backend() == 'theano': + return params[K.T.arange(n), indices] + else: + indices = K.transpose(K.stack([K.tf.range(n), indices])) + return K.tf.gather_nd(params, indices) + + def find_path(argmin_table, best_idx): + next_best_idx = gather_each_row(argmin_table, best_idx[0][:, 0]) + next_best_idx = K.expand_dims(next_best_idx) + if K.backend() == 'theano': + next_best_idx = K.T.unbroadcast(next_best_idx, 1) + return next_best_idx, [next_best_idx] + + _, best_paths, _ = K.rnn(find_path, argmin_tables, initial_best_idx, input_length=K.int_shape(X)[1], unroll=self.unroll) + best_paths = K.reverse(best_paths, 1) + best_paths = K.squeeze(best_paths, 2) + + return K.one_hot(best_paths, self.units) diff --git a/tests/keras_contrib/layers/crf_test.py b/tests/keras_contrib/layers/crf_test.py new file mode 100644 index 0000000..a54b5ae --- /dev/null +++ b/tests/keras_contrib/layers/crf_test.py @@ -0,0 +1,83 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose + +from keras.utils.test_utils import keras_test +from keras.layers import Embedding +from keras_contrib.layers import CRF +from keras.models import Sequential, model_from_json + +nb_samples, timesteps, embedding_dim, output_dim = 2, 10, 4, 5 +embedding_num = 12 + + +@keras_test +def test_CRF(): + # data + x = np.random.randint(1, embedding_num, nb_samples * timesteps).reshape((nb_samples, timesteps)) + x[0, -4:] = 0 # right padding + x[1, :5] = 0 # left padding + y = np.random.randint(0, output_dim, nb_samples * timesteps).reshape((nb_samples, timesteps)) + y_onehot = np.eye(output_dim)[y] + y = np.expand_dims(y, 2) # .astype('float32') + + # test with no masking, onehot, fix length + model = Sequential() + model.add(Embedding(embedding_num, embedding_dim, input_length=timesteps)) + crf = CRF(output_dim) + model.add(crf) + model.compile(optimizer='rmsprop', loss=crf.loss_function) + model.fit(x, y_onehot, epochs=1, batch_size=10) + + # test with masking, sparse target, dynamic length; test crf.viterbi_acc, crf.marginal_acc + + model = Sequential() + model.add(Embedding(embedding_num, embedding_dim, mask_zero=True)) + crf = CRF(output_dim, sparse_target=True) + model.add(crf) + model.compile(optimizer='rmsprop', loss=crf.loss_function, metrics=[crf.viterbi_acc, crf.marginal_acc]) + model.fit(x, y, epochs=1, batch_size=10) + + # check mask + y_pred = model.predict(x).argmax(-1) + assert (y_pred[0, -4:] == 0).all() # right padding + assert (y_pred[1, :5] == 0).all() # left padding + + # test `viterbi_acc + _, v_acc, _ = model.evaluate(x, y) + np_acc = (y_pred[x > 0] == y[:, :, 0][x > 0]).astype('float32').mean() + assert np.abs(v_acc - np_acc) < 1e-4 + + # test config + model.get_config() + + # test marginal learn mode, fix length + + model = Sequential() + model.add(Embedding(embedding_num, embedding_dim, input_length=timesteps, mask_zero=True)) + crf = CRF(output_dim, learn_mode='marginal', unroll=True) + model.add(crf) + model.compile(optimizer='rmsprop', loss=crf.loss_function) + model.fit(x, y_onehot, epochs=1, batch_size=10) + + # check mask (marginal output) + y_pred = model.predict(x) + assert_allclose(y_pred[0, -4:], 1. / output_dim, atol=1e-6) + assert_allclose(y_pred[1, :5], 1. / output_dim, atol=1e-6) + + # test marginal learn mode, but with Viterbi test_mode + model = Sequential() + model.add(Embedding(embedding_num, embedding_dim, input_length=timesteps, mask_zero=True)) + crf = CRF(output_dim, learn_mode='marginal', test_mode='viterbi') + model.add(crf) + model.compile(optimizer='rmsprop', loss=crf.loss_function, metrics=[crf.accuracy]) + model.fit(x, y_onehot, epochs=1, batch_size=10) + + y_pred = model.predict(x) + + # check y_pred is onehot vector (output from 'viterbi' test mode) + assert_allclose(np.eye(output_dim)[y_pred.argmax(-1)], y_pred, atol=1e-6) + + +if __name__ == '__main__': + pytest.main([__file__])