diff --git a/.gitignore b/.gitignore index ade7f374f..2dc978c9c 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,9 @@ # Python byte code files *.pyc +# Backup files +*.bak + # Emacs temporary files *~ *# diff --git a/.travis.yml b/.travis.yml index 92ca5db2c..c29b3ce06 100644 --- a/.travis.yml +++ b/.travis.yml @@ -37,8 +37,7 @@ matrix: - sphinx-build -W -b html -d _build/doctrees source _build/html - cd .. # Run Python linting. - - flake8 --ignore=E111,E114 - --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/numbuf/thirdparty/,src/common/format/,doc/source/conf.py + - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/numbuf/thirdparty/,src/common/format/,doc/source/conf.py - os: linux dist: trusty env: VALGRIND=1 PYTHON=2.7 diff --git a/examples/hyperopt/hyperopt_adaptive.py b/examples/hyperopt/hyperopt_adaptive.py index d0411e746..12357e7dc 100644 --- a/examples/hyperopt/hyperopt_adaptive.py +++ b/examples/hyperopt/hyperopt_adaptive.py @@ -27,124 +27,128 @@ parser.add_argument("--redis-address", default=None, type=str, if __name__ == "__main__": - args = parser.parse_args() + args = parser.parse_args() - ray.init(redis_address=args.redis_address) + ray.init(redis_address=args.redis_address) - # The number of training passes over the dataset to use for network. - steps = args.steps_per_segment + # The number of training passes over the dataset to use for network. + steps = args.steps_per_segment - # Load the mnist data and turn the data into remote objects. - print("Downloading the MNIST dataset. This may take a minute.") - mnist = input_data.read_data_sets("MNIST_data", one_hot=True) - train_images = ray.put(mnist.train.images) - train_labels = ray.put(mnist.train.labels) - validation_images = ray.put(mnist.validation.images) - validation_labels = ray.put(mnist.validation.labels) + # Load the mnist data and turn the data into remote objects. + print("Downloading the MNIST dataset. This may take a minute.") + mnist = input_data.read_data_sets("MNIST_data", one_hot=True) + train_images = ray.put(mnist.train.images) + train_labels = ray.put(mnist.train.labels) + validation_images = ray.put(mnist.validation.images) + validation_labels = ray.put(mnist.validation.labels) - # Keep track of the accuracies that we've seen at different numbers of - # iterations. - accuracies_by_num_steps = defaultdict(lambda: []) + # Keep track of the accuracies that we've seen at different numbers of + # iterations. + accuracies_by_num_steps = defaultdict(lambda: []) - # Define a method to determine if an experiment looks promising or not. - def is_promising(experiment_info): - accuracies = experiment_info["accuracies"] - total_num_steps = experiment_info["total_num_steps"] - comparable_accuracies = accuracies_by_num_steps[total_num_steps] - if len(comparable_accuracies) == 0: - if len(accuracies) == 1: - # This means that we haven't seen anything finish yet, so keep running - # this experiment. - return True - else: - # The experiment is promising if the second half of the accuracies are - # better than the first half of the accuracies. - return (np.mean(accuracies[:len(accuracies) // 2]) < - np.mean(accuracies[len(accuracies) // 2:])) - # Otherwise, continue running the experiment if it is in the top half of - # experiments we've seen so far at this point in time. - return np.mean(accuracy > np.array(comparable_accuracies)) > 0.5 + # Define a method to determine if an experiment looks promising or not. + def is_promising(experiment_info): + accuracies = experiment_info["accuracies"] + total_num_steps = experiment_info["total_num_steps"] + comparable_accuracies = accuracies_by_num_steps[total_num_steps] + if len(comparable_accuracies) == 0: + if len(accuracies) == 1: + # This means that we haven't seen anything finish yet, so keep + # running this experiment. + return True + else: + # The experiment is promising if the second half of the + # accuracies are better than the first half of the accuracies. + return (np.mean(accuracies[:len(accuracies) // 2]) < + np.mean(accuracies[len(accuracies) // 2:])) + # Otherwise, continue running the experiment if it is in the top half + # of experiments we've seen so far at this point in time. + return np.mean(accuracy > np.array(comparable_accuracies)) > 0.5 - # Keep track of all of the experiment segments that we're running. This - # dictionary uses the object ID of the experiment as the key. - experiment_info = {} - # Keep track of the curently running experiment IDs. - remaining_ids = [] + # Keep track of all of the experiment segments that we're running. This + # dictionary uses the object ID of the experiment as the key. + experiment_info = {} + # Keep track of the curently running experiment IDs. + remaining_ids = [] - # Keep track of the best hyperparameters and the best accuracy. - best_hyperparameters = None - best_accuracy = 0 + # Keep track of the best hyperparameters and the best accuracy. + best_hyperparameters = None + best_accuracy = 0 - # A function for generating random hyperparameters. - def generate_hyperparameters(): - return {"learning_rate": 10 ** np.random.uniform(-5, 5), - "batch_size": np.random.randint(1, 100), - "dropout": np.random.uniform(0, 1), - "stddev": 10 ** np.random.uniform(-5, 5)} + # A function for generating random hyperparameters. + def generate_hyperparameters(): + return {"learning_rate": 10 ** np.random.uniform(-5, 5), + "batch_size": np.random.randint(1, 100), + "dropout": np.random.uniform(0, 1), + "stddev": 10 ** np.random.uniform(-5, 5)} - # Launch some initial experiments. - for _ in range(args.num_starting_segments): - hyperparameters = generate_hyperparameters() - experiment_id = objective.train_cnn_and_compute_accuracy.remote( - hyperparameters, steps, train_images, train_labels, validation_images, - validation_labels) - experiment_info[experiment_id] = {"hyperparameters": hyperparameters, - "total_num_steps": steps, - "accuracies": []} - remaining_ids.append(experiment_id) + # Launch some initial experiments. + for _ in range(args.num_starting_segments): + hyperparameters = generate_hyperparameters() + experiment_id = objective.train_cnn_and_compute_accuracy.remote( + hyperparameters, steps, train_images, train_labels, + validation_images, validation_labels) + experiment_info[experiment_id] = {"hyperparameters": hyperparameters, + "total_num_steps": steps, + "accuracies": []} + remaining_ids.append(experiment_id) - for _ in range(args.num_segments): - # Wait for a segment of an experiment to finish. - ready_ids, remaining_ids = ray.wait(remaining_ids, num_returns=1) - experiment_id = ready_ids[0] - # Get the accuracy and the weights. - accuracy, weights = ray.get(experiment_id) - # Update the experiment info. - previous_info = experiment_info[experiment_id] - previous_info["accuracies"].append(accuracy) + for _ in range(args.num_segments): + # Wait for a segment of an experiment to finish. + ready_ids, remaining_ids = ray.wait(remaining_ids, num_returns=1) + experiment_id = ready_ids[0] + # Get the accuracy and the weights. + accuracy, weights = ray.get(experiment_id) + # Update the experiment info. + previous_info = experiment_info[experiment_id] + previous_info["accuracies"].append(accuracy) - # Update the best accuracy and best hyperparameters. - if accuracy > best_accuracy: - best_hyperparameters = hyperparameters - best_accuracy = accuracy + # Update the best accuracy and best hyperparameters. + if accuracy > best_accuracy: + best_hyperparameters = hyperparameters + best_accuracy = accuracy - if is_promising(previous_info): - # If the experiment still looks promising, then continue running it. - print("Continuing to run the experiment with hyperparameters {}.".format( - previous_info["hyperparameters"])) - new_hyperparameters = previous_info["hyperparameters"] - new_info = {"hyperparameters": new_hyperparameters, - "total_num_steps": previous_info["total_num_steps"] + steps, - "accuracies": previous_info["accuracies"][:]} - starting_weights = weights - else: - # If the experiment does not look promising, start a new experiment. - print("Ending the experiment with hyperparameters {}.".format( - previous_info["hyperparameters"])) - new_hyperparameters = generate_hyperparameters() - new_info = {"hyperparameters": new_hyperparameters, - "total_num_steps": steps, - "accuracies": []} - starting_weights = None + if is_promising(previous_info): + # If the experiment still looks promising, then continue running + # it. + print("Continuing to run the experiment with hyperparameters {}." + .format(previous_info["hyperparameters"])) + new_hyperparameters = previous_info["hyperparameters"] + new_info = {"hyperparameters": new_hyperparameters, + "total_num_steps": (previous_info["total_num_steps"] + + steps), + "accuracies": previous_info["accuracies"][:]} + starting_weights = weights + else: + # If the experiment does not look promising, start a new + # experiment. + print("Ending the experiment with hyperparameters {}." + .format(previous_info["hyperparameters"])) + new_hyperparameters = generate_hyperparameters() + new_info = {"hyperparameters": new_hyperparameters, + "total_num_steps": steps, + "accuracies": []} + starting_weights = None - # Start running the next segment. - new_experiment_id = objective.train_cnn_and_compute_accuracy.remote( - new_hyperparameters, steps, train_images, train_labels, - validation_images, validation_labels, weights=starting_weights) - experiment_info[new_experiment_id] = new_info - remaining_ids.append(new_experiment_id) + # Start running the next segment. + new_experiment_id = objective.train_cnn_and_compute_accuracy.remote( + new_hyperparameters, steps, train_images, train_labels, + validation_images, validation_labels, weights=starting_weights) + experiment_info[new_experiment_id] = new_info + remaining_ids.append(new_experiment_id) - # Update the set of all accuracies that we've seen. - accuracies_by_num_steps[previous_info["total_num_steps"]].append(accuracy) + # Update the set of all accuracies that we've seen. + accuracies_by_num_steps[previous_info["total_num_steps"]].append( + accuracy) - # Record the best performing set of hyperparameters. - print("""Best accuracy was {:.3} with - learning_rate: {:.2} - batch_size: {} - dropout: {:.2} - stddev: {:.2} - """.format(100 * best_accuracy, - best_hyperparameters["learning_rate"], - best_hyperparameters["batch_size"], - best_hyperparameters["dropout"], - best_hyperparameters["stddev"])) + # Record the best performing set of hyperparameters. + print("""Best accuracy was {:.3} with + learning_rate: {:.2} + batch_size: {} + dropout: {:.2} + stddev: {:.2} + """.format(100 * best_accuracy, + best_hyperparameters["learning_rate"], + best_hyperparameters["batch_size"], + best_hyperparameters["dropout"], + best_hyperparameters["stddev"])) diff --git a/examples/hyperopt/hyperopt_simple.py b/examples/hyperopt/hyperopt_simple.py index 03c6b81b5..d1a400477 100644 --- a/examples/hyperopt/hyperopt_simple.py +++ b/examples/hyperopt/hyperopt_simple.py @@ -21,80 +21,80 @@ parser.add_argument("--redis-address", default=None, type=str, if __name__ == "__main__": - args = parser.parse_args() + args = parser.parse_args() - ray.init(redis_address=args.redis_address) + ray.init(redis_address=args.redis_address) - # The number of sets of random hyperparameters to try. - trials = args.trials - # The number of training passes over the dataset to use for network. - steps = args.steps + # The number of sets of random hyperparameters to try. + trials = args.trials + # The number of training passes over the dataset to use for network. + steps = args.steps - # Load the mnist data and turn the data into remote objects. - print("Downloading the MNIST dataset. This may take a minute.") - mnist = input_data.read_data_sets("MNIST_data", one_hot=True) - train_images = ray.put(mnist.train.images) - train_labels = ray.put(mnist.train.labels) - validation_images = ray.put(mnist.validation.images) - validation_labels = ray.put(mnist.validation.labels) + # Load the mnist data and turn the data into remote objects. + print("Downloading the MNIST dataset. This may take a minute.") + mnist = input_data.read_data_sets("MNIST_data", one_hot=True) + train_images = ray.put(mnist.train.images) + train_labels = ray.put(mnist.train.labels) + validation_images = ray.put(mnist.validation.images) + validation_labels = ray.put(mnist.validation.labels) - # Keep track of the best hyperparameters and the best accuracy. - best_hyperparamemeters = None - best_accuracy = 0 - # This list holds the object IDs for all of the experiments that we have - # launched and that have not yet been processed. - remaining_ids = [] - # This is a dictionary mapping the object ID of an experiment to the - # hyerparameters used for that experiment. - hyperparameters_mapping = {} + # Keep track of the best hyperparameters and the best accuracy. + best_hyperparamemeters = None + best_accuracy = 0 + # This list holds the object IDs for all of the experiments that we have + # launched and that have not yet been processed. + remaining_ids = [] + # This is a dictionary mapping the object ID of an experiment to the + # hyerparameters used for that experiment. + hyperparameters_mapping = {} - # A function for generating random hyperparameters. - def generate_hyperparameters(): - return {"learning_rate": 10 ** np.random.uniform(-5, 5), - "batch_size": np.random.randint(1, 100), - "dropout": np.random.uniform(0, 1), - "stddev": 10 ** np.random.uniform(-5, 5)} + # A function for generating random hyperparameters. + def generate_hyperparameters(): + return {"learning_rate": 10 ** np.random.uniform(-5, 5), + "batch_size": np.random.randint(1, 100), + "dropout": np.random.uniform(0, 1), + "stddev": 10 ** np.random.uniform(-5, 5)} - # Randomly generate some hyperparameters, and launch a task for each set. - for i in range(trials): - hyperparameters = generate_hyperparameters() - accuracy_id = objective.train_cnn_and_compute_accuracy.remote( - hyperparameters, steps, train_images, train_labels, validation_images, - validation_labels) - remaining_ids.append(accuracy_id) - # Keep track of which hyperparameters correspond to this experiment. - hyperparameters_mapping[accuracy_id] = hyperparameters + # Randomly generate some hyperparameters, and launch a task for each set. + for i in range(trials): + hyperparameters = generate_hyperparameters() + accuracy_id = objective.train_cnn_and_compute_accuracy.remote( + hyperparameters, steps, train_images, train_labels, + validation_images, validation_labels) + remaining_ids.append(accuracy_id) + # Keep track of which hyperparameters correspond to this experiment. + hyperparameters_mapping[accuracy_id] = hyperparameters - # Fetch and print the results of the tasks in the order that they complete. - for i in range(trials): - # Use ray.wait to get the object ID of the first task that completes. - ready_ids, remaining_ids = ray.wait(remaining_ids) - # Process the output of this task. - result_id = ready_ids[0] - hyperparameters = hyperparameters_mapping[result_id] - accuracy, _ = ray.get(result_id) - print("""We achieve accuracy {:.3}% with - learning_rate: {:.2} - batch_size: {} - dropout: {:.2} - stddev: {:.2} - """.format(100 * accuracy, - hyperparameters["learning_rate"], - hyperparameters["batch_size"], - hyperparameters["dropout"], - hyperparameters["stddev"])) - if accuracy > best_accuracy: - best_hyperparameters = hyperparameters - best_accuracy = accuracy + # Fetch and print the results of the tasks in the order that they complete. + for i in range(trials): + # Use ray.wait to get the object ID of the first task that completes. + ready_ids, remaining_ids = ray.wait(remaining_ids) + # Process the output of this task. + result_id = ready_ids[0] + hyperparameters = hyperparameters_mapping[result_id] + accuracy, _ = ray.get(result_id) + print("""We achieve accuracy {:.3}% with + learning_rate: {:.2} + batch_size: {} + dropout: {:.2} + stddev: {:.2} + """.format(100 * accuracy, + hyperparameters["learning_rate"], + hyperparameters["batch_size"], + hyperparameters["dropout"], + hyperparameters["stddev"])) + if accuracy > best_accuracy: + best_hyperparameters = hyperparameters + best_accuracy = accuracy - # Record the best performing set of hyperparameters. - print("""Best accuracy over {} trials was {:.3} with - learning_rate: {:.2} - batch_size: {} - dropout: {:.2} - stddev: {:.2} - """.format(trials, 100 * best_accuracy, - best_hyperparameters["learning_rate"], - best_hyperparameters["batch_size"], - best_hyperparameters["dropout"], - best_hyperparameters["stddev"])) + # Record the best performing set of hyperparameters. + print("""Best accuracy over {} trials was {:.3} with + learning_rate: {:.2} + batch_size: {} + dropout: {:.2} + stddev: {:.2} + """.format(trials, 100 * best_accuracy, + best_hyperparameters["learning_rate"], + best_hyperparameters["batch_size"], + best_hyperparameters["dropout"], + best_hyperparameters["stddev"])) diff --git a/examples/hyperopt/objective.py b/examples/hyperopt/objective.py index efee7510d..8662b77e7 100644 --- a/examples/hyperopt/objective.py +++ b/examples/hyperopt/objective.py @@ -11,59 +11,59 @@ import tensorflow as tf def get_batch(data, batch_index, batch_size): - # This method currently drops data when num_data is not divisible by - # batch_size. - num_data = data.shape[0] - num_batches = num_data // batch_size - batch_index %= num_batches - return data[(batch_index * batch_size):((batch_index + 1) * batch_size)] + # This method currently drops data when num_data is not divisible by + # batch_size. + num_data = data.shape[0] + num_batches = num_data // batch_size + batch_index %= num_batches + return data[(batch_index * batch_size):((batch_index + 1) * batch_size)] def weight(shape, stddev): - initial = tf.truncated_normal(shape, stddev=stddev) - return tf.Variable(initial) + initial = tf.truncated_normal(shape, stddev=stddev) + return tf.Variable(initial) def bias(shape): - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) def conv2d(x, W): - return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") def max_pool_2x2(x): - return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], + padding="SAME") def cnn_setup(x, y, keep_prob, lr, stddev): - first_hidden = 32 - second_hidden = 64 - fc_hidden = 1024 - W_conv1 = weight([5, 5, 1, first_hidden], stddev) - B_conv1 = bias([first_hidden]) - x_image = tf.reshape(x, [-1, 28, 28, 1]) - h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + B_conv1) - h_pool1 = max_pool_2x2(h_conv1) - W_conv2 = weight([5, 5, first_hidden, second_hidden], stddev) - b_conv2 = bias([second_hidden]) - h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) - h_pool2 = max_pool_2x2(h_conv2) - W_fc1 = weight([7 * 7 * second_hidden, fc_hidden], stddev) - b_fc1 = bias([fc_hidden]) - h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * second_hidden]) - h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) - h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) - W_fc2 = weight([fc_hidden, 10], stddev) - b_fc2 = bias([10]) - y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) - cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_conv), - reduction_indices=[1])) - correct_pred = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1)) - return (tf.train.AdamOptimizer(lr).minimize(cross_entropy), - tf.reduce_mean(tf.cast(correct_pred, tf.float32)), cross_entropy) + first_hidden = 32 + second_hidden = 64 + fc_hidden = 1024 + W_conv1 = weight([5, 5, 1, first_hidden], stddev) + B_conv1 = bias([first_hidden]) + x_image = tf.reshape(x, [-1, 28, 28, 1]) + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + B_conv1) + h_pool1 = max_pool_2x2(h_conv1) + W_conv2 = weight([5, 5, first_hidden, second_hidden], stddev) + b_conv2 = bias([second_hidden]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + h_pool2 = max_pool_2x2(h_conv2) + W_fc1 = weight([7 * 7 * second_hidden, fc_hidden], stddev) + b_fc1 = bias([fc_hidden]) + h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * second_hidden]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + W_fc2 = weight([fc_hidden, 10], stddev) + b_fc2 = bias([10]) + y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) + cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_conv), + reduction_indices=[1])) + correct_pred = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1)) + return (tf.train.AdamOptimizer(lr).minimize(cross_entropy), + tf.reduce_mean(tf.cast(correct_pred, tf.float32)), cross_entropy) # Define a remote function that takes a set of hyperparameters as well as the @@ -72,42 +72,42 @@ def cnn_setup(x, y, keep_prob, lr, stddev): def train_cnn_and_compute_accuracy(params, steps, train_images, train_labels, validation_images, validation_labels, weights=None): - # Extract the hyperparameters from the params dictionary. - learning_rate = params["learning_rate"] - batch_size = params["batch_size"] - keep = 1 - params["dropout"] - stddev = params["stddev"] - # Create the network and related variables. - with tf.Graph().as_default(): - # Create the input placeholders for the network. - x = tf.placeholder(tf.float32, shape=[None, 784]) - y = tf.placeholder(tf.float32, shape=[None, 10]) - keep_prob = tf.placeholder(tf.float32) - # Create the network. - train_step, accuracy, loss = cnn_setup(x, y, keep_prob, learning_rate, - stddev) - # Do the training and evaluation. - with tf.Session() as sess: - # Use the TensorFlowVariables utility. This is only necessary if we want - # to set and get the weights. - variables = ray.experimental.TensorFlowVariables(loss, sess) - # Initialize the network weights. - sess.run(tf.global_variables_initializer()) - # If some network weights were passed in, set those. - if weights is not None: - variables.set_weights(weights) - # Do some steps of training. - for i in range(1, steps + 1): - # Fetch the next batch of data. - image_batch = get_batch(train_images, i, batch_size) - label_batch = get_batch(train_labels, i, batch_size) - # Do one step of training. - sess.run(train_step, feed_dict={x: image_batch, y: label_batch, - keep_prob: keep}) - # Training is done, so compute the validation accuracy and the current - # weights and return. - totalacc = accuracy.eval(feed_dict={x: validation_images, - y: validation_labels, - keep_prob: 1.0}) - new_weights = variables.get_weights() - return float(totalacc), new_weights + # Extract the hyperparameters from the params dictionary. + learning_rate = params["learning_rate"] + batch_size = params["batch_size"] + keep = 1 - params["dropout"] + stddev = params["stddev"] + # Create the network and related variables. + with tf.Graph().as_default(): + # Create the input placeholders for the network. + x = tf.placeholder(tf.float32, shape=[None, 784]) + y = tf.placeholder(tf.float32, shape=[None, 10]) + keep_prob = tf.placeholder(tf.float32) + # Create the network. + train_step, accuracy, loss = cnn_setup(x, y, keep_prob, learning_rate, + stddev) + # Do the training and evaluation. + with tf.Session() as sess: + # Use the TensorFlowVariables utility. This is only necessary if we + # want to set and get the weights. + variables = ray.experimental.TensorFlowVariables(loss, sess) + # Initialize the network weights. + sess.run(tf.global_variables_initializer()) + # If some network weights were passed in, set those. + if weights is not None: + variables.set_weights(weights) + # Do some steps of training. + for i in range(1, steps + 1): + # Fetch the next batch of data. + image_batch = get_batch(train_images, i, batch_size) + label_batch = get_batch(train_labels, i, batch_size) + # Do one step of training. + sess.run(train_step, feed_dict={x: image_batch, y: label_batch, + keep_prob: keep}) + # Training is done, so compute the validation accuracy and the + # current weights and return. + totalacc = accuracy.eval(feed_dict={x: validation_images, + y: validation_labels, + keep_prob: 1.0}) + new_weights = variables.get_weights() + return float(totalacc), new_weights diff --git a/examples/lbfgs/driver.py b/examples/lbfgs/driver.py index e81c8b52d..1a05f0013 100644 --- a/examples/lbfgs/driver.py +++ b/examples/lbfgs/driver.py @@ -12,128 +12,129 @@ from tensorflow.examples.tutorials.mnist import input_data class LinearModel(object): - """Simple class for a one layer neural network. + """Simple class for a one layer neural network. - Note that this code does not initialize the network weights. Instead weights - are set via self.variables.set_weights. + Note that this code does not initialize the network weights. Instead + weights are set via self.variables.set_weights. - Example: - net = LinearModel([10, 10]) - weights = [np.random.normal(size=[10, 10]), np.random.normal(size=[10])] - variable_names = [v.name for v in net.variables] - net.variables.set_weights(dict(zip(variable_names, weights))) + Example: + net = LinearModel([10, 10]) + weights = [np.random.normal(size=[10, 10]), + np.random.normal(size=[10])] + variable_names = [v.name for v in net.variables] + net.variables.set_weights(dict(zip(variable_names, weights))) - Attributes: - x (tf.placeholder): Input vector. - w (tf.Variable): Weight matrix. - b (tf.Variable): Bias vector. - y_ (tf.placeholder): Input result vector. - cross_entropy (tf.Operation): Final layer of network. - cross_entropy_grads (tf.Operation): Gradient computation. - sess (tf.Session): Session used for training. - variables (TensorFlowVariables): Extracted variables and methods to - manipulate them. - """ - def __init__(self, shape): - """Creates a LinearModel object.""" - x = tf.placeholder(tf.float32, [None, shape[0]]) - w = tf.Variable(tf.zeros(shape)) - b = tf.Variable(tf.zeros(shape[1])) - self.x = x - self.w = w - self.b = b - y = tf.nn.softmax(tf.matmul(x, w) + b) - y_ = tf.placeholder(tf.float32, [None, shape[1]]) - self.y_ = y_ - cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), - reduction_indices=[1])) - self.cross_entropy = cross_entropy - self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b]) - self.sess = tf.Session() - # In order to get and set the weights, we pass in the loss function to - # Ray's TensorFlowVariables to automatically create methods to modify the - # weights. - self.variables = ray.experimental.TensorFlowVariables(cross_entropy, - self.sess) + Attributes: + x (tf.placeholder): Input vector. + w (tf.Variable): Weight matrix. + b (tf.Variable): Bias vector. + y_ (tf.placeholder): Input result vector. + cross_entropy (tf.Operation): Final layer of network. + cross_entropy_grads (tf.Operation): Gradient computation. + sess (tf.Session): Session used for training. + variables (TensorFlowVariables): Extracted variables and methods to + manipulate them. + """ + def __init__(self, shape): + """Creates a LinearModel object.""" + x = tf.placeholder(tf.float32, [None, shape[0]]) + w = tf.Variable(tf.zeros(shape)) + b = tf.Variable(tf.zeros(shape[1])) + self.x = x + self.w = w + self.b = b + y = tf.nn.softmax(tf.matmul(x, w) + b) + y_ = tf.placeholder(tf.float32, [None, shape[1]]) + self.y_ = y_ + cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), + reduction_indices=[1])) + self.cross_entropy = cross_entropy + self.cross_entropy_grads = tf.gradients(cross_entropy, [w, b]) + self.sess = tf.Session() + # In order to get and set the weights, we pass in the loss function to + # Ray's TensorFlowVariables to automatically create methods to modify + # the weights. + self.variables = ray.experimental.TensorFlowVariables(cross_entropy, + self.sess) - def loss(self, xs, ys): - """Computes the loss of the network.""" - return float(self.sess.run(self.cross_entropy, - feed_dict={self.x: xs, self.y_: ys})) + def loss(self, xs, ys): + """Computes the loss of the network.""" + return float(self.sess.run(self.cross_entropy, + feed_dict={self.x: xs, self.y_: ys})) - def grad(self, xs, ys): - """Computes the gradients of the network.""" - return self.sess.run(self.cross_entropy_grads, - feed_dict={self.x: xs, self.y_: ys}) + def grad(self, xs, ys): + """Computes the gradients of the network.""" + return self.sess.run(self.cross_entropy_grads, + feed_dict={self.x: xs, self.y_: ys}) @ray.remote class NetActor(object): - def __init__(self, xs, ys): - os.environ["CUDA_VISIBLE_DEVICES"] = "" - with tf.device("/cpu:0"): - self.net = LinearModel([784, 10]) - self.xs = xs - self.ys = ys + def __init__(self, xs, ys): + os.environ["CUDA_VISIBLE_DEVICES"] = "" + with tf.device("/cpu:0"): + self.net = LinearModel([784, 10]) + self.xs = xs + self.ys = ys - # Compute the loss on a batch of data. - def loss(self, theta): - net = self.net - net.variables.set_flat(theta) - return net.loss(self.xs, self.ys) + # Compute the loss on a batch of data. + def loss(self, theta): + net = self.net + net.variables.set_flat(theta) + return net.loss(self.xs, self.ys) - # Compute the gradient of the loss on a batch of data. - def grad(self, theta): - net = self.net - net.variables.set_flat(theta) - gradients = net.grad(self.xs, self.ys) - return np.concatenate([g.flatten() for g in gradients]) + # Compute the gradient of the loss on a batch of data. + def grad(self, theta): + net = self.net + net.variables.set_flat(theta) + gradients = net.grad(self.xs, self.ys) + return np.concatenate([g.flatten() for g in gradients]) - def get_flat_size(self): - return self.net.variables.get_flat_size() + def get_flat_size(self): + return self.net.variables.get_flat_size() # Compute the loss on the entire dataset. def full_loss(theta): - theta_id = ray.put(theta) - loss_ids = [actor.loss.remote(theta_id) for actor in actors] - return sum(ray.get(loss_ids)) + theta_id = ray.put(theta) + loss_ids = [actor.loss.remote(theta_id) for actor in actors] + return sum(ray.get(loss_ids)) # Compute the gradient of the loss on the entire dataset. def full_grad(theta): - theta_id = ray.put(theta) - grad_ids = [actor.grad.remote(theta_id) for actor in actors] - # The float64 conversion is necessary for use with fmin_l_bfgs_b. - return sum(ray.get(grad_ids)).astype("float64") + theta_id = ray.put(theta) + grad_ids = [actor.grad.remote(theta_id) for actor in actors] + # The float64 conversion is necessary for use with fmin_l_bfgs_b. + return sum(ray.get(grad_ids)).astype("float64") if __name__ == "__main__": - ray.init(redirect_output=True) + ray.init(redirect_output=True) - # From the perspective of scipy.optimize.fmin_l_bfgs_b, full_loss is simply a - # function which takes some parameters theta, and computes a loss. Similarly, - # full_grad is a function which takes some parameters theta, and computes the - # gradient of the loss. Internally, these functions use Ray to distribute the - # computation of the loss and the gradient over the data that is represented - # by the remote object IDs x_batches and y_batches and which is potentially - # distributed over a cluster. However, these details are hidden from - # scipy.optimize.fmin_l_bfgs_b, which simply uses it to run the L-BFGS - # algorithm. + # From the perspective of scipy.optimize.fmin_l_bfgs_b, full_loss is simply + # a function which takes some parameters theta, and computes a loss. + # Similarly, full_grad is a function which takes some parameters theta, and + # computes the gradient of the loss. Internally, these functions use Ray to + # distribute the computation of the loss and the gradient over the data + # that is represented by the remote object IDs x_batches and y_batches and + # which is potentially distributed over a cluster. However, these details + # are hidden from scipy.optimize.fmin_l_bfgs_b, which simply uses it to run + # the L-BFGS algorithm. - # Load the mnist data and turn the data into remote objects. - print("Downloading the MNIST dataset. This may take a minute.") - mnist = input_data.read_data_sets("MNIST_data", one_hot=True) - num_batches = 10 - batch_size = mnist.train.num_examples // num_batches - batches = [mnist.train.next_batch(batch_size) for _ in range(num_batches)] - print("Putting MNIST in the object store.") - actors = [NetActor.remote(xs, ys) for (xs, ys) in batches] - # Initialize the weights for the network to the vector of all zeros. - dim = ray.get(actors[0].get_flat_size.remote()) - theta_init = 1e-2 * np.random.normal(size=dim) + # Load the mnist data and turn the data into remote objects. + print("Downloading the MNIST dataset. This may take a minute.") + mnist = input_data.read_data_sets("MNIST_data", one_hot=True) + num_batches = 10 + batch_size = mnist.train.num_examples // num_batches + batches = [mnist.train.next_batch(batch_size) for _ in range(num_batches)] + print("Putting MNIST in the object store.") + actors = [NetActor.remote(xs, ys) for (xs, ys) in batches] + # Initialize the weights for the network to the vector of all zeros. + dim = ray.get(actors[0].get_flat_size.remote()) + theta_init = 1e-2 * np.random.normal(size=dim) - # Use L-BFGS to minimize the loss function. - print("Running L-BFGS.") - result = scipy.optimize.fmin_l_bfgs_b(full_loss, theta_init, maxiter=10, - fprime=full_grad, disp=True) + # Use L-BFGS to minimize the loss function. + print("Running L-BFGS.") + result = scipy.optimize.fmin_l_bfgs_b(full_loss, theta_init, maxiter=10, + fprime=full_grad, disp=True) diff --git a/examples/resnet/cifar_input.py b/examples/resnet/cifar_input.py index 3775b656e..4bf2d42ff 100644 --- a/examples/resnet/cifar_input.py +++ b/examples/resnet/cifar_input.py @@ -10,108 +10,111 @@ import tensorflow as tf def build_data(data_path, size, dataset): - """Creates the queue and preprocessing operations for the dataset. + """Creates the queue and preprocessing operations for the dataset. - Args: - data_path: Filename for cifar10 data. - size: The number of images in the dataset. - dataset: The dataset we are using. + Args: + data_path: Filename for cifar10 data. + size: The number of images in the dataset. + dataset: The dataset we are using. - Returns: - queue: A Tensorflow queue for extracting the images and labels. - """ - image_size = 32 - if dataset == "cifar10": - label_bytes = 1 - label_offset = 0 - elif dataset == "cifar100": - label_bytes = 1 - label_offset = 1 - depth = 3 - image_bytes = image_size * image_size * depth - record_bytes = label_bytes + label_offset + image_bytes + Returns: + queue: A Tensorflow queue for extracting the images and labels. + """ + image_size = 32 + if dataset == "cifar10": + label_bytes = 1 + label_offset = 0 + elif dataset == "cifar100": + label_bytes = 1 + label_offset = 1 + depth = 3 + image_bytes = image_size * image_size * depth + record_bytes = label_bytes + label_offset + image_bytes - data_files = tf.gfile.Glob(data_path) - file_queue = tf.train.string_input_producer(data_files, shuffle=True) - # Read examples from files in the filename queue. - reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) - _, value = reader.read(file_queue) + data_files = tf.gfile.Glob(data_path) + file_queue = tf.train.string_input_producer(data_files, shuffle=True) + # Read examples from files in the filename queue. + reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) + _, value = reader.read(file_queue) - # Convert these examples to dense labels and processed images. - record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes]) - label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32) - # Convert from string to [depth * height * width] to [depth, height, width]. - depth_major = tf.reshape(tf.slice(record, [label_bytes], [image_bytes]), - [depth, image_size, image_size]) - # Convert from [depth, height, width] to [height, width, depth]. - image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) - queue = tf.train.shuffle_batch([image, label], size, size, 0, num_threads=16) - return queue + # Convert these examples to dense labels and processed images. + record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes]) + label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32) + # Convert from string to [depth * height * width] to + # [depth, height, width]. + depth_major = tf.reshape(tf.slice(record, [label_bytes], [image_bytes]), + [depth, image_size, image_size]) + # Convert from [depth, height, width] to [height, width, depth]. + image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) + queue = tf.train.shuffle_batch([image, label], size, size, 0, + num_threads=16) + return queue def build_input(data, batch_size, dataset, train): - """Build CIFAR image and labels. + """Build CIFAR image and labels. - Args: - data_path: Filename for cifar10 data. - batch_size: Input batch size. - train: True if we are training and false if we are testing. + Args: + data_path: Filename for cifar10 data. + batch_size: Input batch size. + train: True if we are training and false if we are testing. - Returns: - images: Batches of images of size [batch_size, image_size, image_size, 3]. - labels: Batches of labels of size [batch_size, num_classes]. + Returns: + images: Batches of images of size + [batch_size, image_size, image_size, 3]. + labels: Batches of labels of size [batch_size, num_classes]. - Raises: - ValueError: When the specified dataset is not supported. - """ - images_constant = tf.constant(data[0]) - labels_constant = tf.constant(data[1]) - image_size = 32 - depth = 3 - num_classes = 10 if dataset == "cifar10" else 100 - image, label = tf.train.slice_input_producer([images_constant, - labels_constant], - capacity=16 * batch_size) - if train: - image = tf.image.resize_image_with_crop_or_pad(image, image_size + 4, - image_size + 4) - image = tf.random_crop(image, [image_size, image_size, 3]) - image = tf.image.random_flip_left_right(image) - image = tf.image.per_image_standardization(image) - example_queue = tf.RandomShuffleQueue( - capacity=16 * batch_size, - min_after_dequeue=8 * batch_size, - dtypes=[tf.float32, tf.int32], - shapes=[[image_size, image_size, depth], [1]]) - num_threads = 16 - else: - image = tf.image.resize_image_with_crop_or_pad(image, image_size, - image_size) - image = tf.image.per_image_standardization(image) - example_queue = tf.FIFOQueue( - 3 * batch_size, - dtypes=[tf.float32, tf.int32], - shapes=[[image_size, image_size, depth], [1]]) - num_threads = 1 + Raises: + ValueError: When the specified dataset is not supported. + """ + images_constant = tf.constant(data[0]) + labels_constant = tf.constant(data[1]) + image_size = 32 + depth = 3 + num_classes = 10 if dataset == "cifar10" else 100 + image, label = tf.train.slice_input_producer([images_constant, + labels_constant], + capacity=16 * batch_size) + if train: + image = tf.image.resize_image_with_crop_or_pad(image, image_size + 4, + image_size + 4) + image = tf.random_crop(image, [image_size, image_size, 3]) + image = tf.image.random_flip_left_right(image) + image = tf.image.per_image_standardization(image) + example_queue = tf.RandomShuffleQueue( + capacity=16 * batch_size, + min_after_dequeue=8 * batch_size, + dtypes=[tf.float32, tf.int32], + shapes=[[image_size, image_size, depth], [1]]) + num_threads = 16 + else: + image = tf.image.resize_image_with_crop_or_pad(image, image_size, + image_size) + image = tf.image.per_image_standardization(image) + example_queue = tf.FIFOQueue( + 3 * batch_size, + dtypes=[tf.float32, tf.int32], + shapes=[[image_size, image_size, depth], [1]]) + num_threads = 1 - example_enqueue_op = example_queue.enqueue([image, label]) - tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner( - example_queue, [example_enqueue_op] * num_threads)) + example_enqueue_op = example_queue.enqueue([image, label]) + tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner( + example_queue, [example_enqueue_op] * num_threads)) - # Read "batch" labels + images from the example queue. - images, labels = example_queue.dequeue_many(batch_size) - labels = tf.reshape(labels, [batch_size, 1]) - indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1]) - labels = tf.sparse_to_dense( - tf.concat([indices, labels], 1), - [batch_size, num_classes], 1.0, 0.0) + # Read "batch" labels + images from the example queue. + images, labels = example_queue.dequeue_many(batch_size) + labels = tf.reshape(labels, [batch_size, 1]) + indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1]) + labels = tf.sparse_to_dense( + tf.concat([indices, labels], 1), + [batch_size, num_classes], 1.0, 0.0) - assert len(images.get_shape()) == 4 - assert images.get_shape()[0] == batch_size - assert images.get_shape()[-1] == 3 - assert len(labels.get_shape()) == 2 - assert labels.get_shape()[0] == batch_size - assert labels.get_shape()[1] == num_classes - if not train: - tf.summary.image("images", images) - return images, labels + assert len(images.get_shape()) == 4 + assert images.get_shape()[0] == batch_size + assert images.get_shape()[-1] == 3 + assert len(labels.get_shape()) == 2 + assert labels.get_shape()[0] == batch_size + assert labels.get_shape()[1] == num_classes + if not train: + tf.summary.image("images", images) + return images, labels diff --git a/examples/resnet/resnet_main.py b/examples/resnet/resnet_main.py index 7d29ff694..f6be545c5 100644 --- a/examples/resnet/resnet_main.py +++ b/examples/resnet/resnet_main.py @@ -17,8 +17,8 @@ import resnet_model # Tensorflow must be at least version 1.0.0 for the example to work. if int(tf.__version__.split(".")[0]) < 1: - raise Exception("Your Tensorflow version is less than 1.0.0. Please update " - "Tensorflow to the latest version.") + raise Exception("Your Tensorflow version is less than 1.0.0. Please " + "update Tensorflow to the latest version.") parser = argparse.ArgumentParser(description="Run the ResNet example.") parser.add_argument("--dataset", default="cifar10", type=str, @@ -44,194 +44,199 @@ use_gpu = 1 if int(FLAGS.num_gpus) > 0 else 0 @ray.remote def get_data(path, size, dataset): - # Retrieves all preprocessed images and labels using a tensorflow queue. - # This only uses the cpu. - os.environ["CUDA_VISIBLE_DEVICES"] = "" - with tf.device("/cpu:0"): - queue = cifar_input.build_data(path, size, dataset) - sess = tf.Session() - coord = tf.train.Coordinator() - tf.train.start_queue_runners(sess, coord=coord) - images, labels = sess.run(queue) - coord.request_stop() - sess.close() - return images, labels + # Retrieves all preprocessed images and labels using a tensorflow queue. + # This only uses the cpu. + os.environ["CUDA_VISIBLE_DEVICES"] = "" + with tf.device("/cpu:0"): + queue = cifar_input.build_data(path, size, dataset) + sess = tf.Session() + coord = tf.train.Coordinator() + tf.train.start_queue_runners(sess, coord=coord) + images, labels = sess.run(queue) + coord.request_stop() + sess.close() + return images, labels @ray.remote(num_gpus=use_gpu) class ResNetTrainActor(object): - def __init__(self, data, dataset, num_gpus): - if num_gpus > 0: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i - in ray.get_gpu_ids()]) - hps = resnet_model.HParams( - batch_size=128, - num_classes=100 if dataset == "cifar100" else 10, - min_lrn_rate=0.0001, - lrn_rate=0.1, - num_residual_units=5, - use_bottleneck=False, - weight_decay_rate=0.0002, - relu_leakiness=0.1, - optimizer="mom", - num_gpus=num_gpus) + def __init__(self, data, dataset, num_gpus): + if num_gpus > 0: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + [str(i) for i in ray.get_gpu_ids()]) + hps = resnet_model.HParams( + batch_size=128, + num_classes=100 if dataset == "cifar100" else 10, + min_lrn_rate=0.0001, + lrn_rate=0.1, + num_residual_units=5, + use_bottleneck=False, + weight_decay_rate=0.0002, + relu_leakiness=0.1, + optimizer="mom", + num_gpus=num_gpus) - # We seed each actor differently so that each actor operates on a different - # subset of data. - if num_gpus > 0: - tf.set_random_seed(ray.get_gpu_ids()[0] + 1) - else: - # Only a single actor in this case. - tf.set_random_seed(1) + # We seed each actor differently so that each actor operates on a + # different subset of data. + if num_gpus > 0: + tf.set_random_seed(ray.get_gpu_ids()[0] + 1) + else: + # Only a single actor in this case. + tf.set_random_seed(1) - input_images = data[0] - input_labels = data[1] - with tf.device("/gpu:0" if num_gpus > 0 else "/cpu:0"): - # Build the model. - images, labels = cifar_input.build_input([input_images, input_labels], - hps.batch_size, dataset, False) - self.model = resnet_model.ResNet(hps, images, labels, "train") - self.model.build_graph() - config = tf.ConfigProto(allow_soft_placement=True) - sess = tf.Session(config=config) - self.model.variables.set_session(sess) - self.coord = tf.train.Coordinator() - tf.train.start_queue_runners(sess, coord=self.coord) - init = tf.global_variables_initializer() - sess.run(init) - self.steps = 10 + input_images = data[0] + input_labels = data[1] + with tf.device("/gpu:0" if num_gpus > 0 else "/cpu:0"): + # Build the model. + images, labels = cifar_input.build_input([input_images, + input_labels], + hps.batch_size, dataset, + False) + self.model = resnet_model.ResNet(hps, images, labels, "train") + self.model.build_graph() + config = tf.ConfigProto(allow_soft_placement=True) + sess = tf.Session(config=config) + self.model.variables.set_session(sess) + self.coord = tf.train.Coordinator() + tf.train.start_queue_runners(sess, coord=self.coord) + init = tf.global_variables_initializer() + sess.run(init) + self.steps = 10 - def compute_steps(self, weights): - # This method sets the weights in the network, trains the network - # self.steps times, and returns the new weights. - self.model.variables.set_weights(weights) - for i in range(self.steps): - self.model.variables.sess.run(self.model.train_op) - return self.model.variables.get_weights() + def compute_steps(self, weights): + # This method sets the weights in the network, trains the network + # self.steps times, and returns the new weights. + self.model.variables.set_weights(weights) + for i in range(self.steps): + self.model.variables.sess.run(self.model.train_op) + return self.model.variables.get_weights() - def get_weights(self): - # Note that the driver cannot directly access fields of the class, - # so helper methods must be created. - return self.model.variables.get_weights() + def get_weights(self): + # Note that the driver cannot directly access fields of the class, + # so helper methods must be created. + return self.model.variables.get_weights() @ray.remote class ResNetTestActor(object): - def __init__(self, data, dataset, eval_batch_count, eval_dir): - hps = resnet_model.HParams( - batch_size=100, - num_classes=100 if dataset == "cifar100" else 10, - min_lrn_rate=0.0001, - lrn_rate=0.1, - num_residual_units=5, - use_bottleneck=False, - weight_decay_rate=0.0002, - relu_leakiness=0.1, - optimizer="mom", - num_gpus=0) - input_images = data[0] - input_labels = data[1] - with tf.device("/cpu:0"): - # Builds the testing network. - images, labels = cifar_input.build_input([input_images, input_labels], - hps.batch_size, dataset, False) - self.model = resnet_model.ResNet(hps, images, labels, "eval") - self.model.build_graph() - config = tf.ConfigProto(allow_soft_placement=True) - sess = tf.Session(config=config) - self.model.variables.set_session(sess) - self.coord = tf.train.Coordinator() - tf.train.start_queue_runners(sess, coord=self.coord) - init = tf.global_variables_initializer() - sess.run(init) + def __init__(self, data, dataset, eval_batch_count, eval_dir): + hps = resnet_model.HParams( + batch_size=100, + num_classes=100 if dataset == "cifar100" else 10, + min_lrn_rate=0.0001, + lrn_rate=0.1, + num_residual_units=5, + use_bottleneck=False, + weight_decay_rate=0.0002, + relu_leakiness=0.1, + optimizer="mom", + num_gpus=0) + input_images = data[0] + input_labels = data[1] + with tf.device("/cpu:0"): + # Builds the testing network. + images, labels = cifar_input.build_input([input_images, + input_labels], + hps.batch_size, dataset, + False) + self.model = resnet_model.ResNet(hps, images, labels, "eval") + self.model.build_graph() + config = tf.ConfigProto(allow_soft_placement=True) + sess = tf.Session(config=config) + self.model.variables.set_session(sess) + self.coord = tf.train.Coordinator() + tf.train.start_queue_runners(sess, coord=self.coord) + init = tf.global_variables_initializer() + sess.run(init) - # Initializing parameters for tensorboard. - self.best_precision = 0.0 - self.eval_batch_count = eval_batch_count - self.summary_writer = tf.summary.FileWriter(eval_dir, sess.graph) - # The IP address where tensorboard logs will be on. - self.ip_addr = ray.services.get_node_ip_address() + # Initializing parameters for tensorboard. + self.best_precision = 0.0 + self.eval_batch_count = eval_batch_count + self.summary_writer = tf.summary.FileWriter(eval_dir, sess.graph) + # The IP address where tensorboard logs will be on. + self.ip_addr = ray.services.get_node_ip_address() - def accuracy(self, weights, train_step): - # Sets the weights, computes the accuracy and other metrics - # over eval_batches, and outputs to tensorboard. - self.model.variables.set_weights(weights) - total_prediction, correct_prediction = 0, 0 - model = self.model - sess = self.model.variables.sess - for _ in range(self.eval_batch_count): - summaries, loss, predictions, truth = sess.run( - [model.summaries, model.cost, model.predictions, - model.labels]) + def accuracy(self, weights, train_step): + # Sets the weights, computes the accuracy and other metrics + # over eval_batches, and outputs to tensorboard. + self.model.variables.set_weights(weights) + total_prediction, correct_prediction = 0, 0 + model = self.model + sess = self.model.variables.sess + for _ in range(self.eval_batch_count): + summaries, loss, predictions, truth = sess.run( + [model.summaries, model.cost, model.predictions, + model.labels]) - truth = np.argmax(truth, axis=1) - predictions = np.argmax(predictions, axis=1) - correct_prediction += np.sum(truth == predictions) - total_prediction += predictions.shape[0] + truth = np.argmax(truth, axis=1) + predictions = np.argmax(predictions, axis=1) + correct_prediction += np.sum(truth == predictions) + total_prediction += predictions.shape[0] - precision = 1.0 * correct_prediction / total_prediction - self.best_precision = max(precision, self.best_precision) - precision_summ = tf.Summary() - precision_summ.value.add( - tag="Precision", simple_value=precision) - self.summary_writer.add_summary(precision_summ, train_step) - best_precision_summ = tf.Summary() - best_precision_summ.value.add( - tag="Best Precision", simple_value=self.best_precision) - self.summary_writer.add_summary(best_precision_summ, train_step) - self.summary_writer.add_summary(summaries, train_step) - tf.logging.info("loss: %.3f, precision: %.3f, best precision: %.3f" % - (loss, precision, self.best_precision)) - self.summary_writer.flush() - return precision + precision = 1.0 * correct_prediction / total_prediction + self.best_precision = max(precision, self.best_precision) + precision_summ = tf.Summary() + precision_summ.value.add( + tag="Precision", simple_value=precision) + self.summary_writer.add_summary(precision_summ, train_step) + best_precision_summ = tf.Summary() + best_precision_summ.value.add( + tag="Best Precision", simple_value=self.best_precision) + self.summary_writer.add_summary(best_precision_summ, train_step) + self.summary_writer.add_summary(summaries, train_step) + tf.logging.info("loss: %.3f, precision: %.3f, best precision: %.3f" % + (loss, precision, self.best_precision)) + self.summary_writer.flush() + return precision - def get_ip_addr(self): - # As above, a helper method must be created to access the field from the - # driver. - return self.ip_addr + def get_ip_addr(self): + # As above, a helper method must be created to access the field from + # the driver. + return self.ip_addr def train(): - num_gpus = FLAGS.num_gpus - ray.init(num_gpus=num_gpus, redirect_output=True) - train_data = get_data.remote(FLAGS.train_data_path, 50000, FLAGS.dataset) - test_data = get_data.remote(FLAGS.eval_data_path, 10000, FLAGS.dataset) - # Creates an actor for each gpu, or one if only using the cpu. Each actor has - # access to the dataset. - if FLAGS.num_gpus > 0: - train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset, - num_gpus) for _ in range(num_gpus)] - else: - train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset, 0)] - test_actor = ResNetTestActor.remote(test_data, FLAGS.dataset, - FLAGS.eval_batch_count, FLAGS.eval_dir) - print("The log files for tensorboard are stored at ip {}." - .format(ray.get(test_actor.get_ip_addr.remote()))) - step = 0 - weight_id = train_actors[0].get_weights.remote() - acc_id = test_actor.accuracy.remote(weight_id, step) - # Correction for dividing the weights by the number of gpus. - if num_gpus == 0: - num_gpus = 1 - print("Starting training loop. Use Ctrl-C to exit.") - try: - while True: - all_weights = ray.get([actor.compute_steps.remote(weight_id) - for actor in train_actors]) - mean_weights = {k: (sum([weights[k] for weights in all_weights]) / - num_gpus) - for k in all_weights[0]} - weight_id = ray.put(mean_weights) - step += 10 - if step % 200 == 0: - # Retrieves the previously computed accuracy and launches a new - # testing task with the current weights every 200 steps. - acc = ray.get(acc_id) - acc_id = test_actor.accuracy.remote(weight_id, step) - print("Step {0}: {1:.6f}".format(step - 200, acc)) - except KeyboardInterrupt: - pass + num_gpus = FLAGS.num_gpus + ray.init(num_gpus=num_gpus, redirect_output=True) + train_data = get_data.remote(FLAGS.train_data_path, 50000, FLAGS.dataset) + test_data = get_data.remote(FLAGS.eval_data_path, 10000, FLAGS.dataset) + # Creates an actor for each gpu, or one if only using the cpu. Each actor + # has access to the dataset. + if FLAGS.num_gpus > 0: + train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset, + num_gpus) + for _ in range(num_gpus)] + else: + train_actors = [ResNetTrainActor.remote(train_data, FLAGS.dataset, 0)] + test_actor = ResNetTestActor.remote(test_data, FLAGS.dataset, + FLAGS.eval_batch_count, FLAGS.eval_dir) + print("The log files for tensorboard are stored at ip {}." + .format(ray.get(test_actor.get_ip_addr.remote()))) + step = 0 + weight_id = train_actors[0].get_weights.remote() + acc_id = test_actor.accuracy.remote(weight_id, step) + # Correction for dividing the weights by the number of gpus. + if num_gpus == 0: + num_gpus = 1 + print("Starting training loop. Use Ctrl-C to exit.") + try: + while True: + all_weights = ray.get([actor.compute_steps.remote(weight_id) + for actor in train_actors]) + mean_weights = {k: (sum([weights[k] for weights in all_weights]) / + num_gpus) + for k in all_weights[0]} + weight_id = ray.put(mean_weights) + step += 10 + if step % 200 == 0: + # Retrieves the previously computed accuracy and launches a new + # testing task with the current weights every 200 steps. + acc = ray.get(acc_id) + acc_id = test_actor.accuracy.remote(weight_id, step) + print("Step {0}: {1:.6f}".format(step - 200, acc)) + except KeyboardInterrupt: + pass if __name__ == "__main__": - train() + train() diff --git a/examples/resnet/resnet_model.py b/examples/resnet/resnet_model.py index 9671a4d36..1246f10fc 100644 --- a/examples/resnet/resnet_model.py +++ b/examples/resnet/resnet_model.py @@ -24,260 +24,271 @@ HParams = namedtuple('HParams', class ResNet(object): - """ResNet model.""" + """ResNet model.""" - def __init__(self, hps, images, labels, mode): - """ResNet constructor. + def __init__(self, hps, images, labels, mode): + """ResNet constructor. - Args: - hps: Hyperparameters. - images: Batches of images of size [batch_size, image_size, image_size, - 3]. - labels: Batches of labels of size [batch_size, num_classes]. - mode: One of 'train' and 'eval'. - """ - self.hps = hps - self._images = images - self.labels = labels - self.mode = mode + Args: + hps: Hyperparameters. + images: Batches of images of size [batch_size, image_size, + image_size, 3]. + labels: Batches of labels of size [batch_size, num_classes]. + mode: One of 'train' and 'eval'. + """ + self.hps = hps + self._images = images + self.labels = labels + self.mode = mode - self._extra_train_ops = [] + self._extra_train_ops = [] - def build_graph(self): - """Build a whole graph for the model.""" - self.global_step = tf.Variable(0, trainable=False) - self._build_model() - if self.mode == 'train': - self._build_train_op() - else: - # Additional initialization for the test network. - self.variables = ray.experimental.TensorFlowVariables(self.cost) - self.summaries = tf.summary.merge_all() + def build_graph(self): + """Build a whole graph for the model.""" + self.global_step = tf.Variable(0, trainable=False) + self._build_model() + if self.mode == 'train': + self._build_train_op() + else: + # Additional initialization for the test network. + self.variables = ray.experimental.TensorFlowVariables(self.cost) + self.summaries = tf.summary.merge_all() - def _stride_arr(self, stride): - """Map a stride scalar to the stride array for tf.nn.conv2d.""" - return [1, stride, stride, 1] + def _stride_arr(self, stride): + """Map a stride scalar to the stride array for tf.nn.conv2d.""" + return [1, stride, stride, 1] - def _build_model(self): - """Build the core model within the graph.""" + def _build_model(self): + """Build the core model within the graph.""" - with tf.variable_scope('init'): - x = self._conv('init_conv', self._images, 3, 3, 16, self._stride_arr(1)) + with tf.variable_scope('init'): + x = self._conv('init_conv', self._images, 3, 3, 16, + self._stride_arr(1)) - strides = [1, 2, 2] - activate_before_residual = [True, False, False] - if self.hps.use_bottleneck: - res_func = self._bottleneck_residual - filters = [16, 64, 128, 256] - else: - res_func = self._residual - filters = [16, 16, 32, 64] + strides = [1, 2, 2] + activate_before_residual = [True, False, False] + if self.hps.use_bottleneck: + res_func = self._bottleneck_residual + filters = [16, 64, 128, 256] + else: + res_func = self._residual + filters = [16, 16, 32, 64] - with tf.variable_scope('unit_1_0'): - x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), - activate_before_residual[0]) - for i in range(1, self.hps.num_residual_units): - with tf.variable_scope('unit_1_%d' % i): - x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) + with tf.variable_scope('unit_1_0'): + x = res_func(x, filters[0], filters[1], + self._stride_arr(strides[0]), + activate_before_residual[0]) + for i in range(1, self.hps.num_residual_units): + with tf.variable_scope('unit_1_%d' % i): + x = res_func(x, filters[1], filters[1], self._stride_arr(1), + False) - with tf.variable_scope('unit_2_0'): - x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), - activate_before_residual[1]) - for i in range(1, self.hps.num_residual_units): - with tf.variable_scope('unit_2_%d' % i): - x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) + with tf.variable_scope('unit_2_0'): + x = res_func(x, filters[1], filters[2], + self._stride_arr(strides[1]), + activate_before_residual[1]) + for i in range(1, self.hps.num_residual_units): + with tf.variable_scope('unit_2_%d' % i): + x = res_func(x, filters[2], filters[2], + self._stride_arr(1), False) - with tf.variable_scope('unit_3_0'): - x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), - activate_before_residual[2]) - for i in range(1, self.hps.num_residual_units): - with tf.variable_scope('unit_3_%d' % i): - x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) - with tf.variable_scope('unit_last'): - x = self._batch_norm('final_bn', x) - x = self._relu(x, self.hps.relu_leakiness) - x = self._global_avg_pool(x) + with tf.variable_scope('unit_3_0'): + x = res_func(x, filters[2], filters[3], + self._stride_arr(strides[2]), + activate_before_residual[2]) + for i in range(1, self.hps.num_residual_units): + with tf.variable_scope('unit_3_%d' % i): + x = res_func(x, filters[3], filters[3], self._stride_arr(1), + False) + with tf.variable_scope('unit_last'): + x = self._batch_norm('final_bn', x) + x = self._relu(x, self.hps.relu_leakiness) + x = self._global_avg_pool(x) - with tf.variable_scope('logit'): - logits = self._fully_connected(x, self.hps.num_classes) - self.predictions = tf.nn.softmax(logits) + with tf.variable_scope('logit'): + logits = self._fully_connected(x, self.hps.num_classes) + self.predictions = tf.nn.softmax(logits) - with tf.variable_scope('costs'): - xent = tf.nn.softmax_cross_entropy_with_logits( - logits=logits, labels=self.labels) - self.cost = tf.reduce_mean(xent, name='xent') - self.cost += self._decay() + with tf.variable_scope('costs'): + xent = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=self.labels) + self.cost = tf.reduce_mean(xent, name='xent') + self.cost += self._decay() - if self.mode == 'eval': - tf.summary.scalar('cost', self.cost) + if self.mode == 'eval': + tf.summary.scalar('cost', self.cost) - def _build_train_op(self): - """Build training specific ops for the graph.""" - num_gpus = self.hps.num_gpus if self.hps.num_gpus != 0 else 1 - # The learning rate schedule is dependent on the number of gpus. - boundaries = [int(20000 * i / np.sqrt(num_gpus)) for i in range(2, 5)] - values = [0.1, 0.01, 0.001, 0.0001] - self.lrn_rate = tf.train.piecewise_constant(self.global_step, boundaries, - values) - tf.summary.scalar('learning rate', self.lrn_rate) + def _build_train_op(self): + """Build training specific ops for the graph.""" + num_gpus = self.hps.num_gpus if self.hps.num_gpus != 0 else 1 + # The learning rate schedule is dependent on the number of gpus. + boundaries = [int(20000 * i / np.sqrt(num_gpus)) for i in range(2, 5)] + values = [0.1, 0.01, 0.001, 0.0001] + self.lrn_rate = tf.train.piecewise_constant(self.global_step, + boundaries, values) + tf.summary.scalar('learning rate', self.lrn_rate) - if self.hps.optimizer == 'sgd': - optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate) - elif self.hps.optimizer == 'mom': - optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9) + if self.hps.optimizer == 'sgd': + optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate) + elif self.hps.optimizer == 'mom': + optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9) - apply_op = optimizer.minimize(self.cost, global_step=self.global_step) - train_ops = [apply_op] + self._extra_train_ops - self.train_op = tf.group(*train_ops) - self.variables = ray.experimental.TensorFlowVariables(self.train_op) + apply_op = optimizer.minimize(self.cost, global_step=self.global_step) + train_ops = [apply_op] + self._extra_train_ops + self.train_op = tf.group(*train_ops) + self.variables = ray.experimental.TensorFlowVariables(self.train_op) - def _batch_norm(self, name, x): - """Batch normalization.""" - with tf.variable_scope(name): - params_shape = [x.get_shape()[-1]] + def _batch_norm(self, name, x): + """Batch normalization.""" + with tf.variable_scope(name): + params_shape = [x.get_shape()[-1]] - beta = tf.get_variable( - 'beta', params_shape, tf.float32, - initializer=tf.constant_initializer(0.0, tf.float32)) - gamma = tf.get_variable( - 'gamma', params_shape, tf.float32, - initializer=tf.constant_initializer(1.0, tf.float32)) + beta = tf.get_variable( + 'beta', params_shape, tf.float32, + initializer=tf.constant_initializer(0.0, tf.float32)) + gamma = tf.get_variable( + 'gamma', params_shape, tf.float32, + initializer=tf.constant_initializer(1.0, tf.float32)) - if self.mode == 'train': - mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') + if self.mode == 'train': + mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') - moving_mean = tf.get_variable( - 'moving_mean', params_shape, tf.float32, - initializer=tf.constant_initializer(0.0, tf.float32), - trainable=False) - moving_variance = tf.get_variable( - 'moving_variance', params_shape, tf.float32, - initializer=tf.constant_initializer(1.0, tf.float32), - trainable=False) + moving_mean = tf.get_variable( + 'moving_mean', params_shape, tf.float32, + initializer=tf.constant_initializer(0.0, tf.float32), + trainable=False) + moving_variance = tf.get_variable( + 'moving_variance', params_shape, tf.float32, + initializer=tf.constant_initializer(1.0, tf.float32), + trainable=False) - self._extra_train_ops.append(moving_averages.assign_moving_average( - moving_mean, mean, 0.9)) - self._extra_train_ops.append(moving_averages.assign_moving_average( - moving_variance, variance, 0.9)) - else: - mean = tf.get_variable( - 'moving_mean', params_shape, tf.float32, - initializer=tf.constant_initializer(0.0, tf.float32), - trainable=False) - variance = tf.get_variable( - 'moving_variance', params_shape, tf.float32, - initializer=tf.constant_initializer(1.0, tf.float32), - trainable=False) - tf.summary.histogram(mean.op.name, mean) - tf.summary.histogram(variance.op.name, variance) - # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. - y = tf.nn.batch_normalization( - x, mean, variance, beta, gamma, 0.001) - y.set_shape(x.get_shape()) - return y + self._extra_train_ops.append( + moving_averages.assign_moving_average(moving_mean, mean, + 0.9)) + self._extra_train_ops.append( + moving_averages.assign_moving_average(moving_variance, + variance, 0.9)) + else: + mean = tf.get_variable( + 'moving_mean', params_shape, tf.float32, + initializer=tf.constant_initializer(0.0, tf.float32), + trainable=False) + variance = tf.get_variable( + 'moving_variance', params_shape, tf.float32, + initializer=tf.constant_initializer(1.0, tf.float32), + trainable=False) + tf.summary.histogram(mean.op.name, mean) + tf.summary.histogram(variance.op.name, variance) + # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper + # net. + y = tf.nn.batch_normalization( + x, mean, variance, beta, gamma, 0.001) + y.set_shape(x.get_shape()) + return y - def _residual(self, x, in_filter, out_filter, stride, - activate_before_residual=False): - """Residual unit with 2 sub layers.""" - if activate_before_residual: - with tf.variable_scope('shared_activation'): - x = self._batch_norm('init_bn', x) - x = self._relu(x, self.hps.relu_leakiness) - orig_x = x - else: - with tf.variable_scope('residual_only_activation'): - orig_x = x - x = self._batch_norm('init_bn', x) - x = self._relu(x, self.hps.relu_leakiness) + def _residual(self, x, in_filter, out_filter, stride, + activate_before_residual=False): + """Residual unit with 2 sub layers.""" + if activate_before_residual: + with tf.variable_scope('shared_activation'): + x = self._batch_norm('init_bn', x) + x = self._relu(x, self.hps.relu_leakiness) + orig_x = x + else: + with tf.variable_scope('residual_only_activation'): + orig_x = x + x = self._batch_norm('init_bn', x) + x = self._relu(x, self.hps.relu_leakiness) - with tf.variable_scope('sub1'): - x = self._conv('conv1', x, 3, in_filter, out_filter, stride) + with tf.variable_scope('sub1'): + x = self._conv('conv1', x, 3, in_filter, out_filter, stride) - with tf.variable_scope('sub2'): - x = self._batch_norm('bn2', x) - x = self._relu(x, self.hps.relu_leakiness) - x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) + with tf.variable_scope('sub2'): + x = self._batch_norm('bn2', x) + x = self._relu(x, self.hps.relu_leakiness) + x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) - with tf.variable_scope('sub_add'): - if in_filter != out_filter: - orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') - orig_x = tf.pad( - orig_x, [[0, 0], [0, 0], [0, 0], - [(out_filter - in_filter) // 2, - (out_filter - in_filter) // 2]]) - x += orig_x + with tf.variable_scope('sub_add'): + if in_filter != out_filter: + orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') + orig_x = tf.pad( + orig_x, [[0, 0], [0, 0], [0, 0], + [(out_filter - in_filter) // 2, + (out_filter - in_filter) // 2]]) + x += orig_x - return x + return x - def _bottleneck_residual(self, x, in_filter, out_filter, stride, - activate_before_residual=False): - """Bottleneck residual unit with 3 sub layers.""" - if activate_before_residual: - with tf.variable_scope('common_bn_relu'): - x = self._batch_norm('init_bn', x) - x = self._relu(x, self.hps.relu_leakiness) - orig_x = x - else: - with tf.variable_scope('residual_bn_relu'): - orig_x = x - x = self._batch_norm('init_bn', x) - x = self._relu(x, self.hps.relu_leakiness) + def _bottleneck_residual(self, x, in_filter, out_filter, stride, + activate_before_residual=False): + """Bottleneck residual unit with 3 sub layers.""" + if activate_before_residual: + with tf.variable_scope('common_bn_relu'): + x = self._batch_norm('init_bn', x) + x = self._relu(x, self.hps.relu_leakiness) + orig_x = x + else: + with tf.variable_scope('residual_bn_relu'): + orig_x = x + x = self._batch_norm('init_bn', x) + x = self._relu(x, self.hps.relu_leakiness) - with tf.variable_scope('sub1'): - x = self._conv('conv1', x, 1, in_filter, out_filter / 4, stride) + with tf.variable_scope('sub1'): + x = self._conv('conv1', x, 1, in_filter, out_filter / 4, stride) - with tf.variable_scope('sub2'): - x = self._batch_norm('bn2', x) - x = self._relu(x, self.hps.relu_leakiness) - x = self._conv('conv2', x, 3, out_filter / 4, out_filter / 4, - [1, 1, 1, 1]) + with tf.variable_scope('sub2'): + x = self._batch_norm('bn2', x) + x = self._relu(x, self.hps.relu_leakiness) + x = self._conv('conv2', x, 3, out_filter / 4, out_filter / 4, + [1, 1, 1, 1]) - with tf.variable_scope('sub3'): - x = self._batch_norm('bn3', x) - x = self._relu(x, self.hps.relu_leakiness) - x = self._conv('conv3', x, 1, out_filter / 4, out_filter, [1, 1, 1, 1]) + with tf.variable_scope('sub3'): + x = self._batch_norm('bn3', x) + x = self._relu(x, self.hps.relu_leakiness) + x = self._conv('conv3', x, 1, out_filter / 4, out_filter, + [1, 1, 1, 1]) - with tf.variable_scope('sub_add'): - if in_filter != out_filter: - orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, - stride) - x += orig_x + with tf.variable_scope('sub_add'): + if in_filter != out_filter: + orig_x = self._conv('project', orig_x, 1, in_filter, + out_filter, stride) + x += orig_x - return x + return x - def _decay(self): - """L2 weight decay loss.""" - costs = [] - for var in tf.trainable_variables(): - if var.op.name.find(r'DW') > 0: - costs.append(tf.nn.l2_loss(var)) + def _decay(self): + """L2 weight decay loss.""" + costs = [] + for var in tf.trainable_variables(): + if var.op.name.find(r'DW') > 0: + costs.append(tf.nn.l2_loss(var)) - return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs)) + return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs)) - def _conv(self, name, x, filter_size, in_filters, out_filters, strides): - """Convolution.""" - with tf.variable_scope(name): - n = filter_size * filter_size * out_filters - kernel = tf.get_variable( - 'DW', [filter_size, filter_size, in_filters, out_filters], - tf.float32, initializer=tf.random_normal_initializer( - stddev=np.sqrt(2.0 / n))) - return tf.nn.conv2d(x, kernel, strides, padding='SAME') + def _conv(self, name, x, filter_size, in_filters, out_filters, strides): + """Convolution.""" + with tf.variable_scope(name): + n = filter_size * filter_size * out_filters + kernel = tf.get_variable( + 'DW', [filter_size, filter_size, in_filters, out_filters], + tf.float32, initializer=tf.random_normal_initializer( + stddev=np.sqrt(2.0 / n))) + return tf.nn.conv2d(x, kernel, strides, padding='SAME') - def _relu(self, x, leakiness=0.0): - """Relu, with optional leaky support.""" - return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') + def _relu(self, x, leakiness=0.0): + """Relu, with optional leaky support.""" + return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') - def _fully_connected(self, x, out_dim): - """FullyConnected layer for final output.""" - x = tf.reshape(x, [self.hps.batch_size, -1]) - w = tf.get_variable( - 'DW', [x.get_shape()[1], out_dim], - initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) - b = tf.get_variable('biases', [out_dim], - initializer=tf.constant_initializer()) - return tf.nn.xw_plus_b(x, w, b) + def _fully_connected(self, x, out_dim): + """FullyConnected layer for final output.""" + x = tf.reshape(x, [self.hps.batch_size, -1]) + w = tf.get_variable( + 'DW', [x.get_shape()[1], out_dim], + initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) + b = tf.get_variable('biases', [out_dim], + initializer=tf.constant_initializer()) + return tf.nn.xw_plus_b(x, w, b) - def _global_avg_pool(self, x): - assert x.get_shape().ndims == 4 - return tf.reduce_mean(x, [1, 2]) + def _global_avg_pool(self, x): + assert x.get_shape().ndims == 4 + return tf.reduce_mean(x, [1, 2]) diff --git a/examples/rl_pong/driver.py b/examples/rl_pong/driver.py index 4c414f417..1e4f9db52 100644 --- a/examples/rl_pong/driver.py +++ b/examples/rl_pong/driver.py @@ -28,170 +28,173 @@ D = 80 * 80 def sigmoid(x): - # Sigmoid "squashing" function to interval [0, 1]. - return 1.0 / (1.0 + np.exp(-x)) + # Sigmoid "squashing" function to interval [0, 1]. + return 1.0 / (1.0 + np.exp(-x)) def preprocess(I): - """Preprocess 210x160x3 uint8 frame into 6400 (80x80) 1D float vector.""" - # Crop the image. - I = I[35:195] - # Downsample by factor of 2. - I = I[::2, ::2, 0] - # Erase background (background type 1). - I[I == 144] = 0 - # Erase background (background type 2). - I[I == 109] = 0 - # Set everything else (paddles, ball) to 1. - I[I != 0] = 1 - return I.astype(np.float).ravel() + """Preprocess 210x160x3 uint8 frame into 6400 (80x80) 1D float vector.""" + # Crop the image. + I = I[35:195] + # Downsample by factor of 2. + I = I[::2, ::2, 0] + # Erase background (background type 1). + I[I == 144] = 0 + # Erase background (background type 2). + I[I == 109] = 0 + # Set everything else (paddles, ball) to 1. + I[I != 0] = 1 + return I.astype(np.float).ravel() def discount_rewards(r): - """take 1D float array of rewards and compute discounted reward""" - discounted_r = np.zeros_like(r) - running_add = 0 - for t in reversed(range(0, r.size)): - # Reset the sum, since this was a game boundary (pong specific!). - if r[t] != 0: - running_add = 0 - running_add = running_add * gamma + r[t] - discounted_r[t] = running_add - return discounted_r + """take 1D float array of rewards and compute discounted reward""" + discounted_r = np.zeros_like(r) + running_add = 0 + for t in reversed(range(0, r.size)): + # Reset the sum, since this was a game boundary (pong specific!). + if r[t] != 0: + running_add = 0 + running_add = running_add * gamma + r[t] + discounted_r[t] = running_add + return discounted_r def policy_forward(x, model): - h = np.dot(model["W1"], x) - h[h < 0] = 0 # ReLU nonlinearity. - logp = np.dot(model["W2"], h) - p = sigmoid(logp) - # Return probability of taking action 2, and hidden state. - return p, h + h = np.dot(model["W1"], x) + h[h < 0] = 0 # ReLU nonlinearity. + logp = np.dot(model["W2"], h) + p = sigmoid(logp) + # Return probability of taking action 2, and hidden state. + return p, h def policy_backward(eph, epx, epdlogp, model): - """backward pass. (eph is array of intermediate hidden states)""" - dW2 = np.dot(eph.T, epdlogp).ravel() - dh = np.outer(epdlogp, model["W2"]) - # Backprop relu. - dh[eph <= 0] = 0 - dW1 = np.dot(dh.T, epx) - return {"W1": dW1, "W2": dW2} + """backward pass. (eph is array of intermediate hidden states)""" + dW2 = np.dot(eph.T, epdlogp).ravel() + dh = np.outer(epdlogp, model["W2"]) + # Backprop relu. + dh[eph <= 0] = 0 + dW1 = np.dot(dh.T, epx) + return {"W1": dW1, "W2": dW2} @ray.remote class PongEnv(object): - def __init__(self): - # Tell numpy to only use one core. If we don't do this, each actor may try - # to use all of the cores and the resulting contention may result in no - # speedup over the serial version. Note that if numpy is using OpenBLAS, - # then you need to set OPENBLAS_NUM_THREADS=1, and you probably need to do - # it from the command line (so it happens before numpy is imported). - os.environ["MKL_NUM_THREADS"] = "1" - self.env = gym.make("Pong-v0") + def __init__(self): + # Tell numpy to only use one core. If we don't do this, each actor may + # try to use all of the cores and the resulting contention may result + # in no speedup over the serial version. Note that if numpy is using + # OpenBLAS, then you need to set OPENBLAS_NUM_THREADS=1, and you + # probably need to do it from the command line (so it happens before + # numpy is imported). + os.environ["MKL_NUM_THREADS"] = "1" + self.env = gym.make("Pong-v0") - def compute_gradient(self, model): - # Reset the game. - observation = self.env.reset() - # Note that prev_x is used in computing the difference frame. - prev_x = None - xs, hs, dlogps, drs = [], [], [], [] - reward_sum = 0 - done = False - while not done: - cur_x = preprocess(observation) - x = cur_x - prev_x if prev_x is not None else np.zeros(D) - prev_x = cur_x + def compute_gradient(self, model): + # Reset the game. + observation = self.env.reset() + # Note that prev_x is used in computing the difference frame. + prev_x = None + xs, hs, dlogps, drs = [], [], [], [] + reward_sum = 0 + done = False + while not done: + cur_x = preprocess(observation) + x = cur_x - prev_x if prev_x is not None else np.zeros(D) + prev_x = cur_x - aprob, h = policy_forward(x, model) - # Sample an action. - action = 2 if np.random.uniform() < aprob else 3 + aprob, h = policy_forward(x, model) + # Sample an action. + action = 2 if np.random.uniform() < aprob else 3 - # The observation. - xs.append(x) - # The hidden state. - hs.append(h) - y = 1 if action == 2 else 0 # A "fake label". - # The gradient that encourages the action that was taken to be taken (see - # http://cs231n.github.io/neural-networks-2/#losses if confused). - dlogps.append(y - aprob) + # The observation. + xs.append(x) + # The hidden state. + hs.append(h) + y = 1 if action == 2 else 0 # A "fake label". + # The gradient that encourages the action that was taken to be + # taken (see http://cs231n.github.io/neural-networks-2/#losses if + # confused). + dlogps.append(y - aprob) - observation, reward, done, info = self.env.step(action) - reward_sum += reward + observation, reward, done, info = self.env.step(action) + reward_sum += reward - # Record reward (has to be done after we call step() to get reward for - # previous action). - drs.append(reward) + # Record reward (has to be done after we call step() to get reward + # for previous action). + drs.append(reward) - epx = np.vstack(xs) - eph = np.vstack(hs) - epdlogp = np.vstack(dlogps) - epr = np.vstack(drs) - # Reset the array memory. - xs, hs, dlogps, drs = [], [], [], [] + epx = np.vstack(xs) + eph = np.vstack(hs) + epdlogp = np.vstack(dlogps) + epr = np.vstack(drs) + # Reset the array memory. + xs, hs, dlogps, drs = [], [], [], [] - # Compute the discounted reward backward through time. - discounted_epr = discount_rewards(epr) - # Standardize the rewards to be unit normal (helps control the gradient - # estimator variance). - discounted_epr -= np.mean(discounted_epr) - discounted_epr /= np.std(discounted_epr) - # Modulate the gradient with advantage (the policy gradient magic happens - # right here). - epdlogp *= discounted_epr - return policy_backward(eph, epx, epdlogp, model), reward_sum + # Compute the discounted reward backward through time. + discounted_epr = discount_rewards(epr) + # Standardize the rewards to be unit normal (helps control the gradient + # estimator variance). + discounted_epr -= np.mean(discounted_epr) + discounted_epr /= np.std(discounted_epr) + # Modulate the gradient with advantage (the policy gradient magic + # happens right here). + epdlogp *= discounted_epr + return policy_backward(eph, epx, epdlogp, model), reward_sum if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Train an RL agent on Pong.") - parser.add_argument("--batch-size", default=10, type=int, - help="The number of rollouts to do per batch.") - parser.add_argument("--redis-address", default=None, type=str, - help="The Redis address of the cluster.") + parser = argparse.ArgumentParser(description="Train an RL agent on Pong.") + parser.add_argument("--batch-size", default=10, type=int, + help="The number of rollouts to do per batch.") + parser.add_argument("--redis-address", default=None, type=str, + help="The Redis address of the cluster.") - args = parser.parse_args() - batch_size = args.batch_size + args = parser.parse_args() + batch_size = args.batch_size - ray.init(redis_address=args.redis_address, redirect_output=True) + ray.init(redis_address=args.redis_address, redirect_output=True) - # Run the reinforcement learning. + # Run the reinforcement learning. - running_reward = None - batch_num = 1 - model = {} - # "Xavier" initialization. - model["W1"] = np.random.randn(H, D) / np.sqrt(D) - model["W2"] = np.random.randn(H) / np.sqrt(H) - # Update buffers that add up gradients over a batch. - grad_buffer = {k: np.zeros_like(v) for k, v in model.items()} - # Update the rmsprop memory. - rmsprop_cache = {k: np.zeros_like(v) for k, v in model.items()} - actors = [PongEnv.remote() for _ in range(batch_size)] - while True: - model_id = ray.put(model) - actions = [] - # Launch tasks to compute gradients from multiple rollouts in parallel. - start_time = time.time() - for i in range(batch_size): - action_id = actors[i].compute_gradient.remote(model_id) - actions.append(action_id) - for i in range(batch_size): - action_id, actions = ray.wait(actions) - grad, reward_sum = ray.get(action_id[0]) - # Accumulate the gradient over batch. - for k in model: - grad_buffer[k] += grad[k] - running_reward = (reward_sum if running_reward is None - else running_reward * 0.99 + reward_sum * 0.01) - end_time = time.time() - print("Batch {} computed {} rollouts in {} seconds, " - "running mean is {}".format(batch_num, batch_size, - end_time - start_time, running_reward)) - for k, v in model.items(): - g = grad_buffer[k] - rmsprop_cache[k] = (decay_rate * rmsprop_cache[k] + - (1 - decay_rate) * g ** 2) - model[k] += learning_rate * g / (np.sqrt(rmsprop_cache[k]) + 1e-5) - # Reset the batch gradient buffer. - grad_buffer[k] = np.zeros_like(v) - batch_num += 1 + running_reward = None + batch_num = 1 + model = {} + # "Xavier" initialization. + model["W1"] = np.random.randn(H, D) / np.sqrt(D) + model["W2"] = np.random.randn(H) / np.sqrt(H) + # Update buffers that add up gradients over a batch. + grad_buffer = {k: np.zeros_like(v) for k, v in model.items()} + # Update the rmsprop memory. + rmsprop_cache = {k: np.zeros_like(v) for k, v in model.items()} + actors = [PongEnv.remote() for _ in range(batch_size)] + while True: + model_id = ray.put(model) + actions = [] + # Launch tasks to compute gradients from multiple rollouts in parallel. + start_time = time.time() + for i in range(batch_size): + action_id = actors[i].compute_gradient.remote(model_id) + actions.append(action_id) + for i in range(batch_size): + action_id, actions = ray.wait(actions) + grad, reward_sum = ray.get(action_id[0]) + # Accumulate the gradient over batch. + for k in model: + grad_buffer[k] += grad[k] + running_reward = (reward_sum if running_reward is None + else running_reward * 0.99 + reward_sum * 0.01) + end_time = time.time() + print("Batch {} computed {} rollouts in {} seconds, " + "running mean is {}".format(batch_num, batch_size, + end_time - start_time, + running_reward)) + for k, v in model.items(): + g = grad_buffer[k] + rmsprop_cache[k] = (decay_rate * rmsprop_cache[k] + + (1 - decay_rate) * g ** 2) + model[k] += learning_rate * g / (np.sqrt(rmsprop_cache[k]) + 1e-5) + # Reset the batch gradient buffer. + grad_buffer[k] = np.zeros_like(v) + batch_num += 1 diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 60ed81dfd..980a8c758 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -21,9 +21,9 @@ __all__ = ["register_class", "error_info", "init", "connect", "disconnect", import ctypes # Windows only if hasattr(ctypes, "windll"): - # Makes sure that all child processes die when we die. Also makes sure that - # fatal crashes result in process termination rather than an error dialog - # (the latter is annoying since we have a lot of processes). This is done by - # associating all child processes with a "job" object that imposes this - # behavior. - (lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501 + # Makes sure that all child processes die when we die. Also makes sure that + # fatal crashes result in process termination rather than an error dialog + # (the latter is annoying since we have a lot of processes). This is done + # by associating all child processes with a "job" object that imposes this + # behavior. + (lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501 diff --git a/python/ray/actor.py b/python/ray/actor.py index 1b2458e8c..7fd9fad62 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -18,406 +18,419 @@ from ray.utils import (FunctionProperties, binary_to_hex, hex_to_binary, def random_actor_id(): - return ray.local_scheduler.ObjectID(random_string()) + return ray.local_scheduler.ObjectID(random_string()) def random_actor_class_id(): - return random_string() + return random_string() def get_actor_method_function_id(attr): - """Get the function ID corresponding to an actor method. + """Get the function ID corresponding to an actor method. - Args: - attr (str): The attribute name of the method. + Args: + attr (str): The attribute name of the method. - Returns: - Function ID corresponding to the method. - """ - function_id_hash = hashlib.sha1() - function_id_hash.update(attr.encode("ascii")) - function_id = function_id_hash.digest() - assert len(function_id) == 20 - return ray.local_scheduler.ObjectID(function_id) + Returns: + Function ID corresponding to the method. + """ + function_id_hash = hashlib.sha1() + function_id_hash.update(attr.encode("ascii")) + function_id = function_id_hash.digest() + assert len(function_id) == 20 + return ray.local_scheduler.ObjectID(function_id) def fetch_and_register_actor(actor_class_key, worker): - """Import an actor. + """Import an actor. - This will be called by the worker's import thread when the worker receives - the actor_class export, assuming that the worker is an actor for that class. - """ - actor_id_str = worker.actor_id - (driver_id, class_id, class_name, - module, pickled_class, actor_method_names) = worker.redis_client.hmget( - actor_class_key, ["driver_id", "class_id", "class_name", "module", - "class", "actor_method_names"]) + This will be called by the worker's import thread when the worker receives + the actor_class export, assuming that the worker is an actor for that + class. + """ + actor_id_str = worker.actor_id + (driver_id, class_id, class_name, + module, pickled_class, actor_method_names) = worker.redis_client.hmget( + actor_class_key, ["driver_id", "class_id", "class_name", "module", + "class", "actor_method_names"]) - actor_name = class_name.decode("ascii") - module = module.decode("ascii") - actor_method_names = json.loads(actor_method_names.decode("ascii")) + actor_name = class_name.decode("ascii") + module = module.decode("ascii") + actor_method_names = json.loads(actor_method_names.decode("ascii")) - # Create a temporary actor with some temporary methods so that if the actor - # fails to be unpickled, the temporary actor can be used (just to produce - # error messages and to prevent the driver from hanging). - class TemporaryActor(object): - pass - worker.actors[actor_id_str] = TemporaryActor() + # Create a temporary actor with some temporary methods so that if the actor + # fails to be unpickled, the temporary actor can be used (just to produce + # error messages and to prevent the driver from hanging). + class TemporaryActor(object): + pass + worker.actors[actor_id_str] = TemporaryActor() - def temporary_actor_method(*xs): - raise Exception("The actor with name {} failed to be imported, and so " - "cannot execute this method".format(actor_name)) - for actor_method_name in actor_method_names: - function_id = get_actor_method_function_id(actor_method_name).id() - worker.functions[driver_id][function_id] = (actor_method_name, - temporary_actor_method) - worker.function_properties[driver_id][function_id] = FunctionProperties( - num_return_vals=1, - num_cpus=1, - num_gpus=0, - max_calls=0) - worker.num_task_executions[driver_id][function_id] = 0 + def temporary_actor_method(*xs): + raise Exception("The actor with name {} failed to be imported, and so " + "cannot execute this method".format(actor_name)) + for actor_method_name in actor_method_names: + function_id = get_actor_method_function_id(actor_method_name).id() + worker.functions[driver_id][function_id] = (actor_method_name, + temporary_actor_method) + worker.function_properties[driver_id][function_id] = ( + FunctionProperties(num_return_vals=1, + num_cpus=1, + num_gpus=0, + max_calls=0)) + worker.num_task_executions[driver_id][function_id] = 0 - try: - unpickled_class = pickle.loads(pickled_class) - except Exception: - # If an exception was thrown when the actor was imported, we record the - # traceback and notify the scheduler of the failure. - traceback_str = ray.worker.format_error_message(traceback.format_exc()) - # Log the error message. - worker.push_error_to_driver(driver_id, "register_actor", traceback_str, - data={"actor_id": actor_id_str}) - else: - # TODO(pcm): Why is the below line necessary? - unpickled_class.__module__ = module - worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) - for (k, v) in inspect.getmembers( - unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or - inspect.ismethod(x)))): - function_id = get_actor_method_function_id(k).id() - worker.functions[driver_id][function_id] = (k, v) - # We do not set worker.function_properties[driver_id][function_id] - # because we currently do need the actor worker to submit new tasks for - # the actor. + try: + unpickled_class = pickle.loads(pickled_class) + except Exception: + # If an exception was thrown when the actor was imported, we record the + # traceback and notify the scheduler of the failure. + traceback_str = ray.worker.format_error_message(traceback.format_exc()) + # Log the error message. + worker.push_error_to_driver(driver_id, "register_actor", traceback_str, + data={"actor_id": actor_id_str}) + else: + # TODO(pcm): Why is the below line necessary? + unpickled_class.__module__ = module + worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) + for (k, v) in inspect.getmembers( + unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or + inspect.ismethod(x)))): + function_id = get_actor_method_function_id(k).id() + worker.functions[driver_id][function_id] = (k, v) + # We do not set worker.function_properties[driver_id][function_id] + # because we currently do need the actor worker to submit new tasks + # for the actor. def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker): - """Attempt to acquire GPUs on a particular local scheduler for an actor. + """Attempt to acquire GPUs on a particular local scheduler for an actor. - Args: - num_gpus: The number of GPUs to acquire. - driver_id: The ID of the driver responsible for creating the actor. - local_scheduler: Information about the local scheduler. + Args: + num_gpus: The number of GPUs to acquire. + driver_id: The ID of the driver responsible for creating the actor. + local_scheduler: Information about the local scheduler. - Returns: - True if the GPUs were successfully reserved and false otherwise. - """ - assert num_gpus != 0 - local_scheduler_id = local_scheduler["DBClientID"] - local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) + Returns: + True if the GPUs were successfully reserved and false otherwise. + """ + assert num_gpus != 0 + local_scheduler_id = local_scheduler["DBClientID"] + local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) - success = False + success = False - # Attempt to acquire GPU IDs atomically. - with worker.redis_client.pipeline() as pipe: - while True: - try: - # If this key is changed before the transaction below (the multi/exec - # block), then the transaction will not take place. - pipe.watch(local_scheduler_id) + # Attempt to acquire GPU IDs atomically. + with worker.redis_client.pipeline() as pipe: + while True: + try: + # If this key is changed before the transaction below (the + # multi/exec block), then the transaction will not take place. + pipe.watch(local_scheduler_id) - # Figure out which GPUs are currently in use. - result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use") - gpus_in_use = dict() if result is None else json.loads( - result.decode("ascii")) - num_gpus_in_use = 0 - for key in gpus_in_use: - num_gpus_in_use += gpus_in_use[key] - assert num_gpus_in_use <= local_scheduler_total_gpus + # Figure out which GPUs are currently in use. + result = worker.redis_client.hget(local_scheduler_id, + "gpus_in_use") + gpus_in_use = dict() if result is None else json.loads( + result.decode("ascii")) + num_gpus_in_use = 0 + for key in gpus_in_use: + num_gpus_in_use += gpus_in_use[key] + assert num_gpus_in_use <= local_scheduler_total_gpus - pipe.multi() + pipe.multi() - if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus: - # There are enough available GPUs, so try to reserve some. We use the - # hex driver ID in hex as a dictionary key so that the dictionary is - # JSON serializable. - driver_id_hex = binary_to_hex(driver_id) - if driver_id_hex not in gpus_in_use: - gpus_in_use[driver_id_hex] = 0 - gpus_in_use[driver_id_hex] += num_gpus + if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus: + # There are enough available GPUs, so try to reserve some. + # We use the hex driver ID in hex as a dictionary key so + # that the dictionary is JSON serializable. + driver_id_hex = binary_to_hex(driver_id) + if driver_id_hex not in gpus_in_use: + gpus_in_use[driver_id_hex] = 0 + gpus_in_use[driver_id_hex] += num_gpus - # Stick the updated GPU IDs back in Redis - pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use)) - success = True + # Stick the updated GPU IDs back in Redis + pipe.hset(local_scheduler_id, "gpus_in_use", + json.dumps(gpus_in_use)) + success = True - pipe.execute() - # If a WatchError is not raised, then the operations should have gone - # through atomically. - break - except redis.WatchError: - # Another client must have changed the watched key between the time we - # started WATCHing it and the pipeline's execution. We should just - # retry. - success = False - continue + pipe.execute() + # If a WatchError is not raised, then the operations should + # have gone through atomically. + break + except redis.WatchError: + # Another client must have changed the watched key between the + # time we started WATCHing it and the pipeline's execution. We + # should just retry. + success = False + continue - return success + return success def select_local_scheduler(local_schedulers, num_gpus, worker): - """Select a local scheduler to assign this actor to. + """Select a local scheduler to assign this actor to. - Args: - local_schedulers: A list of dictionaries of information about the local - schedulers. - num_gpus (int): The number of GPUs that must be reserved for this actor. + Args: + local_schedulers: A list of dictionaries of information about the local + schedulers. + num_gpus (int): The number of GPUs that must be reserved for this + actor. - Returns: - The ID of the local scheduler that has been chosen. + Returns: + The ID of the local scheduler that has been chosen. - Raises: - Exception: An exception is raised if no local scheduler can be found with - sufficient resources. - """ - driver_id = worker.task_driver_id.id() + Raises: + Exception: An exception is raised if no local scheduler can be found + with sufficient resources. + """ + driver_id = worker.task_driver_id.id() - local_scheduler_id = None - # Loop through all of the local schedulers in a random order. - local_schedulers = np.random.permutation(local_schedulers) - for local_scheduler in local_schedulers: - if local_scheduler["NumCPUs"] < 1: - continue - if local_scheduler["NumGPUs"] < num_gpus: - continue - if num_gpus == 0: - local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"]) - break - else: - # Try to reserve enough GPUs on this local scheduler. - success = attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, - worker) - if success: - local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"]) - break + local_scheduler_id = None + # Loop through all of the local schedulers in a random order. + local_schedulers = np.random.permutation(local_schedulers) + for local_scheduler in local_schedulers: + if local_scheduler["NumCPUs"] < 1: + continue + if local_scheduler["NumGPUs"] < num_gpus: + continue + if num_gpus == 0: + local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"]) + break + else: + # Try to reserve enough GPUs on this local scheduler. + success = attempt_to_reserve_gpus(num_gpus, driver_id, + local_scheduler, worker) + if success: + local_scheduler_id = hex_to_binary( + local_scheduler["DBClientID"]) + break - if local_scheduler_id is None: - raise Exception("Could not find a node with enough GPUs or other " - "resources to create this actor. The local scheduler " - "information is {}.".format(local_schedulers)) + if local_scheduler_id is None: + raise Exception("Could not find a node with enough GPUs or other " + "resources to create this actor. The local scheduler " + "information is {}.".format(local_schedulers)) - return local_scheduler_id + return local_scheduler_id def export_actor_class(class_id, Class, actor_method_names, worker): - if worker.mode is None: - raise NotImplemented("TODO(pcm): Cache actors") - key = b"ActorClass:" + class_id - d = {"driver_id": worker.task_driver_id.id(), - "class_name": Class.__name__, - "module": Class.__module__, - "class": pickle.dumps(Class), - "actor_method_names": json.dumps(list(actor_method_names))} - worker.redis_client.hmset(key, d) - worker.redis_client.rpush("Exports", key) + if worker.mode is None: + raise NotImplemented("TODO(pcm): Cache actors") + key = b"ActorClass:" + class_id + d = {"driver_id": worker.task_driver_id.id(), + "class_name": Class.__name__, + "module": Class.__module__, + "class": pickle.dumps(Class), + "actor_method_names": json.dumps(list(actor_method_names))} + worker.redis_client.hmset(key, d) + worker.redis_client.rpush("Exports", key) def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus, worker): - """Export an actor to redis. + """Export an actor to redis. - Args: - actor_id: The ID of the actor. - actor_method_names (list): A list of the names of this actor's methods. - num_cpus (int): The number of CPUs that this actor requires. - num_gpus (int): The number of GPUs that this actor requires. - """ - ray.worker.check_main_thread() - if worker.mode is None: - raise Exception("Actors cannot be created before Ray has been started. " - "You can start Ray with 'ray.init()'.") - key = b"Actor:" + actor_id.id() + Args: + actor_id: The ID of the actor. + actor_method_names (list): A list of the names of this actor's methods. + num_cpus (int): The number of CPUs that this actor requires. + num_gpus (int): The number of GPUs that this actor requires. + """ + ray.worker.check_main_thread() + if worker.mode is None: + raise Exception("Actors cannot be created before Ray has been " + "started. You can start Ray with 'ray.init()'.") + key = b"Actor:" + actor_id.id() - # For now, all actor methods have 1 return value. - driver_id = worker.task_driver_id.id() - for actor_method_name in actor_method_names: - # TODO(rkn): When we create a second actor, we are probably overwriting - # the values from the first actor here. This may or may not be a problem. - function_id = get_actor_method_function_id(actor_method_name).id() - worker.function_properties[driver_id][function_id] = FunctionProperties( - num_return_vals=1, - num_cpus=1, - num_gpus=0, - max_calls=0) + # For now, all actor methods have 1 return value. + driver_id = worker.task_driver_id.id() + for actor_method_name in actor_method_names: + # TODO(rkn): When we create a second actor, we are probably overwriting + # the values from the first actor here. This may or may not be a + # problem. + function_id = get_actor_method_function_id(actor_method_name).id() + worker.function_properties[driver_id][function_id] = ( + FunctionProperties(num_return_vals=1, + num_cpus=1, + num_gpus=0, + max_calls=0)) - # Get a list of the local schedulers from the client table. - client_table = ray.global_state.client_table() - local_schedulers = [] - for ip_address, clients in client_table.items(): - for client in clients: - if client["ClientType"] == "local_scheduler" and not client["Deleted"]: - local_schedulers.append(client) - # Select a local scheduler for the actor. - local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus, - worker) - assert local_scheduler_id is not None + # Get a list of the local schedulers from the client table. + client_table = ray.global_state.client_table() + local_schedulers = [] + for ip_address, clients in client_table.items(): + for client in clients: + if (client["ClientType"] == "local_scheduler" and + not client["Deleted"]): + local_schedulers.append(client) + # Select a local scheduler for the actor. + local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus, + worker) + assert local_scheduler_id is not None - # We must put the actor information in Redis before publishing the actor - # notification so that when the newly created actor attempts to fetch the - # information from Redis, it is already there. - worker.redis_client.hmset(key, {"class_id": class_id, - "num_gpus": num_gpus}) + # We must put the actor information in Redis before publishing the actor + # notification so that when the newly created actor attempts to fetch the + # information from Redis, it is already there. + worker.redis_client.hmset(key, {"class_id": class_id, + "num_gpus": num_gpus}) - # Really we should encode this message as a flatbuffer object. However, we're - # having trouble getting that to work. It almost works, but in Python 2.7, - # builder.CreateString fails on byte strings that contain characters outside - # range(128). + # Really we should encode this message as a flatbuffer object. However, + # we're having trouble getting that to work. It almost works, but in Python + # 2.7, builder.CreateString fails on byte strings that contain characters + # outside range(128). - # TODO(rkn): There is actually no guarantee that the local scheduler that we - # are publishing to has already subscribed to the actor_notifications - # channel. Therefore, this message may be missed and the workload will hang. - # This is a bug. - worker.redis_client.publish("actor_notifications", - actor_id.id() + driver_id + local_scheduler_id) + # TODO(rkn): There is actually no guarantee that the local scheduler that + # we are publishing to has already subscribed to the actor_notifications + # channel. Therefore, this message may be missed and the workload will + # hang. This is a bug. + worker.redis_client.publish("actor_notifications", + actor_id.id() + driver_id + local_scheduler_id) def actor(*args, **kwargs): - raise Exception("The @ray.actor decorator is deprecated. Instead, please " - "use @ray.remote.") + raise Exception("The @ray.actor decorator is deprecated. Instead, please " + "use @ray.remote.") def make_actor(cls, num_cpus, num_gpus): - # Modify the class to have an additional method that will be used for - # terminating the worker. - class Class(cls): - def __ray_terminate__(self): - ray.worker.global_worker.local_scheduler_client.disconnect() - import os - os._exit(0) + # Modify the class to have an additional method that will be used for + # terminating the worker. + class Class(cls): + def __ray_terminate__(self): + ray.worker.global_worker.local_scheduler_client.disconnect() + import os + os._exit(0) - Class.__module__ = cls.__module__ - Class.__name__ = cls.__name__ + Class.__module__ = cls.__module__ + Class.__name__ = cls.__name__ - class_id = random_actor_class_id() - # The list exported will have length 0 if the class has not been exported - # yet, and length one if it has. This is just implementing a bool, but we - # don't use a bool because we need to modify it inside of the NewClass - # constructor. - exported = [] + class_id = random_actor_class_id() + # The list exported will have length 0 if the class has not been exported + # yet, and length one if it has. This is just implementing a bool, but we + # don't use a bool because we need to modify it inside of the NewClass + # constructor. + exported = [] - # The function actor_method_call gets called if somebody tries to call a - # method on their local actor stub object. - def actor_method_call(actor_id, attr, function_signature, *args, **kwargs): - ray.worker.check_connected() - ray.worker.check_main_thread() - args = signature.extend_args(function_signature, args, kwargs) + # The function actor_method_call gets called if somebody tries to call a + # method on their local actor stub object. + def actor_method_call(actor_id, attr, function_signature, *args, **kwargs): + ray.worker.check_connected() + ray.worker.check_main_thread() + args = signature.extend_args(function_signature, args, kwargs) - function_id = get_actor_method_function_id(attr) - object_ids = ray.worker.global_worker.submit_task(function_id, "", args, - actor_id=actor_id) - if len(object_ids) == 1: - return object_ids[0] - elif len(object_ids) > 1: - return object_ids + function_id = get_actor_method_function_id(attr) + object_ids = ray.worker.global_worker.submit_task(function_id, "", + args, + actor_id=actor_id) + if len(object_ids) == 1: + return object_ids[0] + elif len(object_ids) > 1: + return object_ids - class ActorMethod(object): - def __init__(self, method_name, actor_id, method_signature): - self.method_name = method_name - self.actor_id = actor_id - self.method_signature = method_signature + class ActorMethod(object): + def __init__(self, method_name, actor_id, method_signature): + self.method_name = method_name + self.actor_id = actor_id + self.method_signature = method_signature - def __call__(self, *args, **kwargs): - raise Exception("Actor methods cannot be called directly. Instead " - "of running 'object.{}()', try 'object.{}.remote()'." - .format(self.method_name, self.method_name)) + def __call__(self, *args, **kwargs): + raise Exception("Actor methods cannot be called directly. Instead " + "of running 'object.{}()', try " + "'object.{}.remote()'." + .format(self.method_name, self.method_name)) - def remote(self, *args, **kwargs): - return actor_method_call(self.actor_id, self.method_name, - self.method_signature, *args, **kwargs) + def remote(self, *args, **kwargs): + return actor_method_call(self.actor_id, self.method_name, + self.method_signature, *args, **kwargs) - class NewClass(object): - def __init__(self, *args, **kwargs): - raise Exception("Actor classes cannot be instantiated directly. " - "Instead of running '{}()', try '{}.remote()'." - .format(Class.__name__, Class.__name__)) + class NewClass(object): + def __init__(self, *args, **kwargs): + raise Exception("Actor classes cannot be instantiated directly. " + "Instead of running '{}()', try '{}.remote()'." + .format(Class.__name__, Class.__name__)) - @classmethod - def remote(cls, *args, **kwargs): - actor_object = cls.__new__(cls) - actor_object._manual_init(*args, **kwargs) - return actor_object + @classmethod + def remote(cls, *args, **kwargs): + actor_object = cls.__new__(cls) + actor_object._manual_init(*args, **kwargs) + return actor_object - def _manual_init(self, *args, **kwargs): - self._ray_actor_id = random_actor_id() - self._ray_actor_methods = { - k: v for (k, v) in inspect.getmembers( - Class, predicate=(lambda x: (inspect.isfunction(x) or - inspect.ismethod(x))))} - # Extract the signatures of each of the methods. This will be used to - # catch some errors if the methods are called with inappropriate - # arguments. - self._ray_method_signatures = dict() - for k, v in self._ray_actor_methods.items(): - # Print a warning message if the method signature is not supported. - # We don't raise an exception because if the actor inherits from a - # class that has a method whose signature we don't support, we - # there may not be much the user can do about it. - signature.check_signature_supported(v, warn=True) - self._ray_method_signatures[k] = signature.extract_signature( - v, ignore_first=True) + def _manual_init(self, *args, **kwargs): + self._ray_actor_id = random_actor_id() + self._ray_actor_methods = { + k: v for (k, v) in inspect.getmembers( + Class, predicate=(lambda x: (inspect.isfunction(x) or + inspect.ismethod(x))))} + # Extract the signatures of each of the methods. This will be used + # to catch some errors if the methods are called with inappropriate + # arguments. + self._ray_method_signatures = dict() + for k, v in self._ray_actor_methods.items(): + # Print a warning message if the method signature is not + # supported. We don't raise an exception because if the actor + # inherits from a class that has a method whose signature we + # don't support, we there may not be much the user can do about + # it. + signature.check_signature_supported(v, warn=True) + self._ray_method_signatures[k] = signature.extract_signature( + v, ignore_first=True) - # Create objects to wrap method invocations. This is done so that we - # can invoke methods with actor.method.remote() instead of - # actor.method(). - self._actor_method_invokers = dict() - for k, v in self._ray_actor_methods.items(): - self._actor_method_invokers[k] = ActorMethod( - k, self._ray_actor_id, self._ray_method_signatures[k]) + # Create objects to wrap method invocations. This is done so that + # we can invoke methods with actor.method.remote() instead of + # actor.method(). + self._actor_method_invokers = dict() + for k, v in self._ray_actor_methods.items(): + self._actor_method_invokers[k] = ActorMethod( + k, self._ray_actor_id, self._ray_method_signatures[k]) - # Export the actor class if it has not been exported yet. - if len(exported) == 0: - export_actor_class(class_id, Class, self._ray_actor_methods.keys(), - ray.worker.global_worker) - exported.append(0) - # Export the actor. - export_actor(self._ray_actor_id, class_id, - self._ray_actor_methods.keys(), num_cpus, num_gpus, - ray.worker.global_worker) - # Call __init__ as a remote function. - if "__init__" in self._ray_actor_methods.keys(): - actor_method_call(self._ray_actor_id, "__init__", - self._ray_method_signatures["__init__"], - *args, **kwargs) - else: - print("WARNING: this object has no __init__ method.") + # Export the actor class if it has not been exported yet. + if len(exported) == 0: + export_actor_class(class_id, Class, + self._ray_actor_methods.keys(), + ray.worker.global_worker) + exported.append(0) + # Export the actor. + export_actor(self._ray_actor_id, class_id, + self._ray_actor_methods.keys(), num_cpus, num_gpus, + ray.worker.global_worker) + # Call __init__ as a remote function. + if "__init__" in self._ray_actor_methods.keys(): + actor_method_call(self._ray_actor_id, "__init__", + self._ray_method_signatures["__init__"], + *args, **kwargs) + else: + print("WARNING: this object has no __init__ method.") - # Make tab completion work. - def __dir__(self): - return self._ray_actor_methods + # Make tab completion work. + def __dir__(self): + return self._ray_actor_methods - def __getattribute__(self, attr): - # The following is needed so we can still access self.actor_methods. - if attr in ["_manual_init", "_ray_actor_id", "_ray_actor_methods", - "_actor_method_invokers", "_ray_method_signatures"]: - return object.__getattribute__(self, attr) - if attr in self._ray_actor_methods.keys(): - return self._actor_method_invokers[attr] - # There is no method with this name, so raise an exception. - raise AttributeError("'{}' Actor object has no attribute '{}'" - .format(Class, attr)) + def __getattribute__(self, attr): + # The following is needed so we can still access + # self.actor_methods. + if attr in ["_manual_init", "_ray_actor_id", "_ray_actor_methods", + "_actor_method_invokers", "_ray_method_signatures"]: + return object.__getattribute__(self, attr) + if attr in self._ray_actor_methods.keys(): + return self._actor_method_invokers[attr] + # There is no method with this name, so raise an exception. + raise AttributeError("'{}' Actor object has no attribute '{}'" + .format(Class, attr)) - def __repr__(self): - return "Actor(" + self._ray_actor_id.hex() + ")" + def __repr__(self): + return "Actor(" + self._ray_actor_id.hex() + ")" - def __reduce__(self): - raise Exception("Actor objects cannot be pickled.") + def __reduce__(self): + raise Exception("Actor objects cannot be pickled.") - def __del__(self): - """Kill the worker that is running this actor.""" - if ray.worker.global_worker.connected: - actor_method_call(self._ray_actor_id, "__ray_terminate__", - self._ray_method_signatures["__ray_terminate__"]) + def __del__(self): + """Kill the worker that is running this actor.""" + if ray.worker.global_worker.connected: + actor_method_call( + self._ray_actor_id, "__ray_terminate__", + self._ray_method_signatures["__ray_terminate__"]) - return NewClass + return NewClass ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index 42e6057cd..d826b2056 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -23,423 +23,436 @@ OBJECT_CHANNEL_PREFIX = "OC:" def integerToAsciiHex(num, numbytes): - retstr = b"" - # Support 32 and 64 bit architecture. - assert(numbytes == 4 or numbytes == 8) - for i in range(numbytes): - curbyte = num & 0xff - if sys.version_info >= (3, 0): - retstr += bytes([curbyte]) - else: - retstr += chr(curbyte) - num = num >> 8 + retstr = b"" + # Support 32 and 64 bit architecture. + assert(numbytes == 4 or numbytes == 8) + for i in range(numbytes): + curbyte = num & 0xff + if sys.version_info >= (3, 0): + retstr += bytes([curbyte]) + else: + retstr += chr(curbyte) + num = num >> 8 - return retstr + return retstr def get_next_message(pubsub_client, timeout_seconds=10): - """Block until the next message is available on the pubsub channel.""" - start_time = time.time() - while True: - message = pubsub_client.get_message() - if message is not None: - return message - time.sleep(0.1) - if time.time() - start_time > timeout_seconds: - raise Exception("Timed out while waiting for next message.") + """Block until the next message is available on the pubsub channel.""" + start_time = time.time() + while True: + message = pubsub_client.get_message() + if message is not None: + return message + time.sleep(0.1) + if time.time() - start_time > timeout_seconds: + raise Exception("Timed out while waiting for next message.") class TestGlobalStateStore(unittest.TestCase): - def setUp(self): - redis_port, _ = ray.services.start_redis_instance() - self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0) + def setUp(self): + redis_port, _ = ray.services.start_redis_instance() + self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0) - def tearDown(self): - ray.services.cleanup() + def tearDown(self): + ray.services.cleanup() - def testInvalidObjectTableAdd(self): - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called - # with the wrong arguments. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one", - "hash2", "manager_id1") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, - "hash2", "manager_id1", "extra argument") - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an - # object ID that is already present with a different hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1"}) - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - # Check that the second manager was added, even though the hash was - # mismatched. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that it is fine if we add the same object ID multiple times with - # the most recent hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, - "hash2", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) + def testInvalidObjectTableAdd(self): + # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called + # with the wrong arguments. + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_ADD") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", + "one", "hash2", "manager_id1") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, + "hash2", "manager_id1", + "extra argument") + # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an + # object ID that is already present with a different hash. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1"}) + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id2") + # Check that the second manager was added, even though the hash was + # mismatched. + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) + # Check that it is fine if we add the same object ID multiple times + # with the most recent hash. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash2", "manager_id2") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, + "hash2", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - def testObjectTableAddAndLookup(self): - # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been - # added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(response, None) - # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Add a manager that already exists again and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that we properly handle NULL characters. In the past, NULL - # characters were handled improperly causing a "hash mismatch" error if two - # object IDs that agreed up to the NULL character were inserted with - # different hashes. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, - "hash2", "manager_id1") - # Check that NULL characters in the hash are handled properly. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash1", "manager_id1") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash2", "manager_id1") + def testObjectTableAddAndLookup(self): + # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not + # been added yet. + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(response, None) + # Add some managers and try again. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) + # Add a manager that already exists again and try again. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) + # Check that we properly handle NULL characters. In the past, NULL + # characters were handled improperly causing a "hash mismatch" error if + # two object IDs that agreed up to the NULL character were inserted + # with different hashes. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, + "hash2", "manager_id1") + # Check that NULL characters in the hash are handled properly. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, + "\x00hash1", "manager_id1") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, + "\x00hash2", "manager_id1") - def testObjectTableAddAndRemove(self): - # Try removing a manager from an object ID that has not been added yet. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been - # added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(response, None) - # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that doesn't exist, and make sure we still have the same - # set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that does exist. Make sure it gets removed the first - # time and does nothing the second time. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id2"}) - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id2"}) - # Remove the last manager, and make sure we have an empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), set()) - # Remove a manager from an empty set, and make sure we now have an empty - # set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), set()) + def testObjectTableAddAndRemove(self): + # Try removing a manager from an object ID that has not been added yet. + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id1") + # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not + # been added yet. + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(response, None) + # Add some managers and try again. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) + # Remove a manager that doesn't exist, and make sure we still have the + # same set. + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id3") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) + # Remove a manager that does exist. Make sure it gets removed the first + # time and does nothing the second time. + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id2"}) + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id1") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), {b"manager_id2"}) + # Remove the last manager, and make sure we have an empty set. + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id2") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), set()) + # Remove a manager from an empty set, and make sure we now have an + # empty set. + self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", + "manager_id3") + response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", + "object_id1") + self.assertEqual(set(response), set()) - def testObjectTableSubscribeToNotifications(self): - # Define a helper method for checking the contents of object notifications. - def check_object_notification(notification_message, object_id, object_size, - manager_ids): - notification_object = (SubscribeToNotificationsReply - .GetRootAsSubscribeToNotificationsReply( - notification_message, 0)) - self.assertEqual(notification_object.ObjectId(), object_id) - self.assertEqual(notification_object.ObjectSize(), object_size) - self.assertEqual(notification_object.ManagerIdsLength(), - len(manager_ids)) - for i in range(len(manager_ids)): - self.assertEqual(notification_object.ManagerIds(i), manager_ids[i]) + def testObjectTableSubscribeToNotifications(self): + # Define a helper method for checking the contents of object + # notifications. + def check_object_notification(notification_message, object_id, + object_size, manager_ids): + notification_object = (SubscribeToNotificationsReply + .GetRootAsSubscribeToNotificationsReply( + notification_message, 0)) + self.assertEqual(notification_object.ObjectId(), object_id) + self.assertEqual(notification_object.ObjectSize(), object_size) + self.assertEqual(notification_object.ManagerIdsLength(), + len(manager_ids)) + for i in range(len(manager_ids)): + self.assertEqual(notification_object.ManagerIds(i), + manager_ids[i]) - data_size = 0xf1f0 - p = self.redis.pubsub() - # Subscribe to an object ID. - p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX)) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size, - "hash1", "manager_id2") - # Receive the acknowledgement message. - self.assertEqual(get_next_message(p)["data"], 1) - # Request a notification and receive the data. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id1") - # Verify that the notification is correct. - check_object_notification(get_next_message(p)["data"], - b"object_id1", - data_size, - [b"manager_id2"]) + data_size = 0xf1f0 + p = self.redis.pubsub() + # Subscribe to an object ID. + p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX)) + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", + data_size, "hash1", "manager_id2") + # Receive the acknowledgement message. + self.assertEqual(get_next_message(p)["data"], 1) + # Request a notification and receive the data. + self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", + "manager_id1", "object_id1") + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id1", + data_size, + [b"manager_id2"]) - # Request a notification for an object that isn't there. Then add the - # object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD - # should trigger notifications. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id2", "object_id3") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, - "hash1", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, - "hash1", "manager_id3") - # Verify that the notification is correct. - check_object_notification(get_next_message(p)["data"], - b"object_id3", - data_size, - [b"manager_id1"]) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, - "hash1", "manager_id3") - # Verify that the notification is correct. - check_object_notification(get_next_message(p)["data"], - b"object_id2", - data_size, - [b"manager_id3"]) - # Request notifications for object_id3 again. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id3") - # Verify that the notification is correct. - check_object_notification(get_next_message(p)["data"], - b"object_id3", - data_size, - [b"manager_id1", b"manager_id2", b"manager_id3"]) + # Request a notification for an object that isn't there. Then add the + # object and receive the data. Only the first call to + # RAY.OBJECT_TABLE_ADD should trigger notifications. + self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", + "manager_id1", "object_id2", "object_id3") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", + data_size, "hash1", "manager_id1") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", + data_size, "hash1", "manager_id2") + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", + data_size, "hash1", "manager_id3") + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id3", + data_size, + [b"manager_id1"]) + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", + data_size, "hash1", "manager_id3") + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id2", + data_size, + [b"manager_id3"]) + # Request notifications for object_id3 again. + self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", + "manager_id1", "object_id3") + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id3", + data_size, + [b"manager_id1", b"manager_id2", + b"manager_id3"]) - def testResultTableAddAndLookup(self): - def check_result_table_entry(message, task_id, is_put): - result_table_reply = ResultTableReply.GetRootAsResultTableReply(message, - 0) - self.assertEqual(result_table_reply.TaskId(), task_id) - self.assertEqual(result_table_reply.IsPut(), is_put) + def testResultTableAddAndLookup(self): + def check_result_table_entry(message, task_id, is_put): + result_table_reply = ResultTableReply.GetRootAsResultTableReply( + message, 0) + self.assertEqual(result_table_reply.TaskId(), task_id) + self.assertEqual(result_table_reply.IsPut(), is_put) - # Try looking up something in the result table before anything is added. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - self.assertIsNone(response) - # Adding the object to the object table should have no effect. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - self.assertIsNone(response) - # Add the result to the result table. The lookup now returns the task ID. - task_id = b"task_id1" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id, - 0) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - check_result_table_entry(response, task_id, False) - # Doing it again should still work. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - check_result_table_entry(response, task_id, False) - # Try another result table lookup. This should succeed. - task_id = b"task_id2" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id, - 1) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id2") - check_result_table_entry(response, task_id, True) + # Try looking up something in the result table before anything is + # added. + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") + self.assertIsNone(response) + # Adding the object to the object table should have no effect. + self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, + "hash1", "manager_id1") + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") + self.assertIsNone(response) + # Add the result to the result table. The lookup now returns the task + # ID. + task_id = b"task_id1" + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", + task_id, 0) + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") + check_result_table_entry(response, task_id, False) + # Doing it again should still work. + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id1") + check_result_table_entry(response, task_id, False) + # Try another result table lookup. This should succeed. + task_id = b"task_id2" + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", + task_id, 1) + response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", + "object_id2") + check_result_table_entry(response, task_id, True) - def testInvalidTaskTableAdd(self): - # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with - # the wrong arguments. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, "node_id") - with self.assertRaises(redis.ResponseError): - # Non-integer scheduling states should not be added. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - "invalid_state", "node_id", "task_spec") - with self.assertRaises(redis.ResponseError): - # Should not be able to update a non-existent task. - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, - "node_id") + def testInvalidTaskTableAdd(self): + # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called + # with the wrong arguments. + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.TASK_TABLE_ADD") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello") + with self.assertRaises(redis.ResponseError): + self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, + "node_id") + with self.assertRaises(redis.ResponseError): + # Non-integer scheduling states should not be added. + self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", + "invalid_state", "node_id", "task_spec") + with self.assertRaises(redis.ResponseError): + # Should not be able to update a non-existent task. + self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, + "node_id") - def testTaskTableAddAndLookup(self): - TASK_STATUS_WAITING = 1 - TASK_STATUS_SCHEDULED = 2 - TASK_STATUS_QUEUED = 4 + def testTaskTableAddAndLookup(self): + TASK_STATUS_WAITING = 1 + TASK_STATUS_SCHEDULED = 2 + TASK_STATUS_QUEUED = 4 - # make sure somebody will get a notification (checked in the redis module) - p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + # make sure somebody will get a notification (checked in the redis + # module) + p = self.redis.pubsub() + p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) - def check_task_reply(message, task_args, updated=False): - task_status, local_scheduler_id, task_spec = task_args - task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) - self.assertEqual(task_reply_object.State(), task_status) - self.assertEqual(task_reply_object.LocalSchedulerId(), - local_scheduler_id) - self.assertEqual(task_reply_object.TaskSpec(), task_spec) - self.assertEqual(task_reply_object.Updated(), updated) + def check_task_reply(message, task_args, updated=False): + task_status, local_scheduler_id, task_spec = task_args + task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) + self.assertEqual(task_reply_object.State(), task_status) + self.assertEqual(task_reply_object.LocalSchedulerId(), + local_scheduler_id) + self.assertEqual(task_reply_object.TaskSpec(), task_spec) + self.assertEqual(task_reply_object.Updated(), updated) - # Check that task table adds, updates, and lookups work correctly. - task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"] - response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - *task_args) - response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(response, task_args) + # Check that task table adds, updates, and lookups work correctly. + task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"] + response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", + *task_args) + response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + check_task_reply(response, task_args) - task_args[0] = TASK_STATUS_SCHEDULED - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", - *task_args[:2]) - response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(response, task_args) + task_args[0] = TASK_STATUS_SCHEDULED + self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", + *task_args[:2]) + response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + check_task_reply(response, task_args) - # If the current value, test value, and set value are all the same, the - # update happens, and the response is still the same task. - task_args = [task_args[0]] + task_args - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", - *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - # Check that the task entry is still the same. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(get_response, task_args[1:]) + # If the current value, test value, and set value are all the same, the + # update happens, and the response is still the same task. + task_args = [task_args[0]] + task_args + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *task_args[:3]) + check_task_reply(response, task_args[1:], updated=True) + # Check that the task entry is still the same. + get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", + "task_id") + check_task_reply(get_response, task_args[1:]) - # If the current value is the same as the test value, and the set value is - # different, the update happens, and the response is the entire task. - task_args[1] = TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", - *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - # Check that the update happened. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(get_response, task_args[1:]) + # If the current value is the same as the test value, and the set value + # is different, the update happens, and the response is the entire + # task. + task_args[1] = TASK_STATUS_QUEUED + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *task_args[:3]) + check_task_reply(response, task_args[1:], updated=True) + # Check that the update happened. + get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", + "task_id") + check_task_reply(get_response, task_args[1:]) - # If the current value is no longer the same as the test value, the - # response is the same task as before the test-and-set. - new_task_args = task_args[:] - new_task_args[1] = TASK_STATUS_WAITING - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", - *new_task_args[:3]) - check_task_reply(response, task_args[1:], updated=False) - # Check that the update did not happen. - get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertEqual(get_response2, get_response) + # If the current value is no longer the same as the test value, the + # response is the same task as before the test-and-set. + new_task_args = task_args[:] + new_task_args[1] = TASK_STATUS_WAITING + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *new_task_args[:3]) + check_task_reply(response, task_args[1:], updated=False) + # Check that the update did not happen. + get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", + "task_id") + self.assertEqual(get_response2, get_response) - # If the test value is a bitmask that matches the current value, the update - # happens. - task_args = new_task_args - task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", - *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) + # If the test value is a bitmask that matches the current value, the + # update happens. + task_args = new_task_args + task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *task_args[:3]) + check_task_reply(response, task_args[1:], updated=True) - # If the test value is a bitmask that does not match the current value, the - # update does not happen, and the response is the same task as before the - # test-and-set. - new_task_args = task_args[:] - new_task_args[0] = TASK_STATUS_SCHEDULED - old_response = response - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", - *new_task_args[:3]) - check_task_reply(response, task_args[1:], updated=False) - # Check that the update did not happen. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertNotEqual(get_response, old_response) - check_task_reply(get_response, task_args[1:]) + # If the test value is a bitmask that does not match the current value, + # the update does not happen, and the response is the same task as + # before the test-and-set. + new_task_args = task_args[:] + new_task_args[0] = TASK_STATUS_SCHEDULED + old_response = response + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *new_task_args[:3]) + check_task_reply(response, task_args[1:], updated=False) + # Check that the update did not happen. + get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", + "task_id") + self.assertNotEqual(get_response, old_response) + check_task_reply(get_response, task_args[1:]) - def check_task_subscription(self, p, scheduling_state, local_scheduler_id): - task_args = [b"task_id", scheduling_state, - local_scheduler_id.encode("ascii"), b"task_spec"] - self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) - # Receive the data. - message = get_next_message(p)["data"] - # Check that the notification object is correct. - notification_object = TaskReply.GetRootAsTaskReply(message, 0) - self.assertEqual(notification_object.TaskId(), b"task_id") - self.assertEqual(notification_object.State(), scheduling_state) - self.assertEqual(notification_object.LocalSchedulerId(), - local_scheduler_id.encode("ascii")) - self.assertEqual(notification_object.TaskSpec(), b"task_spec") + def check_task_subscription(self, p, scheduling_state, local_scheduler_id): + task_args = [b"task_id", scheduling_state, + local_scheduler_id.encode("ascii"), b"task_spec"] + self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) + # Receive the data. + message = get_next_message(p)["data"] + # Check that the notification object is correct. + notification_object = TaskReply.GetRootAsTaskReply(message, 0) + self.assertEqual(notification_object.TaskId(), b"task_id") + self.assertEqual(notification_object.State(), scheduling_state) + self.assertEqual(notification_object.LocalSchedulerId(), + local_scheduler_id.encode("ascii")) + self.assertEqual(notification_object.TaskSpec(), b"task_spec") - def testTaskTableSubscribe(self): - scheduling_state = 1 - local_scheduler_id = "local_scheduler_id" - # Subscribe to the task table. - p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - # unsubscribe to make sure there is only one subscriber at a given time - p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) + def testTaskTableSubscribe(self): + scheduling_state = 1 + local_scheduler_id = "local_scheduler_id" + # Subscribe to the task table. + p = self.redis.pubsub() + p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 1) + self.check_task_subscription(p, scheduling_state, local_scheduler_id) + # unsubscribe to make sure there is only one subscriber at a given time + p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 0) - p.psubscribe("{prefix}*:{state}".format( - prefix=TASK_PREFIX, state=scheduling_state)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}*:{state}".format( - prefix=TASK_PREFIX, state=scheduling_state)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) + p.psubscribe("{prefix}*:{state}".format( + prefix=TASK_PREFIX, state=scheduling_state)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 1) + self.check_task_subscription(p, scheduling_state, local_scheduler_id) + p.punsubscribe("{prefix}*:{state}".format( + prefix=TASK_PREFIX, state=scheduling_state)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 0) - p.psubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) + p.psubscribe("{prefix}{local_scheduler_id}:*".format( + prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 1) + self.check_task_subscription(p, scheduling_state, local_scheduler_id) + p.punsubscribe("{prefix}{local_scheduler_id}:*".format( + prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + # Receive acknowledgment. + self.assertEqual(get_next_message(p)["data"], 0) if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/python/ray/common/test/test.py b/python/ray/common/test/test.py index 6e0a62433..d3e648e85 100644 --- a/python/ray/common/test/test.py +++ b/python/ray/common/test/test.py @@ -13,19 +13,19 @@ ID_SIZE = 20 def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) BASE_SIMPLE_OBJECTS = [ @@ -33,7 +33,7 @@ BASE_SIMPLE_OBJECTS = [ 990 * u"h"] if sys.version_info < (3, 0): - BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821 + BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821 LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS] TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS] @@ -51,8 +51,8 @@ l.append(l) class Foo(object): - def __init__(self): - pass + def __init__(self): + pass BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(), @@ -70,118 +70,118 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + class TestSerialization(unittest.TestCase): - def test_serialize_by_value(self): + def test_serialize_by_value(self): - for val in SIMPLE_OBJECTS: - self.assertTrue(local_scheduler.check_simple_value(val)) - for val in COMPLEX_OBJECTS: - self.assertFalse(local_scheduler.check_simple_value(val)) + for val in SIMPLE_OBJECTS: + self.assertTrue(local_scheduler.check_simple_value(val)) + for val in COMPLEX_OBJECTS: + self.assertFalse(local_scheduler.check_simple_value(val)) class TestObjectID(unittest.TestCase): - def test_create_object_id(self): - random_object_id() + def test_create_object_id(self): + random_object_id() - def test_cannot_pickle_object_ids(self): - object_ids = [random_object_id() for _ in range(256)] + def test_cannot_pickle_object_ids(self): + object_ids = [random_object_id() for _ in range(256)] - def f(): - return object_ids + def f(): + return object_ids - def g(val=object_ids): - return 1 + def g(val=object_ids): + return 1 - def h(): - object_ids[0] - return 1 - # Make sure that object IDs cannot be pickled (including functions that - # close over object IDs). - self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0])) - self.assertRaises(Exception, lambda: pickle.dumps(object_ids)) - self.assertRaises(Exception, lambda: pickle.dumps(f)) - self.assertRaises(Exception, lambda: pickle.dumps(g)) - self.assertRaises(Exception, lambda: pickle.dumps(h)) + def h(): + object_ids[0] + return 1 + # Make sure that object IDs cannot be pickled (including functions that + # close over object IDs). + self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0])) + self.assertRaises(Exception, lambda: pickle.dumps(object_ids)) + self.assertRaises(Exception, lambda: pickle.dumps(f)) + self.assertRaises(Exception, lambda: pickle.dumps(g)) + self.assertRaises(Exception, lambda: pickle.dumps(h)) - def test_equality_comparisons(self): - x1 = local_scheduler.ObjectID(ID_SIZE * b"a") - x2 = local_scheduler.ObjectID(ID_SIZE * b"a") - y1 = local_scheduler.ObjectID(ID_SIZE * b"b") - y2 = local_scheduler.ObjectID(ID_SIZE * b"b") - self.assertEqual(x1, x2) - self.assertEqual(y1, y2) - self.assertNotEqual(x1, y1) + def test_equality_comparisons(self): + x1 = local_scheduler.ObjectID(ID_SIZE * b"a") + x2 = local_scheduler.ObjectID(ID_SIZE * b"a") + y1 = local_scheduler.ObjectID(ID_SIZE * b"b") + y2 = local_scheduler.ObjectID(ID_SIZE * b"b") + self.assertEqual(x1, x2) + self.assertEqual(y1, y2) + self.assertNotEqual(x1, y1) - random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)] - object_ids1 = [local_scheduler.ObjectID(random_strings[i]) - for i in range(256)] - object_ids2 = [local_scheduler.ObjectID(random_strings[i]) - for i in range(256)] - self.assertEqual(len(set(object_ids1)), 256) - self.assertEqual(len(set(object_ids1 + object_ids2)), 256) - self.assertEqual(set(object_ids1), set(object_ids2)) + random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)] + object_ids1 = [local_scheduler.ObjectID(random_strings[i]) + for i in range(256)] + object_ids2 = [local_scheduler.ObjectID(random_strings[i]) + for i in range(256)] + self.assertEqual(len(set(object_ids1)), 256) + self.assertEqual(len(set(object_ids1 + object_ids2)), 256) + self.assertEqual(set(object_ids1), set(object_ids2)) - def test_hashability(self): - x = random_object_id() - y = random_object_id() - {x: y} - set([x, y]) + def test_hashability(self): + x = random_object_id() + y = random_object_id() + {x: y} + set([x, y]) class TestTask(unittest.TestCase): - def check_task(self, task, function_id, num_return_vals, args): - self.assertEqual(function_id.id(), task.function_id().id()) - retrieved_args = task.arguments() - self.assertEqual(num_return_vals, len(task.returns())) - self.assertEqual(len(args), len(retrieved_args)) - for i in range(len(retrieved_args)): - if isinstance(retrieved_args[i], local_scheduler.ObjectID): - self.assertEqual(retrieved_args[i].id(), args[i].id()) - else: - self.assertEqual(retrieved_args[i], args[i]) + def check_task(self, task, function_id, num_return_vals, args): + self.assertEqual(function_id.id(), task.function_id().id()) + retrieved_args = task.arguments() + self.assertEqual(num_return_vals, len(task.returns())) + self.assertEqual(len(args), len(retrieved_args)) + for i in range(len(retrieved_args)): + if isinstance(retrieved_args[i], local_scheduler.ObjectID): + self.assertEqual(retrieved_args[i].id(), args[i].id()) + else: + self.assertEqual(retrieved_args[i], args[i]) - def test_create_and_serialize_task(self): - # TODO(rkn): The function ID should be a FunctionID object, not an - # ObjectID. - driver_id = random_driver_id() - parent_id = random_task_id() - function_id = random_function_id() - object_ids = [random_object_id() for _ in range(256)] - args_list = [ - [], - 1 * [1], - 10 * [1], - 100 * [1], - 1000 * [1], - 1 * ["a"], - 10 * ["a"], - 100 * ["a"], - 1000 * ["a"], - [1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]], - object_ids[:1], - object_ids[:2], - object_ids[:3], - object_ids[:4], - object_ids[:5], - object_ids[:10], - object_ids[:100], - object_ids[:256], - [1, object_ids[0]], - [object_ids[0], "a"], - [1, object_ids[0], "a"], - [object_ids[0], 1, object_ids[1], "a"], - object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids] - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(driver_id, function_id, args, - num_return_vals, parent_id, 0) - self.check_task(task, function_id, num_return_vals, args) - data = local_scheduler.task_to_string(task) - task2 = local_scheduler.task_from_string(data) - self.check_task(task2, function_id, num_return_vals, args) + def test_create_and_serialize_task(self): + # TODO(rkn): The function ID should be a FunctionID object, not an + # ObjectID. + driver_id = random_driver_id() + parent_id = random_task_id() + function_id = random_function_id() + object_ids = [random_object_id() for _ in range(256)] + args_list = [ + [], + 1 * [1], + 10 * [1], + 100 * [1], + 1000 * [1], + 1 * ["a"], + 10 * ["a"], + 100 * ["a"], + 1000 * ["a"], + [1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]], + object_ids[:1], + object_ids[:2], + object_ids[:3], + object_ids[:4], + object_ids[:5], + object_ids[:10], + object_ids[:100], + object_ids[:256], + [1, object_ids[0]], + [object_ids[0], "a"], + [1, object_ids[0], "a"], + [object_ids[0], 1, object_ids[1], "a"], + object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], + object_ids + 100 * ["a"] + object_ids] + for args in args_list: + for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: + task = local_scheduler.Task(driver_id, function_id, args, + num_return_vals, parent_id, 0) + self.check_task(task, function_id, num_return_vals, args) + data = local_scheduler.task_to_string(task) + task2 = local_scheduler.task_from_string(data) + self.check_task(task2, function_id, num_return_vals, args) if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/python/ray/experimental/array/distributed/core.py b/python/ray/experimental/array/distributed/core.py index 2b2d4badd..b576d0bff 100644 --- a/python/ray/experimental/array/distributed/core.py +++ b/python/ray/experimental/array/distributed/core.py @@ -10,271 +10,277 @@ BLOCK_SIZE = 10 class DistArray(object): - def __init__(self, shape, objectids=None): - self.shape = shape - self.ndim = len(shape) - self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape] - if objectids is not None: - self.objectids = objectids - else: - self.objectids = np.empty(self.num_blocks, dtype=object) - if self.num_blocks != list(self.objectids.shape): - raise Exception("The fields `num_blocks` and `objectids` are " - "inconsistent, `num_blocks` is {} and `objectids` has " - "shape {}".format(self.num_blocks, - list(self.objectids.shape))) + def __init__(self, shape, objectids=None): + self.shape = shape + self.ndim = len(shape) + self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) + for a in self.shape] + if objectids is not None: + self.objectids = objectids + else: + self.objectids = np.empty(self.num_blocks, dtype=object) + if self.num_blocks != list(self.objectids.shape): + raise Exception("The fields `num_blocks` and `objectids` are " + "inconsistent, `num_blocks` is {} and `objectids` " + "has shape {}".format(self.num_blocks, + list(self.objectids.shape))) - @staticmethod - def compute_block_lower(index, shape): - if len(index) != len(shape): - raise Exception("The fields `index` and `shape` must have the same " - "length, but `index` is {} and `shape` is " - "{}.".format(index, shape)) - return [elem * BLOCK_SIZE for elem in index] + @staticmethod + def compute_block_lower(index, shape): + if len(index) != len(shape): + raise Exception("The fields `index` and `shape` must have the " + "same length, but `index` is {} and `shape` is " + "{}.".format(index, shape)) + return [elem * BLOCK_SIZE for elem in index] - @staticmethod - def compute_block_upper(index, shape): - if len(index) != len(shape): - raise Exception("The fields `index` and `shape` must have the same " - "length, but `index` is {} and `shape` is " - "{}.".format(index, shape)) - upper = [] - for i in range(len(shape)): - upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i])) - return upper + @staticmethod + def compute_block_upper(index, shape): + if len(index) != len(shape): + raise Exception("The fields `index` and `shape` must have the " + "same length, but `index` is {} and `shape` is " + "{}.".format(index, shape)) + upper = [] + for i in range(len(shape)): + upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i])) + return upper - @staticmethod - def compute_block_shape(index, shape): - lower = DistArray.compute_block_lower(index, shape) - upper = DistArray.compute_block_upper(index, shape) - return [u - l for (l, u) in zip(lower, upper)] + @staticmethod + def compute_block_shape(index, shape): + lower = DistArray.compute_block_lower(index, shape) + upper = DistArray.compute_block_upper(index, shape) + return [u - l for (l, u) in zip(lower, upper)] - @staticmethod - def compute_num_blocks(shape): - return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape] + @staticmethod + def compute_num_blocks(shape): + return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape] - def assemble(self): - """Assemble an array from a distributed array of object IDs.""" - first_block = ray.get(self.objectids[(0,) * self.ndim]) - dtype = first_block.dtype - result = np.zeros(self.shape, dtype=dtype) - for index in np.ndindex(*self.num_blocks): - lower = DistArray.compute_block_lower(index, self.shape) - upper = DistArray.compute_block_upper(index, self.shape) - result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get( - self.objectids[index]) - return result + def assemble(self): + """Assemble an array from a distributed array of object IDs.""" + first_block = ray.get(self.objectids[(0,) * self.ndim]) + dtype = first_block.dtype + result = np.zeros(self.shape, dtype=dtype) + for index in np.ndindex(*self.num_blocks): + lower = DistArray.compute_block_lower(index, self.shape) + upper = DistArray.compute_block_upper(index, self.shape) + result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get( + self.objectids[index]) + return result - def __getitem__(self, sliced): - # TODO(rkn): Fix this, this is just a placeholder that should work but is - # inefficient. - a = self.assemble() - return a[sliced] + def __getitem__(self, sliced): + # TODO(rkn): Fix this, this is just a placeholder that should work but + # is inefficient. + a = self.assemble() + return a[sliced] @ray.remote def assemble(a): - return a.assemble() + return a.assemble() # TODO(rkn): What should we call this method? @ray.remote def numpy_to_dist(a): - result = DistArray(a.shape) - for index in np.ndindex(*result.num_blocks): - lower = DistArray.compute_block_lower(index, a.shape) - upper = DistArray.compute_block_upper(index, a.shape) - result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) - in zip(lower, upper)]]) - return result + result = DistArray(a.shape) + for index in np.ndindex(*result.num_blocks): + lower = DistArray.compute_block_lower(index, a.shape) + upper = DistArray.compute_block_upper(index, a.shape) + result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) + in zip(lower, upper)]]) + return result @ray.remote def zeros(shape, dtype_name="float"): - result = DistArray(shape) - for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.zeros.remote( - DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) - return result + result = DistArray(shape) + for index in np.ndindex(*result.num_blocks): + result.objectids[index] = ra.zeros.remote( + DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) + return result @ray.remote def ones(shape, dtype_name="float"): - result = DistArray(shape) - for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.ones.remote( - DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) - return result + result = DistArray(shape) + for index in np.ndindex(*result.num_blocks): + result.objectids[index] = ra.ones.remote( + DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) + return result @ray.remote def copy(a): - result = DistArray(a.shape) - for index in np.ndindex(*result.num_blocks): - # We don't need to actually copy the objects because remote objects are - # immutable. - result.objectids[index] = a.objectids[index] - return result + result = DistArray(a.shape) + for index in np.ndindex(*result.num_blocks): + # We don't need to actually copy the objects because remote objects are + # immutable. + result.objectids[index] = a.objectids[index] + return result @ray.remote def eye(dim1, dim2=-1, dtype_name="float"): - dim2 = dim1 if dim2 == -1 else dim2 - shape = [dim1, dim2] - result = DistArray(shape) - for (i, j) in np.ndindex(*result.num_blocks): - block_shape = DistArray.compute_block_shape([i, j], shape) - if i == j: - result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1], - dtype_name=dtype_name) - else: - result.objectids[i, j] = ra.zeros.remote(block_shape, - dtype_name=dtype_name) - return result + dim2 = dim1 if dim2 == -1 else dim2 + shape = [dim1, dim2] + result = DistArray(shape) + for (i, j) in np.ndindex(*result.num_blocks): + block_shape = DistArray.compute_block_shape([i, j], shape) + if i == j: + result.objectids[i, j] = ra.eye.remote(block_shape[0], + block_shape[1], + dtype_name=dtype_name) + else: + result.objectids[i, j] = ra.zeros.remote(block_shape, + dtype_name=dtype_name) + return result @ray.remote def triu(a): - if a.ndim != 2: - raise Exception("Input must have 2 dimensions, but a.ndim is " - "{}.".format(a.ndim)) - result = DistArray(a.shape) - for (i, j) in np.ndindex(*result.num_blocks): - if i < j: - result.objectids[i, j] = ra.copy.remote(a.objectids[i, j]) - elif i == j: - result.objectids[i, j] = ra.triu.remote(a.objectids[i, j]) - else: - result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) - return result + if a.ndim != 2: + raise Exception("Input must have 2 dimensions, but a.ndim is " + "{}.".format(a.ndim)) + result = DistArray(a.shape) + for (i, j) in np.ndindex(*result.num_blocks): + if i < j: + result.objectids[i, j] = ra.copy.remote(a.objectids[i, j]) + elif i == j: + result.objectids[i, j] = ra.triu.remote(a.objectids[i, j]) + else: + result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) + return result @ray.remote def tril(a): - if a.ndim != 2: - raise Exception("Input must have 2 dimensions, but a.ndim is " - "{}.".format(a.ndim)) - result = DistArray(a.shape) - for (i, j) in np.ndindex(*result.num_blocks): - if i > j: - result.objectids[i, j] = ra.copy.remote(a.objectids[i, j]) - elif i == j: - result.objectids[i, j] = ra.tril.remote(a.objectids[i, j]) - else: - result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) - return result + if a.ndim != 2: + raise Exception("Input must have 2 dimensions, but a.ndim is " + "{}.".format(a.ndim)) + result = DistArray(a.shape) + for (i, j) in np.ndindex(*result.num_blocks): + if i > j: + result.objectids[i, j] = ra.copy.remote(a.objectids[i, j]) + elif i == j: + result.objectids[i, j] = ra.tril.remote(a.objectids[i, j]) + else: + result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) + return result @ray.remote def blockwise_dot(*matrices): - n = len(matrices) - if n % 2 != 0: - raise Exception("blockwise_dot expects an even number of arguments, but " - "len(matrices) is {}.".format(n)) - shape = (matrices[0].shape[0], matrices[n // 2].shape[1]) - result = np.zeros(shape) - for i in range(n // 2): - result += np.dot(matrices[i], matrices[n // 2 + i]) - return result + n = len(matrices) + if n % 2 != 0: + raise Exception("blockwise_dot expects an even number of arguments, " + "but len(matrices) is {}.".format(n)) + shape = (matrices[0].shape[0], matrices[n // 2].shape[1]) + result = np.zeros(shape) + for i in range(n // 2): + result += np.dot(matrices[i], matrices[n // 2 + i]) + return result @ray.remote def dot(a, b): - if a.ndim != 2: - raise Exception("dot expects its arguments to be 2-dimensional, but " - "a.ndim = {}.".format(a.ndim)) - if b.ndim != 2: - raise Exception("dot expects its arguments to be 2-dimensional, but " - "b.ndim = {}.".format(b.ndim)) - if a.shape[1] != b.shape[0]: - raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape " - "= {} and b.shape = {}.".format(a.shape, b.shape)) - shape = [a.shape[0], b.shape[1]] - result = DistArray(shape) - for (i, j) in np.ndindex(*result.num_blocks): - args = list(a.objectids[i, :]) + list(b.objectids[:, j]) - result.objectids[i, j] = blockwise_dot.remote(*args) - return result + if a.ndim != 2: + raise Exception("dot expects its arguments to be 2-dimensional, but " + "a.ndim = {}.".format(a.ndim)) + if b.ndim != 2: + raise Exception("dot expects its arguments to be 2-dimensional, but " + "b.ndim = {}.".format(b.ndim)) + if a.shape[1] != b.shape[0]: + raise Exception("dot expects a.shape[1] to equal b.shape[0], but " + "a.shape = {} and b.shape = {}.".format(a.shape, + b.shape)) + shape = [a.shape[0], b.shape[1]] + result = DistArray(shape) + for (i, j) in np.ndindex(*result.num_blocks): + args = list(a.objectids[i, :]) + list(b.objectids[:, j]) + result.objectids[i, j] = blockwise_dot.remote(*args) + return result @ray.remote def subblocks(a, *ranges): - """ - This function produces a distributed array from a subset of the blocks in the - `a`. The result and `a` will have the same number of dimensions.For example, - subblocks(a, [0, 1], [2, 4]) - will produce a DistArray whose objectids are - [[a.objectids[0, 2], a.objectids[0, 4]], - [a.objectids[1, 2], a.objectids[1, 4]]] - We allow the user to pass in an empty list [] to indicate the full range. - """ - ranges = list(ranges) - if len(ranges) != a.ndim: - raise Exception("sub_blocks expects to receive a number of ranges equal " - "to a.ndim, but it received {} ranges and a.ndim = " - "{}.".format(len(ranges), a.ndim)) - for i in range(len(ranges)): - # We allow the user to pass in an empty list to indicate the full range. - if ranges[i] == []: - ranges[i] = range(a.num_blocks[i]) - if not np.alltrue(ranges[i] == np.sort(ranges[i])): - raise Exception("Ranges passed to sub_blocks must be sorted, but the " - "{}th range is {}.".format(i, ranges[i])) - if ranges[i][0] < 0: - raise Exception("Values in the ranges passed to sub_blocks must be at " - "least 0, but the {}th range is {}.".format(i, - ranges[i])) - if ranges[i][-1] >= a.num_blocks[i]: - raise Exception("Values in the ranges passed to sub_blocks must be less " - "than the relevant number of blocks, but the {}th range " - "is {}, and a.num_blocks = {}.".format(i, ranges[i], - a.num_blocks)) - last_index = [r[-1] for r in ranges] - last_block_shape = DistArray.compute_block_shape(last_index, a.shape) - shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] - for i in range(a.ndim)] - result = DistArray(shape) - for index in np.ndindex(*result.num_blocks): - result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] - for i in range(a.ndim)])] - return result + """ + This function produces a distributed array from a subset of the blocks in + the `a`. The result and `a` will have the same number of dimensions. For + example, + subblocks(a, [0, 1], [2, 4]) + will produce a DistArray whose objectids are + [[a.objectids[0, 2], a.objectids[0, 4]], + [a.objectids[1, 2], a.objectids[1, 4]]] + We allow the user to pass in an empty list [] to indicate the full range. + """ + ranges = list(ranges) + if len(ranges) != a.ndim: + raise Exception("sub_blocks expects to receive a number of ranges " + "equal to a.ndim, but it received {} ranges and " + "a.ndim = {}.".format(len(ranges), a.ndim)) + for i in range(len(ranges)): + # We allow the user to pass in an empty list to indicate the full + # range. + if ranges[i] == []: + ranges[i] = range(a.num_blocks[i]) + if not np.alltrue(ranges[i] == np.sort(ranges[i])): + raise Exception("Ranges passed to sub_blocks must be sorted, but " + "the {}th range is {}.".format(i, ranges[i])) + if ranges[i][0] < 0: + raise Exception("Values in the ranges passed to sub_blocks must " + "be at least 0, but the {}th range is {}." + .format(i, ranges[i])) + if ranges[i][-1] >= a.num_blocks[i]: + raise Exception("Values in the ranges passed to sub_blocks must " + "be less than the relevant number of blocks, but " + "the {}th range is {}, and a.num_blocks = {}." + .format(i, ranges[i], a.num_blocks)) + last_index = [r[-1] for r in ranges] + last_block_shape = DistArray.compute_block_shape(last_index, a.shape) + shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i] + for i in range(a.ndim)] + result = DistArray(shape) + for index in np.ndindex(*result.num_blocks): + result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] + for i in range(a.ndim)])] + return result @ray.remote def transpose(a): - if a.ndim != 2: - raise Exception("transpose expects its argument to be 2-dimensional, but " - "a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape)) - result = DistArray([a.shape[1], a.shape[0]]) - for i in range(result.num_blocks[0]): - for j in range(result.num_blocks[1]): - result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i]) - return result + if a.ndim != 2: + raise Exception("transpose expects its argument to be 2-dimensional, " + "but a.ndim = {}, a.shape = {}.".format(a.ndim, + a.shape)) + result = DistArray([a.shape[1], a.shape[0]]) + for i in range(result.num_blocks[0]): + for j in range(result.num_blocks[1]): + result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i]) + return result # TODO(rkn): support broadcasting? @ray.remote def add(x1, x2): - if x1.shape != x2.shape: - raise Exception("add expects arguments `x1` and `x2` to have the same " - "shape, but x1.shape = {}, and x2.shape = {}." - .format(x1.shape, x2.shape)) - result = DistArray(x1.shape) - for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.add.remote(x1.objectids[index], - x2.objectids[index]) - return result + if x1.shape != x2.shape: + raise Exception("add expects arguments `x1` and `x2` to have the same " + "shape, but x1.shape = {}, and x2.shape = {}." + .format(x1.shape, x2.shape)) + result = DistArray(x1.shape) + for index in np.ndindex(*result.num_blocks): + result.objectids[index] = ra.add.remote(x1.objectids[index], + x2.objectids[index]) + return result # TODO(rkn): support broadcasting? @ray.remote def subtract(x1, x2): - if x1.shape != x2.shape: - raise Exception("subtract expects arguments `x1` and `x2` to have the " - "same shape, but x1.shape = {}, and x2.shape = {}." - .format(x1.shape, x2.shape)) - result = DistArray(x1.shape) - for index in np.ndindex(*result.num_blocks): - result.objectids[index] = ra.subtract.remote(x1.objectids[index], - x2.objectids[index]) - return result + if x1.shape != x2.shape: + raise Exception("subtract expects arguments `x1` and `x2` to have the " + "same shape, but x1.shape = {}, and x2.shape = {}." + .format(x1.shape, x2.shape)) + result = DistArray(x1.shape) + for index in np.ndindex(*result.num_blocks): + result.objectids[index] = ra.subtract.remote(x1.objectids[index], + x2.objectids[index]) + return result diff --git a/python/ray/experimental/array/distributed/linalg.py b/python/ray/experimental/array/distributed/linalg.py index 86a66fd95..6c6023423 100644 --- a/python/ray/experimental/array/distributed/linalg.py +++ b/python/ray/experimental/array/distributed/linalg.py @@ -13,74 +13,75 @@ __all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"] @ray.remote(num_return_vals=2) def tsqr(a): - """Perform a QR decomposition of a tall-skinny matrix. + """Perform a QR decomposition of a tall-skinny matrix. - Args: - a: A distributed matrix with shape MxN (suppose K = min(M, N)). + Args: + a: A distributed matrix with shape MxN (suppose K = min(M, N)). - Returns: - A tuple of q (a DistArray) and r (a numpy array) satisfying the following. - - If q_full = ray.get(DistArray, q).assemble(), then - q_full.shape == (M, K). - - np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True. - - If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N). - - np.allclose(r, np.triu(r)) == True. - """ - if len(a.shape) != 2: - raise Exception("tsqr requires len(a.shape) == 2, but a.shape is " - "{}".format(a.shape)) - if a.num_blocks[1] != 1: - raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is " - "{}".format(a.num_blocks)) + Returns: + A tuple of q (a DistArray) and r (a numpy array) satisfying the + following. + - If q_full = ray.get(DistArray, q).assemble(), then + q_full.shape == (M, K). + - np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True. + - If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N). + - np.allclose(r, np.triu(r)) == True. + """ + if len(a.shape) != 2: + raise Exception("tsqr requires len(a.shape) == 2, but a.shape is " + "{}".format(a.shape)) + if a.num_blocks[1] != 1: + raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks " + "is {}".format(a.num_blocks)) - num_blocks = a.num_blocks[0] - K = int(np.ceil(np.log2(num_blocks))) + 1 - q_tree = np.empty((num_blocks, K), dtype=object) - current_rs = [] - for i in range(num_blocks): - block = a.objectids[i, 0] - q, r = ra.linalg.qr.remote(block) - q_tree[i, 0] = q - current_rs.append(r) - for j in range(1, K): - new_rs = [] - for i in range(int(np.ceil(1.0 * len(current_rs) / 2))): - stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)]) - q, r = ra.linalg.qr.remote(stacked_rs) - q_tree[i, j] = q - new_rs.append(r) - current_rs = new_rs - assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs)) - - # handle the special case in which the whole DistArray "a" fits in one block - # and has fewer rows than columns, this is a bit ugly so think about how to - # remove it - if a.shape[0] >= a.shape[1]: - q_shape = a.shape - else: - q_shape = [a.shape[0], a.shape[0]] - q_num_blocks = core.DistArray.compute_num_blocks(q_shape) - q_objectids = np.empty(q_num_blocks, dtype=object) - q_result = core.DistArray(q_shape, q_objectids) - - # reconstruct output - for i in range(num_blocks): - q_block_current = q_tree[i, 0] - ith_index = i + num_blocks = a.num_blocks[0] + K = int(np.ceil(np.log2(num_blocks))) + 1 + q_tree = np.empty((num_blocks, K), dtype=object) + current_rs = [] + for i in range(num_blocks): + block = a.objectids[i, 0] + q, r = ra.linalg.qr.remote(block) + q_tree[i, 0] = q + current_rs.append(r) for j in range(1, K): - if np.mod(ith_index, 2) == 0: - lower = [0, 0] - upper = [a.shape[1], core.BLOCK_SIZE] - else: - lower = [a.shape[1], 0] - upper = [2 * a.shape[1], core.BLOCK_SIZE] - ith_index //= 2 - q_block_current = ra.dot.remote(q_block_current, - ra.subarray.remote(q_tree[ith_index, j], - lower, upper)) - q_result.objectids[i] = q_block_current - r = current_rs[0] - return q_result, ray.get(r) + new_rs = [] + for i in range(int(np.ceil(1.0 * len(current_rs) / 2))): + stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)]) + q, r = ra.linalg.qr.remote(stacked_rs) + q_tree[i, j] = q + new_rs.append(r) + current_rs = new_rs + assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs)) + + # handle the special case in which the whole DistArray "a" fits in one + # block and has fewer rows than columns, this is a bit ugly so think about + # how to remove it + if a.shape[0] >= a.shape[1]: + q_shape = a.shape + else: + q_shape = [a.shape[0], a.shape[0]] + q_num_blocks = core.DistArray.compute_num_blocks(q_shape) + q_objectids = np.empty(q_num_blocks, dtype=object) + q_result = core.DistArray(q_shape, q_objectids) + + # reconstruct output + for i in range(num_blocks): + q_block_current = q_tree[i, 0] + ith_index = i + for j in range(1, K): + if np.mod(ith_index, 2) == 0: + lower = [0, 0] + upper = [a.shape[1], core.BLOCK_SIZE] + else: + lower = [a.shape[1], 0] + upper = [2 * a.shape[1], core.BLOCK_SIZE] + ith_index //= 2 + q_block_current = ra.dot.remote( + q_block_current, ra.subarray.remote(q_tree[ith_index, j], + lower, upper)) + q_result.objectids[i] = q_block_current + r = current_rs[0] + return q_result, ray.get(r) # TODO(rkn): This is unoptimized, we really want a block version of this. @@ -88,76 +89,77 @@ def tsqr(a): # http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf. @ray.remote(num_return_vals=3) def modified_lu(q): - """Perform a modified LU decomposition of a matrix. + """Perform a modified LU decomposition of a matrix. - This takes a matrix q with orthonormal columns, returns l, u, s such that - q - s = l * u. + This takes a matrix q with orthonormal columns, returns l, u, s such that + q - s = l * u. - Args: - q: A two dimensional orthonormal matrix q. + Args: + q: A two dimensional orthonormal matrix q. - Returns: - A tuple of a lower triangular matrix l, an upper triangular matrix u, and a - a vector representing a diagonal matrix s such that q - s = l * u. - """ - q = q.assemble() - m, b = q.shape[0], q.shape[1] - S = np.zeros(b) + Returns: + A tuple of a lower triangular matrix l, an upper triangular matrix u, + and a a vector representing a diagonal matrix s such that + q - s = l * u. + """ + q = q.assemble() + m, b = q.shape[0], q.shape[1] + S = np.zeros(b) - q_work = np.copy(q) + q_work = np.copy(q) - for i in range(b): - S[i] = -1 * np.sign(q_work[i, i]) - q_work[i, i] -= S[i] - # Scale ith column of L by diagonal element. - q_work[(i + 1):m, i] /= q_work[i, i] - # Perform Schur complement update. - q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], - q_work[i, (i + 1):b]) + for i in range(b): + S[i] = -1 * np.sign(q_work[i, i]) + q_work[i, i] -= S[i] + # Scale ith column of L by diagonal element. + q_work[(i + 1):m, i] /= q_work[i, i] + # Perform Schur complement update. + q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i], + q_work[i, (i + 1):b]) - L = np.tril(q_work) - for i in range(b): - L[i, i] = 1 - U = np.triu(q_work)[:b, :] - # TODO(rkn): Get rid of the put below. - return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S + L = np.tril(q_work) + for i in range(b): + L[i, i] = 1 + U = np.triu(q_work)[:b, :] + # TODO(rkn): Get rid of the put below. + return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S @ray.remote(num_return_vals=2) def tsqr_hr_helper1(u, s, y_top_block, b): - y_top = y_top_block[:b, :b] - s_full = np.diag(s) - t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T)) - return t, y_top + y_top = y_top_block[:b, :b] + s_full = np.diag(s) + t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T)) + return t, y_top @ray.remote def tsqr_hr_helper2(s, r_temp): - s_full = np.diag(s) - return np.dot(s_full, r_temp) + s_full = np.diag(s) + return np.dot(s_full, r_temp) # This is Algorithm 6 from # http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf. @ray.remote(num_return_vals=4) def tsqr_hr(a): - q, r_temp = tsqr.remote(a) - y, u, s = modified_lu.remote(q) - y_blocked = ray.get(y) - t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], - a.shape[1]) - r = tsqr_hr_helper2.remote(s, r_temp) - return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r) + q, r_temp = tsqr.remote(a) + y, u, s = modified_lu.remote(q) + y_blocked = ray.get(y) + t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0], + a.shape[1]) + r = tsqr_hr_helper2.remote(s, r_temp) + return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r) @ray.remote def qr_helper1(a_rc, y_ri, t, W_c): - return a_rc - np.dot(y_ri, np.dot(t.T, W_c)) + return a_rc - np.dot(y_ri, np.dot(t.T, W_c)) @ray.remote def qr_helper2(y_ri, a_rc): - return np.dot(y_ri.T, a_rc) + return np.dot(y_ri.T, a_rc) # This is Algorithm 7 from @@ -165,60 +167,63 @@ def qr_helper2(y_ri, a_rc): @ray.remote(num_return_vals=2) def qr(a): - m, n = a.shape[0], a.shape[1] - k = min(m, n) + m, n = a.shape[0], a.shape[1] + k = min(m, n) - # we will store our scratch work in a_work - a_work = core.DistArray(a.shape, np.copy(a.objectids)) + # we will store our scratch work in a_work + a_work = core.DistArray(a.shape, np.copy(a.objectids)) - result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name - # TODO(rkn): It would be preferable not to get this right after creating it. - r_res = ray.get(core.zeros.remote([k, n], result_dtype)) - # TODO(rkn): It would be preferable not to get this right after creating it. - y_res = ray.get(core.zeros.remote([m, k], result_dtype)) - Ts = [] + result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name + # TODO(rkn): It would be preferable not to get this right after creating + # it. + r_res = ray.get(core.zeros.remote([k, n], result_dtype)) + # TODO(rkn): It would be preferable not to get this right after creating + # it. + y_res = ray.get(core.zeros.remote([m, k], result_dtype)) + Ts = [] - # The for loop differs from the paper, which says - # "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense - # when a.num_blocks[1] > a.num_blocks[0]. - for i in range(min(a.num_blocks[0], a.num_blocks[1])): - sub_dist_array = core.subblocks.remote( - a_work, list(range(i, a_work.num_blocks[0])), [i]) - y, t, _, R = tsqr_hr.remote(sub_dist_array) - y_val = ray.get(y) + # The for loop differs from the paper, which says + # "for i in range(a.num_blocks[1])", but that doesn't seem to make any + # sense when a.num_blocks[1] > a.num_blocks[0]. + for i in range(min(a.num_blocks[0], a.num_blocks[1])): + sub_dist_array = core.subblocks.remote( + a_work, list(range(i, a_work.num_blocks[0])), [i]) + y, t, _, R = tsqr_hr.remote(sub_dist_array) + y_val = ray.get(y) - for j in range(i, a.num_blocks[0]): - y_res.objectids[j, i] = y_val.objectids[j - i, 0] - if a.shape[0] > a.shape[1]: - # in this case, R needs to be square - R_shape = ray.get(ra.shape.remote(R)) - eye_temp = ra.eye.remote(R_shape[1], R_shape[0], dtype_name=result_dtype) - r_res.objectids[i, i] = ra.dot.remote(eye_temp, R) - else: - r_res.objectids[i, i] = R - Ts.append(core.numpy_to_dist.remote(t)) + for j in range(i, a.num_blocks[0]): + y_res.objectids[j, i] = y_val.objectids[j - i, 0] + if a.shape[0] > a.shape[1]: + # in this case, R needs to be square + R_shape = ray.get(ra.shape.remote(R)) + eye_temp = ra.eye.remote(R_shape[1], R_shape[0], + dtype_name=result_dtype) + r_res.objectids[i, i] = ra.dot.remote(eye_temp, R) + else: + r_res.objectids[i, i] = R + Ts.append(core.numpy_to_dist.remote(t)) - for c in range(i + 1, a.num_blocks[1]): - W_rcs = [] - for r in range(i, a.num_blocks[0]): - y_ri = y_val.objectids[r - i, 0] - W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c])) - W_c = ra.sum_list.remote(*W_rcs) - for r in range(i, a.num_blocks[0]): - y_ri = y_val.objectids[r - i, 0] - A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c) - a_work.objectids[r, c] = A_rc - r_res.objectids[i, c] = a_work.objectids[i, c] + for c in range(i + 1, a.num_blocks[1]): + W_rcs = [] + for r in range(i, a.num_blocks[0]): + y_ri = y_val.objectids[r - i, 0] + W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c])) + W_c = ra.sum_list.remote(*W_rcs) + for r in range(i, a.num_blocks[0]): + y_ri = y_val.objectids[r - i, 0] + A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c) + a_work.objectids[r, c] = A_rc + r_res.objectids[i, c] = a_work.objectids[i, c] - # construct q_res from Ys and Ts - q = core.eye.remote(m, k, dtype_name=result_dtype) - for i in range(len(Ts))[::-1]: - y_col_block = core.subblocks.remote(y_res, [], [i]) - q = core.subtract.remote( - q, core.dot.remote( - y_col_block, - core.dot.remote(Ts[i], - core.dot.remote(core.transpose.remote(y_col_block), - q)))) + # construct q_res from Ys and Ts + q = core.eye.remote(m, k, dtype_name=result_dtype) + for i in range(len(Ts))[::-1]: + y_col_block = core.subblocks.remote(y_res, [], [i]) + q = core.subtract.remote( + q, core.dot.remote( + y_col_block, + core.dot.remote( + Ts[i], + core.dot.remote(core.transpose.remote(y_col_block), q)))) - return ray.get(q), r_res + return ray.get(q), r_res diff --git a/python/ray/experimental/array/distributed/random.py b/python/ray/experimental/array/distributed/random.py index e3e25ff8c..a946df90a 100644 --- a/python/ray/experimental/array/distributed/random.py +++ b/python/ray/experimental/array/distributed/random.py @@ -11,10 +11,10 @@ from .core import DistArray @ray.remote def normal(shape): - num_blocks = DistArray.compute_num_blocks(shape) - objectids = np.empty(num_blocks, dtype=object) - for index in np.ndindex(*num_blocks): - objectids[index] = ra.random.normal.remote( - DistArray.compute_block_shape(index, shape)) - result = DistArray(shape, objectids) - return result + num_blocks = DistArray.compute_num_blocks(shape) + objectids = np.empty(num_blocks, dtype=object) + for index in np.ndindex(*num_blocks): + objectids[index] = ra.random.normal.remote( + DistArray.compute_block_shape(index, shape)) + result = DistArray(shape, objectids) + return result diff --git a/python/ray/experimental/array/remote/core.py b/python/ray/experimental/array/remote/core.py index e7f142f1b..60b6ca38b 100644 --- a/python/ray/experimental/array/remote/core.py +++ b/python/ray/experimental/array/remote/core.py @@ -8,94 +8,94 @@ import ray @ray.remote def zeros(shape, dtype_name="float", order="C"): - return np.zeros(shape, dtype=np.dtype(dtype_name), order=order) + return np.zeros(shape, dtype=np.dtype(dtype_name), order=order) @ray.remote def zeros_like(a, dtype_name="None", order="K", subok=True): - dtype_val = None if dtype_name == "None" else np.dtype(dtype_name) - return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok) + dtype_val = None if dtype_name == "None" else np.dtype(dtype_name) + return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok) @ray.remote def ones(shape, dtype_name="float", order="C"): - return np.ones(shape, dtype=np.dtype(dtype_name), order=order) + return np.ones(shape, dtype=np.dtype(dtype_name), order=order) @ray.remote def eye(N, M=-1, k=0, dtype_name="float"): - M = N if M == -1 else M - return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name)) + M = N if M == -1 else M + return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name)) @ray.remote def dot(a, b): - return np.dot(a, b) + return np.dot(a, b) @ray.remote def vstack(*xs): - return np.vstack(xs) + return np.vstack(xs) @ray.remote def hstack(*xs): - return np.hstack(xs) + return np.hstack(xs) # TODO(rkn): Instead of this, consider implementing slicing. # TODO(rkn): Be consistent about using "index" versus "indices". @ray.remote def subarray(a, lower_indices, upper_indices): - return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] + return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] @ray.remote def copy(a, order="K"): - return np.copy(a, order=order) + return np.copy(a, order=order) @ray.remote def tril(m, k=0): - return np.tril(m, k=k) + return np.tril(m, k=k) @ray.remote def triu(m, k=0): - return np.triu(m, k=k) + return np.triu(m, k=k) @ray.remote def diag(v, k=0): - return np.diag(v, k=k) + return np.diag(v, k=k) @ray.remote def transpose(a, axes=[]): - axes = None if axes == [] else axes - return np.transpose(a, axes=axes) + axes = None if axes == [] else axes + return np.transpose(a, axes=axes) @ray.remote def add(x1, x2): - return np.add(x1, x2) + return np.add(x1, x2) @ray.remote def subtract(x1, x2): - return np.subtract(x1, x2) + return np.subtract(x1, x2) @ray.remote def sum(x, axis=-1): - return np.sum(x, axis=axis if axis != -1 else None) + return np.sum(x, axis=axis if axis != -1 else None) @ray.remote def shape(a): - return np.shape(a) + return np.shape(a) @ray.remote def sum_list(*xs): - return np.sum(xs, axis=0) + return np.sum(xs, axis=0) diff --git a/python/ray/experimental/array/remote/linalg.py b/python/ray/experimental/array/remote/linalg.py index b1436648c..a862ed39f 100644 --- a/python/ray/experimental/array/remote/linalg.py +++ b/python/ray/experimental/array/remote/linalg.py @@ -13,99 +13,99 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv", @ray.remote def matrix_power(M, n): - return np.linalg.matrix_power(M, n) + return np.linalg.matrix_power(M, n) @ray.remote def solve(a, b): - return np.linalg.solve(a, b) + return np.linalg.solve(a, b) @ray.remote(num_return_vals=2) def tensorsolve(a): - raise NotImplementedError + raise NotImplementedError @ray.remote(num_return_vals=2) def tensorinv(a): - raise NotImplementedError + raise NotImplementedError @ray.remote def inv(a): - return np.linalg.inv(a) + return np.linalg.inv(a) @ray.remote def cholesky(a): - return np.linalg.cholesky(a) + return np.linalg.cholesky(a) @ray.remote def eigvals(a): - return np.linalg.eigvals(a) + return np.linalg.eigvals(a) @ray.remote def eigvalsh(a): - raise NotImplementedError + raise NotImplementedError @ray.remote def pinv(a): - return np.linalg.pinv(a) + return np.linalg.pinv(a) @ray.remote def slogdet(a): - raise NotImplementedError + raise NotImplementedError @ray.remote def det(a): - return np.linalg.det(a) + return np.linalg.det(a) @ray.remote(num_return_vals=3) def svd(a): - return np.linalg.svd(a) + return np.linalg.svd(a) @ray.remote(num_return_vals=2) def eig(a): - return np.linalg.eig(a) + return np.linalg.eig(a) @ray.remote(num_return_vals=2) def eigh(a): - return np.linalg.eigh(a) + return np.linalg.eigh(a) @ray.remote(num_return_vals=4) def lstsq(a, b): - return np.linalg.lstsq(a) + return np.linalg.lstsq(a) @ray.remote def norm(x): - return np.linalg.norm(x) + return np.linalg.norm(x) @ray.remote(num_return_vals=2) def qr(a): - return np.linalg.qr(a) + return np.linalg.qr(a) @ray.remote def cond(x): - return np.linalg.cond(x) + return np.linalg.cond(x) @ray.remote def matrix_rank(M): - return np.linalg.matrix_rank(M) + return np.linalg.matrix_rank(M) @ray.remote def multi_dot(*a): - raise NotImplementedError + raise NotImplementedError diff --git a/python/ray/experimental/array/remote/random.py b/python/ray/experimental/array/remote/random.py index bea781f2b..64346c34a 100644 --- a/python/ray/experimental/array/remote/random.py +++ b/python/ray/experimental/array/remote/random.py @@ -8,4 +8,4 @@ import ray @ray.remote def normal(shape): - return np.random.normal(size=shape) + return np.random.normal(size=shape) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index d76d07faf..e9fd9a804 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -49,507 +49,523 @@ TASK_STATUS_MAPPING = { class GlobalState(object): - """A class used to interface with the Ray control state. + """A class used to interface with the Ray control state. - Attributes: - redis_client: The redis client used to query the redis server. - """ - def __init__(self): - """Create a GlobalState object.""" - self.redis_client = None - - def _check_connected(self): - """Check that the object has been initialized before it is used. - - Raises: - Exception: An exception is raised if ray.init() has not been called yet. + Attributes: + redis_client: The redis client used to query the redis server. """ - if self.redis_client is None: - raise Exception("The ray.global_state API cannot be used before " - "ray.init has been called.") + def __init__(self): + """Create a GlobalState object.""" + self.redis_client = None - def _initialize_global_state(self, redis_ip_address, redis_port): - """Initialize the GlobalState object by connecting to Redis. + def _check_connected(self): + """Check that the object has been initialized before it is used. - Args: - redis_ip_address: The IP address of the node that the Redis server lives - on. - redis_port: The port that the Redis server is listening on. - """ - self.redis_client = redis.StrictRedis(host=redis_ip_address, - port=redis_port) - self.redis_clients = [] - num_redis_shards = self.redis_client.get("NumRedisShards") - if num_redis_shards is None: - raise Exception("No entry found for NumRedisShards") - num_redis_shards = int(num_redis_shards) - if (num_redis_shards < 1): - raise Exception("Expected at least one Redis shard, found " - "{}.".format(num_redis_shards)) + Raises: + Exception: An exception is raised if ray.init() has not been called + yet. + """ + if self.redis_client is None: + raise Exception("The ray.global_state API cannot be used before " + "ray.init has been called.") - ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1) - if len(ip_address_ports) != num_redis_shards: - raise Exception("Expected {} Redis shard addresses, found " - "{}".format(num_redis_shards, len(ip_address_ports))) + def _initialize_global_state(self, redis_ip_address, redis_port): + """Initialize the GlobalState object by connecting to Redis. - for ip_address_port in ip_address_ports: - shard_address, shard_port = ip_address_port.split(b":") - self.redis_clients.append(redis.StrictRedis(host=shard_address, - port=shard_port)) + Args: + redis_ip_address: The IP address of the node that the Redis server + lives on. + redis_port: The port that the Redis server is listening on. + """ + self.redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) + self.redis_clients = [] + num_redis_shards = self.redis_client.get("NumRedisShards") + if num_redis_shards is None: + raise Exception("No entry found for NumRedisShards") + num_redis_shards = int(num_redis_shards) + if (num_redis_shards < 1): + raise Exception("Expected at least one Redis shard, found " + "{}.".format(num_redis_shards)) - def _execute_command(self, key, *args): - """Execute a Redis command on the appropriate Redis shard based on key. + ip_address_ports = self.redis_client.lrange("RedisShards", start=0, + end=-1) + if len(ip_address_ports) != num_redis_shards: + raise Exception("Expected {} Redis shard addresses, found " + "{}".format(num_redis_shards, + len(ip_address_ports))) - Args: - key: The object ID or the task ID that the query is about. - args: The command to run. + for ip_address_port in ip_address_ports: + shard_address, shard_port = ip_address_port.split(b":") + self.redis_clients.append(redis.StrictRedis(host=shard_address, + port=shard_port)) - Returns: - The value returned by the Redis command. - """ - client = self.redis_clients[key.redis_shard_hash() % - len(self.redis_clients)] - return client.execute_command(*args) + def _execute_command(self, key, *args): + """Execute a Redis command on the appropriate Redis shard based on key. - def _keys(self, pattern): - """Execute the KEYS command on all Redis shards. + Args: + key: The object ID or the task ID that the query is about. + args: The command to run. - Args: - pattern: The KEYS pattern to query. + Returns: + The value returned by the Redis command. + """ + client = self.redis_clients[key.redis_shard_hash() % + len(self.redis_clients)] + return client.execute_command(*args) - Returns: - The concatenated list of results from all shards. - """ - result = [] - for client in self.redis_clients: - result.extend(client.keys(pattern)) - return result + def _keys(self, pattern): + """Execute the KEYS command on all Redis shards. - def _object_table(self, object_id): - """Fetch and parse the object table information for a single object ID. + Args: + pattern: The KEYS pattern to query. - Args: - object_id_binary: A string of bytes with the object ID to get information - about. + Returns: + The concatenated list of results from all shards. + """ + result = [] + for client in self.redis_clients: + result.extend(client.keys(pattern)) + return result - Returns: - A dictionary with information about the object ID in question. - """ - # Allow the argument to be either an ObjectID or a hex string. - if not isinstance(object_id, ray.local_scheduler.ObjectID): - object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id)) + def _object_table(self, object_id): + """Fetch and parse the object table information for a single object ID. - # Return information about a single object ID. - object_locations = self._execute_command(object_id, - "RAY.OBJECT_TABLE_LOOKUP", - object_id.id()) - if object_locations is not None: - manager_ids = [binary_to_hex(manager_id) - for manager_id in object_locations] - else: - manager_ids = None + Args: + object_id_binary: A string of bytes with the object ID to get + information about. - result_table_response = self._execute_command(object_id, - "RAY.RESULT_TABLE_LOOKUP", - object_id.id()) - result_table_message = ResultTableReply.GetRootAsResultTableReply( - result_table_response, 0) + Returns: + A dictionary with information about the object ID in question. + """ + # Allow the argument to be either an ObjectID or a hex string. + if not isinstance(object_id, ray.local_scheduler.ObjectID): + object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id)) - result = {"ManagerIDs": manager_ids, - "TaskID": binary_to_hex(result_table_message.TaskId()), - "IsPut": bool(result_table_message.IsPut()), - "DataSize": result_table_message.DataSize(), - "Hash": binary_to_hex(result_table_message.Hash())} + # Return information about a single object ID. + object_locations = self._execute_command(object_id, + "RAY.OBJECT_TABLE_LOOKUP", + object_id.id()) + if object_locations is not None: + manager_ids = [binary_to_hex(manager_id) + for manager_id in object_locations] + else: + manager_ids = None - return result + result_table_response = self._execute_command( + object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) + result_table_message = ResultTableReply.GetRootAsResultTableReply( + result_table_response, 0) - def object_table(self, object_id=None): - """Fetch and parse the object table information for one or more object IDs. + result = {"ManagerIDs": manager_ids, + "TaskID": binary_to_hex(result_table_message.TaskId()), + "IsPut": bool(result_table_message.IsPut()), + "DataSize": result_table_message.DataSize(), + "Hash": binary_to_hex(result_table_message.Hash())} - Args: - object_id: An object ID to fetch information about. If this is None, then - the entire object table is fetched. + return result + + def object_table(self, object_id=None): + """Fetch and parse the object table info for one or more object IDs. + + Args: + object_id: An object ID to fetch information about. If this is + None, then the entire object table is fetched. - Returns: - Information from the object table. - """ - self._check_connected() - if object_id is not None: - # Return information about a single object ID. - return self._object_table(object_id) - else: - # Return the entire object table. - object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*") - object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*") - object_ids_binary = set( - [key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + - [key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys]) - results = {} - for object_id_binary in object_ids_binary: - results[binary_to_object_id(object_id_binary)] = self._object_table( - binary_to_object_id(object_id_binary)) - return results + Returns: + Information from the object table. + """ + self._check_connected() + if object_id is not None: + # Return information about a single object ID. + return self._object_table(object_id) + else: + # Return the entire object table. + object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*") + object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*") + object_ids_binary = set( + [key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + + [key[len(OBJECT_LOCATION_PREFIX):] + for key in object_location_keys]) + results = {} + for object_id_binary in object_ids_binary: + results[binary_to_object_id(object_id_binary)] = ( + self._object_table(binary_to_object_id(object_id_binary))) + return results - def _task_table(self, task_id): - """Fetch and parse the task table information for a single object task ID. + def _task_table(self, task_id): + """Fetch and parse the task table information for a single task ID. - Args: - task_id_binary: A string of bytes with the task ID to get information - about. + Args: + task_id_binary: A string of bytes with the task ID to get + information about. - Returns: - A dictionary with information about the task ID in question. - TASK_STATUS_MAPPING should be used to parse the "State" field into a - human-readable string. - """ - task_table_response = self._execute_command(task_id, - "RAY.TASK_TABLE_GET", - task_id.id()) - if task_table_response is None: - raise Exception("There is no entry for task ID {} in the task table." - .format(binary_to_hex(task_id.id()))) - task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0) - task_spec = task_table_message.TaskSpec() - task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) - args = [] - for i in range(task_spec_message.ArgsLength()): - arg = task_spec_message.Args(i) - if len(arg.ObjectId()) != 0: - args.append(binary_to_object_id(arg.ObjectId())) - else: - args.append(pickle.loads(arg.Data())) - assert task_spec_message.RequiredResourcesLength() == 2 - required_resources = {"CPUs": task_spec_message.RequiredResources(0), - "GPUs": task_spec_message.RequiredResources(1)} - task_spec_info = { - "DriverID": binary_to_hex(task_spec_message.DriverId()), - "TaskID": binary_to_hex(task_spec_message.TaskId()), - "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), - "ParentCounter": task_spec_message.ParentCounter(), - "ActorID": binary_to_hex(task_spec_message.ActorId()), - "ActorCounter": task_spec_message.ActorCounter(), - "FunctionID": binary_to_hex(task_spec_message.FunctionId()), - "Args": args, - "ReturnObjectIDs": [binary_to_object_id(task_spec_message.Returns(i)) - for i in range(task_spec_message.ReturnsLength())], - "RequiredResources": required_resources} + Returns: + A dictionary with information about the task ID in question. + TASK_STATUS_MAPPING should be used to parse the "State" field + into a human-readable string. + """ + task_table_response = self._execute_command(task_id, + "RAY.TASK_TABLE_GET", + task_id.id()) + if task_table_response is None: + raise Exception("There is no entry for task ID {} in the task " + "table.".format(binary_to_hex(task_id.id()))) + task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, + 0) + task_spec = task_table_message.TaskSpec() + task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0) + args = [] + for i in range(task_spec_message.ArgsLength()): + arg = task_spec_message.Args(i) + if len(arg.ObjectId()) != 0: + args.append(binary_to_object_id(arg.ObjectId())) + else: + args.append(pickle.loads(arg.Data())) + assert task_spec_message.RequiredResourcesLength() == 2 + required_resources = {"CPUs": task_spec_message.RequiredResources(0), + "GPUs": task_spec_message.RequiredResources(1)} + task_spec_info = { + "DriverID": binary_to_hex(task_spec_message.DriverId()), + "TaskID": binary_to_hex(task_spec_message.TaskId()), + "ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()), + "ParentCounter": task_spec_message.ParentCounter(), + "ActorID": binary_to_hex(task_spec_message.ActorId()), + "ActorCounter": task_spec_message.ActorCounter(), + "FunctionID": binary_to_hex(task_spec_message.FunctionId()), + "Args": args, + "ReturnObjectIDs": [binary_to_object_id( + task_spec_message.Returns(i)) + for i in range( + task_spec_message.ReturnsLength())], + "RequiredResources": required_resources} - return {"State": task_table_message.State(), - "LocalSchedulerID": binary_to_hex( - task_table_message.LocalSchedulerId()), - "TaskSpec": task_spec_info} + return {"State": task_table_message.State(), + "LocalSchedulerID": binary_to_hex( + task_table_message.LocalSchedulerId()), + "TaskSpec": task_spec_info} - def task_table(self, task_id=None): - """Fetch and parse the task table information for one or more task IDs. + def task_table(self, task_id=None): + """Fetch and parse the task table information for one or more task IDs. - Args: - task_id: A hex string of the task ID to fetch information about. If this - is None, then the task object table is fetched. + Args: + task_id: A hex string of the task ID to fetch information about. If + this is None, then the task object table is fetched. - Returns: - Information from the task table. - """ - self._check_connected() - if task_id is not None: - task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) - return self._task_table(task_id) - else: - task_table_keys = self._keys(TASK_PREFIX + "*") - results = {} - for key in task_table_keys: - task_id_binary = key[len(TASK_PREFIX):] - results[binary_to_hex(task_id_binary)] = self._task_table( - ray.local_scheduler.ObjectID(task_id_binary)) - return results + Returns: + Information from the task table. + """ + self._check_connected() + if task_id is not None: + task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) + return self._task_table(task_id) + else: + task_table_keys = self._keys(TASK_PREFIX + "*") + results = {} + for key in task_table_keys: + task_id_binary = key[len(TASK_PREFIX):] + results[binary_to_hex(task_id_binary)] = self._task_table( + ray.local_scheduler.ObjectID(task_id_binary)) + return results - def function_table(self, function_id=None): - """Fetch and parse the function table. + def function_table(self, function_id=None): + """Fetch and parse the function table. - Returns: - A dictionary that maps function IDs to information about the function. - """ - self._check_connected() - function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*") - results = {} - for key in function_table_keys: - info = self.redis_client.hgetall(key) - function_info_parsed = { - "DriverID": binary_to_hex(info[b"driver_id"]), - "Module": decode(info[b"module"]), - "Name": decode(info[b"name"]) - } - results[binary_to_hex(info[b"function_id"])] = function_info_parsed - return results + Returns: + A dictionary that maps function IDs to information about the + function. + """ + self._check_connected() + function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*") + results = {} + for key in function_table_keys: + info = self.redis_client.hgetall(key) + function_info_parsed = { + "DriverID": binary_to_hex(info[b"driver_id"]), + "Module": decode(info[b"module"]), + "Name": decode(info[b"name"])} + results[binary_to_hex(info[b"function_id"])] = function_info_parsed + return results - def client_table(self): - """Fetch and parse the Redis DB client table. + def client_table(self): + """Fetch and parse the Redis DB client table. - Returns: - Information about the Ray clients in the cluster. - """ - self._check_connected() - db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") - node_info = dict() - for key in db_client_keys: - client_info = self.redis_client.hgetall(key) - node_ip_address = decode(client_info[b"node_ip_address"]) - if node_ip_address not in node_info: - node_info[node_ip_address] = [] - client_info_parsed = { - "ClientType": decode(client_info[b"client_type"]), - "Deleted": bool(int(decode(client_info[b"deleted"]))), - "DBClientID": binary_to_hex(client_info[b"ray_client_id"]) - } - if b"aux_address" in client_info: - client_info_parsed["AuxAddress"] = decode(client_info[b"aux_address"]) - if b"num_cpus" in client_info: - client_info_parsed["NumCPUs"] = float(decode(client_info[b"num_cpus"])) - if b"num_gpus" in client_info: - client_info_parsed["NumGPUs"] = float(decode(client_info[b"num_gpus"])) - if b"local_scheduler_socket_name" in client_info: - client_info_parsed["LocalSchedulerSocketName"] = decode( - client_info[b"local_scheduler_socket_name"]) - node_info[node_ip_address].append(client_info_parsed) + Returns: + Information about the Ray clients in the cluster. + """ + self._check_connected() + db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*") + node_info = dict() + for key in db_client_keys: + client_info = self.redis_client.hgetall(key) + node_ip_address = decode(client_info[b"node_ip_address"]) + if node_ip_address not in node_info: + node_info[node_ip_address] = [] + client_info_parsed = { + "ClientType": decode(client_info[b"client_type"]), + "Deleted": bool(int(decode(client_info[b"deleted"]))), + "DBClientID": binary_to_hex(client_info[b"ray_client_id"]) + } + if b"aux_address" in client_info: + client_info_parsed["AuxAddress"] = decode( + client_info[b"aux_address"]) + if b"num_cpus" in client_info: + client_info_parsed["NumCPUs"] = float( + decode(client_info[b"num_cpus"])) + if b"num_gpus" in client_info: + client_info_parsed["NumGPUs"] = float( + decode(client_info[b"num_gpus"])) + if b"local_scheduler_socket_name" in client_info: + client_info_parsed["LocalSchedulerSocketName"] = decode( + client_info[b"local_scheduler_socket_name"]) + node_info[node_ip_address].append(client_info_parsed) - return node_info + return node_info - def log_files(self): - """Fetch and return a dictionary of log file names to outputs. + def log_files(self): + """Fetch and return a dictionary of log file names to outputs. - Returns: - IP address to log file name to log file contents mappings. - """ - relevant_files = self.redis_client.keys("LOGFILE*") + Returns: + IP address to log file name to log file contents mappings. + """ + relevant_files = self.redis_client.keys("LOGFILE*") - ip_filename_file = dict() + ip_filename_file = dict() - for filename in relevant_files: - filename = filename.decode("ascii") - filename_components = filename.split(":") - ip_addr = filename_components[1] + for filename in relevant_files: + filename = filename.decode("ascii") + filename_components = filename.split(":") + ip_addr = filename_components[1] - file = self.redis_client.lrange(filename, 0, -1) - file_str = [] - for x in file: - y = x.decode("ascii") - file_str.append(y) + file = self.redis_client.lrange(filename, 0, -1) + file_str = [] + for x in file: + y = x.decode("ascii") + file_str.append(y) - if ip_addr not in ip_filename_file: - ip_filename_file[ip_addr] = dict() + if ip_addr not in ip_filename_file: + ip_filename_file[ip_addr] = dict() - ip_filename_file[ip_addr][filename] = file_str + ip_filename_file[ip_addr][filename] = file_str - return ip_filename_file + return ip_filename_file - def task_profiles(self, start=None, end=None, num=None): - """Fetch and return a list of task profiles. + def task_profiles(self, start=None, end=None, num=None): + """Fetch and return a list of task profiles. - Args: - start: The start point of the time window that is queried for tasks. - end: The end point in time of the time window that is queried for tasks. - num: A limit on the number of tasks that task_profiles will return. + Args: + start: The start point of the time window that is queried for + tasks. + end: The end point in time of the time window that is queried for + tasks. + num: A limit on the number of tasks that task_profiles will return. - Returns: - A tuple of two elements. The first element is a dictionary mapping the - task ID of a task to a list of the profiling information for all of the - executions of that task. The second element is a list of profiling - information for tasks where the events have no task ID. - """ - if start is None: - start = 0 - if num is None: - num = sys.maxsize + Returns: + A tuple of two elements. The first element is a dictionary mapping + the task ID of a task to a list of the profiling information + for all of the executions of that task. The second element is a + list of profiling information for tasks where the events have + no task ID. + """ + if start is None: + start = 0 + if num is None: + num = sys.maxsize - task_info = dict() - event_log_sets = self.redis_client.keys("event_log*") + task_info = dict() + event_log_sets = self.redis_client.keys("event_log*") - # The heap is used to maintain the set of x tasks that occurred the most - # recently across all of the workers, where x is defined as the function - # parameter num. The key is the start time of the "get_task" component of - # each task. Calling heappop will result in the taks with the earliest - # "get_task_start" to be removed from the heap. + # The heap is used to maintain the set of x tasks that occurred the + # most recently across all of the workers, where x is defined as the + # function parameter num. The key is the start time of the "get_task" + # component of each task. Calling heappop will result in the taks with + # the earliest "get_task_start" to be removed from the heap. - heap = [] - heapq.heapify(heap) - heap_size = 0 - # Parse through event logs to determine task start and end points. - for i in range(len(event_log_sets)): - event_list = self.redis_client.zrangebyscore(event_log_sets[i], - min=start, - max=end, - start=start, - num=num) - for event in event_list: - event_dict = json.loads(event) - task_id = "" - for event in event_dict: - if "task_id" in event[3]: - task_id = event[3]["task_id"] - task_info[task_id] = dict() - for event in event_dict: - if event[1] == "ray:get_task" and event[2] == 1: - task_info[task_id]["get_task_start"] = event[0] - # Add task to min heap by its start point. - heapq.heappush(heap, - (task_info[task_id]["get_task_start"], task_id)) - heap_size += 1 - if event[1] == "ray:get_task" and event[2] == 2: - task_info[task_id]["get_task_end"] = event[0] - if event[1] == "ray:import_remote_function" and event[2] == 1: - task_info[task_id]["import_remote_start"] = event[0] - if event[1] == "ray:import_remote_function" and event[2] == 2: - task_info[task_id]["import_remote_end"] = event[0] - if event[1] == "ray:acquire_lock" and event[2] == 1: - task_info[task_id]["acquire_lock_start"] = event[0] - if event[1] == "ray:acquire_lock" and event[2] == 2: - task_info[task_id]["acquire_lock_end"] = event[0] - if event[1] == "ray:task:get_arguments" and event[2] == 1: - task_info[task_id]["get_arguments_start"] = event[0] - if event[1] == "ray:task:get_arguments" and event[2] == 2: - task_info[task_id]["get_arguments_end"] = event[0] - if event[1] == "ray:task:execute" and event[2] == 1: - task_info[task_id]["execute_start"] = event[0] - if event[1] == "ray:task:execute" and event[2] == 2: - task_info[task_id]["execute_end"] = event[0] - if event[1] == "ray:task:store_outputs" and event[2] == 1: - task_info[task_id]["store_outputs_start"] = event[0] - if event[1] == "ray:task:store_outputs" and event[2] == 2: - task_info[task_id]["store_outputs_end"] = event[0] - if "worker_id" in event[3]: - task_info[task_id]["worker_id"] = event[3]["worker_id"] - if "function_name" in event[3]: - task_info[task_id]["function_name"] = event[3]["function_name"] - if heap_size > num: - min_task, task_id_hex = heapq.heappop(heap) - del task_info[task_id_hex] - heap_size -= 1 - return task_info + heap = [] + heapq.heapify(heap) + heap_size = 0 + # Parse through event logs to determine task start and end points. + for i in range(len(event_log_sets)): + event_list = self.redis_client.zrangebyscore(event_log_sets[i], + min=start, + max=end, + start=start, + num=num) + for event in event_list: + event_dict = json.loads(event) + task_id = "" + for event in event_dict: + if "task_id" in event[3]: + task_id = event[3]["task_id"] + task_info[task_id] = dict() + for event in event_dict: + if event[1] == "ray:get_task" and event[2] == 1: + task_info[task_id]["get_task_start"] = event[0] + # Add task to min heap by its start point. + heapq.heappush(heap, + (task_info[task_id]["get_task_start"], + task_id)) + heap_size += 1 + if event[1] == "ray:get_task" and event[2] == 2: + task_info[task_id]["get_task_end"] = event[0] + if (event[1] == "ray:import_remote_function" and + event[2] == 1): + task_info[task_id]["import_remote_start"] = event[0] + if (event[1] == "ray:import_remote_function" and + event[2] == 2): + task_info[task_id]["import_remote_end"] = event[0] + if event[1] == "ray:acquire_lock" and event[2] == 1: + task_info[task_id]["acquire_lock_start"] = event[0] + if event[1] == "ray:acquire_lock" and event[2] == 2: + task_info[task_id]["acquire_lock_end"] = event[0] + if event[1] == "ray:task:get_arguments" and event[2] == 1: + task_info[task_id]["get_arguments_start"] = event[0] + if event[1] == "ray:task:get_arguments" and event[2] == 2: + task_info[task_id]["get_arguments_end"] = event[0] + if event[1] == "ray:task:execute" and event[2] == 1: + task_info[task_id]["execute_start"] = event[0] + if event[1] == "ray:task:execute" and event[2] == 2: + task_info[task_id]["execute_end"] = event[0] + if event[1] == "ray:task:store_outputs" and event[2] == 1: + task_info[task_id]["store_outputs_start"] = event[0] + if event[1] == "ray:task:store_outputs" and event[2] == 2: + task_info[task_id]["store_outputs_end"] = event[0] + if "worker_id" in event[3]: + task_info[task_id]["worker_id"] = event[3]["worker_id"] + if "function_name" in event[3]: + task_info[task_id]["function_name"] = ( + event[3]["function_name"]) + if heap_size > num: + min_task, task_id_hex = heapq.heappop(heap) + del task_info[task_id_hex] + heap_size -= 1 + return task_info - def dump_catapult_trace(self, path, start=None, end=None, num=None): - """Dump task profiling information to a file. + def dump_catapult_trace(self, path, start=None, end=None, num=None): + """Dump task profiling information to a file. - This information can be viewed as a timeline of profiling information by - going to chrome://tracing in the chrome web browser and loading the - appropriate file. + This information can be viewed as a timeline of profiling information + by going to chrome://tracing in the chrome web browser and loading the + appropriate file. - Args: - path: The filepath to dump the profiling information to. - """ - if end is None: - end = time.time() - task_info = self.task_profiles(start=start, end=end, num=num) - workers = self.workers() - start_time = None - for info in task_info.values(): - task_start = min(self._get_times(info)) - if not start_time or task_start < start_time: - start_time = task_start + Args: + path: The filepath to dump the profiling information to. + """ + if end is None: + end = time.time() + task_info = self.task_profiles(start=start, end=end, num=num) + workers = self.workers() + start_time = None + for info in task_info.values(): + task_start = min(self._get_times(info)) + if not start_time or task_start < start_time: + start_time = task_start - def micros(ts): - return int(1e6 * (ts - start_time)) + def micros(ts): + return int(1e6 * (ts - start_time)) - full_trace = [] - for task_id, info in task_info.items(): - task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) - task_data = self._task_table(task_id_hex) - parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"]) - times = self._get_times(info) - worker = workers[info["worker_id"]] - if parent_info: - parent_worker = workers[parent_info["worker_id"]] - parent_times = self._get_times(parent_info) - parent_trace = { - "cat": "submit_task", - "pid": "Node " + str(parent_worker["node_ip_address"]), - "tid": parent_info["worker_id"], - "ts": micros(min(parent_times)), - "ph": "s", - "name": "SubmitTask", - "args": {}, - "id": str(worker) - } - full_trace.append(parent_trace) + full_trace = [] + for task_id, info in task_info.items(): + task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id)) + task_data = self._task_table(task_id_hex) + parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"]) + times = self._get_times(info) + worker = workers[info["worker_id"]] + if parent_info: + parent_worker = workers[parent_info["worker_id"]] + parent_times = self._get_times(parent_info) + parent_trace = { + "cat": "submit_task", + "pid": "Node " + str(parent_worker["node_ip_address"]), + "tid": parent_info["worker_id"], + "ts": micros(min(parent_times)), + "ph": "s", + "name": "SubmitTask", + "args": {}, + "id": str(worker) + } + full_trace.append(parent_trace) - parent = { - "cat": "submit_task", - "pid": "Node " + str(parent_worker["node_ip_address"]), - "tid": parent_info["worker_id"], - "ts": micros(min(parent_times)), - "ph": "s", - "name": "SubmitTask", - "args": {}, - "id": str(worker) - } - full_trace.append(parent) + parent = { + "cat": "submit_task", + "pid": "Node " + str(parent_worker["node_ip_address"]), + "tid": parent_info["worker_id"], + "ts": micros(min(parent_times)), + "ph": "s", + "name": "SubmitTask", + "args": {}, + "id": str(worker) + } + full_trace.append(parent) - task_trace = { - "cat": "submit_task", - "pid": "Node " + str(worker["node_ip_address"]), - "tid": info["worker_id"], - "ts": micros(min(times)), - "ph": "f", - "name": "SubmitTask", - "args": {}, - "id": str(worker) - } - full_trace.append(task_trace) + task_trace = { + "cat": "submit_task", + "pid": "Node " + str(worker["node_ip_address"]), + "tid": info["worker_id"], + "ts": micros(min(times)), + "ph": "f", + "name": "SubmitTask", + "args": {}, + "id": str(worker) + } + full_trace.append(task_trace) - task = { - "name": info["function_name"], - "cat": "ray_task", - "ph": "X", - "ts": micros(min(times)), - "dur": micros(max(times)) - micros(min(times)), - "pid": "Node " + str(worker["node_ip_address"]), - "tid": info["worker_id"], - "args": info - } - full_trace.append(task) + task = { + "name": info["function_name"], + "cat": "ray_task", + "ph": "X", + "ts": micros(min(times)), + "dur": micros(max(times)) - micros(min(times)), + "pid": "Node " + str(worker["node_ip_address"]), + "tid": info["worker_id"], + "args": info + } + full_trace.append(task) - with open(path, "w") as outfile: - json.dump(full_trace, outfile) + with open(path, "w") as outfile: + json.dump(full_trace, outfile) - def _get_times(self, data): - """Extract the numerical times from a task profile. + def _get_times(self, data): + """Extract the numerical times from a task profile. - This is a helper method for dump_catapult_trace. + This is a helper method for dump_catapult_trace. - Args: - data: This must be a value in the dictionary returned by the - task_profiles function. - """ - all_times = [] - all_times.append(data["acquire_lock_start"]) - all_times.append(data["acquire_lock_end"]) - all_times.append(data["get_arguments_start"]) - all_times.append(data["get_arguments_end"]) - all_times.append(data["execute_start"]) - all_times.append(data["execute_end"]) - all_times.append(data["store_outputs_start"]) - all_times.append(data["store_outputs_end"]) - return all_times + Args: + data: This must be a value in the dictionary returned by the + task_profiles function. + """ + all_times = [] + all_times.append(data["acquire_lock_start"]) + all_times.append(data["acquire_lock_end"]) + all_times.append(data["get_arguments_start"]) + all_times.append(data["get_arguments_end"]) + all_times.append(data["execute_start"]) + all_times.append(data["execute_end"]) + all_times.append(data["store_outputs_start"]) + all_times.append(data["store_outputs_end"]) + return all_times - def workers(self): - """Get a dictionary mapping worker ID to worker information.""" - worker_keys = self.redis_client.keys("Worker*") - workers_data = dict() + def workers(self): + """Get a dictionary mapping worker ID to worker information.""" + worker_keys = self.redis_client.keys("Worker*") + workers_data = dict() - for worker_key in worker_keys: - worker_info = self.redis_client.hgetall(worker_key) - worker_id = binary_to_hex(worker_key[len("Workers:"):]) + for worker_key in worker_keys: + worker_info = self.redis_client.hgetall(worker_key) + worker_id = binary_to_hex(worker_key[len("Workers:"):]) - workers_data[worker_id] = { - "local_scheduler_socket": (worker_info[b"local_scheduler_socket"] - .decode("ascii")), - "node_ip_address": (worker_info[b"node_ip_address"] - .decode("ascii")), - "plasma_manager_socket": (worker_info[b"plasma_manager_socket"] - .decode("ascii")), - "plasma_store_socket": (worker_info[b"plasma_store_socket"] - .decode("ascii")), - "stderr_file": worker_info[b"stderr_file"].decode("ascii"), - "stdout_file": worker_info[b"stdout_file"].decode("ascii") - } - return workers_data + workers_data[worker_id] = { + "local_scheduler_socket": ( + worker_info[b"local_scheduler_socket"].decode("ascii")), + "node_ip_address": ( + worker_info[b"node_ip_address"].decode("ascii")), + "plasma_manager_socket": (worker_info[b"plasma_manager_socket"] + .decode("ascii")), + "plasma_store_socket": (worker_info[b"plasma_store_socket"] + .decode("ascii")), + "stderr_file": worker_info[b"stderr_file"].decode("ascii"), + "stdout_file": worker_info[b"stdout_file"].decode("ascii") + } + return workers_data diff --git a/python/ray/experimental/tfutils.py b/python/ray/experimental/tfutils.py index 8ebcd60ed..5133a49b9 100644 --- a/python/ray/experimental/tfutils.py +++ b/python/ray/experimental/tfutils.py @@ -6,114 +6,116 @@ from collections import deque, OrderedDict def unflatten(vector, shapes): - i = 0 - arrays = [] - for shape in shapes: - size = np.prod(shape) - array = vector[i:(i + size)].reshape(shape) - arrays.append(array) - i += size - assert len(vector) == i, "Passed weight does not have the correct shape." - return arrays + i = 0 + arrays = [] + for shape in shapes: + size = np.prod(shape) + array = vector[i:(i + size)].reshape(shape) + arrays.append(array) + i += size + assert len(vector) == i, "Passed weight does not have the correct shape." + return arrays class TensorFlowVariables(object): - """An object used to extract variables from a loss function. + """An object used to extract variables from a loss function. - This object also provides methods for getting and setting the weights of the - relevant variables. + This object also provides methods for getting and setting the weights of + the relevant variables. - Attributes: - sess (tf.Session): The tensorflow session used to run assignment. - loss: The loss function passed in by the user. - variables (List[tf.Variable]): Extracted variables from the loss. - assignment_placeholders (List[tf.placeholders]): The nodes that weights get - passed to. - assignment_nodes (List[tf.Tensor]): The nodes that assign the weights. - """ - def __init__(self, loss, sess=None): - """Creates a TensorFlowVariables instance.""" - import tensorflow as tf - self.sess = sess - self.loss = loss - queue = deque([loss]) - variable_names = [] - explored_inputs = set([loss]) + Attributes: + sess (tf.Session): The tensorflow session used to run assignment. + loss: The loss function passed in by the user. + variables (List[tf.Variable]): Extracted variables from the loss. + assignment_placeholders (List[tf.placeholders]): The nodes that weights + get passed to. + assignment _nodes (List[tf.Tensor]): The nodes that assign the weights. + """ + def __init__(self, loss, sess=None): + """Creates a TensorFlowVariables instance.""" + import tensorflow as tf + self.sess = sess + self.loss = loss + queue = deque([loss]) + variable_names = [] + explored_inputs = set([loss]) - # We do a BFS on the dependency graph of the input function to find - # the variables. - while len(queue) != 0: - tf_obj = queue.popleft() + # We do a BFS on the dependency graph of the input function to find + # the variables. + while len(queue) != 0: + tf_obj = queue.popleft() - # The object put into the queue is not necessarily an operation, so we - # want the op attribute to get the operation underlying the object. - # Only operations contain the inputs that we can explore. - if hasattr(tf_obj, "op"): - tf_obj = tf_obj.op - for input_op in tf_obj.inputs: - if input_op not in explored_inputs: - queue.append(input_op) - explored_inputs.add(input_op) - # Tensorflow control inputs can be circular, so we keep track of - # explored operations. - for control in tf_obj.control_inputs: - if control not in explored_inputs: - queue.append(control) - explored_inputs.add(control) - if "Variable" in tf_obj.node_def.op: - variable_names.append(tf_obj.node_def.name) - self.variables = OrderedDict() - for v in [v for v in tf.global_variables() - if v.op.node_def.name in variable_names]: - self.variables[v.op.node_def.name] = v - self.placeholders = dict() - self.assignment_nodes = [] + # The object put into the queue is not necessarily an operation, so + # we want the op attribute to get the operation underlying the + # object. Only operations contain the inputs that we can explore. + if hasattr(tf_obj, "op"): + tf_obj = tf_obj.op + for input_op in tf_obj.inputs: + if input_op not in explored_inputs: + queue.append(input_op) + explored_inputs.add(input_op) + # Tensorflow control inputs can be circular, so we keep track of + # explored operations. + for control in tf_obj.control_inputs: + if control not in explored_inputs: + queue.append(control) + explored_inputs.add(control) + if "Variable" in tf_obj.node_def.op: + variable_names.append(tf_obj.node_def.name) + self.variables = OrderedDict() + for v in [v for v in tf.global_variables() + if v.op.node_def.name in variable_names]: + self.variables[v.op.node_def.name] = v + self.placeholders = dict() + self.assignment_nodes = [] - # Create new placeholders to put in custom weights. - for k, var in self.variables.items(): - self.placeholders[k] = tf.placeholder(var.value().dtype, - var.get_shape().as_list()) - self.assignment_nodes.append(var.assign(self.placeholders[k])) + # Create new placeholders to put in custom weights. + for k, var in self.variables.items(): + self.placeholders[k] = tf.placeholder(var.value().dtype, + var.get_shape().as_list()) + self.assignment_nodes.append(var.assign(self.placeholders[k])) - def set_session(self, sess): - """Modifies the current session used by the class.""" - self.sess = sess + def set_session(self, sess): + """Modifies the current session used by the class.""" + self.sess = sess - def get_flat_size(self): - return sum([np.prod(v.get_shape().as_list()) - for v in self.variables.values()]) + def get_flat_size(self): + return sum([np.prod(v.get_shape().as_list()) + for v in self.variables.values()]) - def _check_sess(self): - """Checks if the session is set, and if not throw an error message.""" - assert self.sess is not None, ("The session is not set. Set the session " - "either by passing it into the " - "TensorFlowVariables constructor or by " - "calling set_session(sess).") + def _check_sess(self): + """Checks if the session is set, and if not throw an error message.""" + assert self.sess is not None, ("The session is not set. Set the " + "session either by passing it into the " + "TensorFlowVariables constructor or by " + "calling set_session(sess).") - def get_flat(self): - """Gets the weights and returns them as a flat array.""" - self._check_sess() - return np.concatenate([v.eval(session=self.sess).flatten() - for v in self.variables.values()]) + def get_flat(self): + """Gets the weights and returns them as a flat array.""" + self._check_sess() + return np.concatenate([v.eval(session=self.sess).flatten() + for v in self.variables.values()]) - def set_flat(self, new_weights): - """Sets the weights to new_weights, converting from a flat array.""" - self._check_sess() - shapes = [v.get_shape().as_list() for v in self.variables.values()] - arrays = unflatten(new_weights, shapes) - placeholders = [self.placeholders[k] for k, v in self.variables.items()] - self.sess.run(self.assignment_nodes, - feed_dict=dict(zip(placeholders, arrays))) + def set_flat(self, new_weights): + """Sets the weights to new_weights, converting from a flat array.""" + self._check_sess() + shapes = [v.get_shape().as_list() for v in self.variables.values()] + arrays = unflatten(new_weights, shapes) + placeholders = [self.placeholders[k] + for k, v in self.variables.items()] + self.sess.run(self.assignment_nodes, + feed_dict=dict(zip(placeholders, arrays))) - def get_weights(self): - """Returns the weights of the variables of the loss function in a list.""" - self._check_sess() - return {k: v.eval(session=self.sess) for k, v in self.variables.items()} + def get_weights(self): + """Returns a list of the weights of the loss function variables.""" + self._check_sess() + return {k: v.eval(session=self.sess) + for k, v in self.variables.items()} - def set_weights(self, new_weights): - """Sets the weights to new_weights.""" - self._check_sess() - self.sess.run(self.assignment_nodes, - feed_dict={self.placeholders[name]: value - for (name, value) in new_weights.items() - if name in self.placeholders}) + def set_weights(self, new_weights): + """Sets the weights to new_weights.""" + self._check_sess() + self.sess.run(self.assignment_nodes, + feed_dict={self.placeholders[name]: value + for (name, value) in new_weights.items() + if name in self.placeholders}) diff --git a/python/ray/experimental/utils.py b/python/ray/experimental/utils.py index 166ae0904..f802b1430 100644 --- a/python/ray/experimental/utils.py +++ b/python/ray/experimental/utils.py @@ -11,70 +11,72 @@ import ray def tarred_directory_as_bytes(source_dir): - """Tar a directory and return it as a byte string. + """Tar a directory and return it as a byte string. - Args: - source_dir (str): The name of the directory to tar. + Args: + source_dir (str): The name of the directory to tar. - Returns: - A byte string representing the tarred file. - """ - # Get a BytesIO object. - string_file = io.BytesIO() - # Create an in-memory tarfile of the source directory. - with tarfile.open(mode="w:gz", fileobj=string_file) as tar: - tar.add(source_dir, arcname=os.path.basename(source_dir)) - string_file.seek(0) - return string_file.read() + Returns: + A byte string representing the tarred file. + """ + # Get a BytesIO object. + string_file = io.BytesIO() + # Create an in-memory tarfile of the source directory. + with tarfile.open(mode="w:gz", fileobj=string_file) as tar: + tar.add(source_dir, arcname=os.path.basename(source_dir)) + string_file.seek(0) + return string_file.read() def tarred_bytes_to_directory(tarred_bytes, target_dir): - """Take a byte string and untar it. + """Take a byte string and untar it. - Args: - tarred_bytes (str): A byte string representing the tarred file. This should - be the output of tarred_directory_as_bytes. - target_dir (str): The directory to create the untarred files in. - """ - string_file = io.BytesIO(tarred_bytes) - with tarfile.open(fileobj=string_file) as tar: - tar.extractall(path=target_dir) + Args: + tarred_bytes (str): A byte string representing the tarred file. This + should be the output of tarred_directory_as_bytes. + target_dir (str): The directory to create the untarred files in. + """ + string_file = io.BytesIO(tarred_bytes) + with tarfile.open(fileobj=string_file) as tar: + tar.extractall(path=target_dir) def copy_directory(source_dir, target_dir=None): - """Copy a local directory to each machine in the Ray cluster. + """Copy a local directory to each machine in the Ray cluster. - Note that both source_dir and target_dir must have the same basename). For - example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case, - the directory /d/e will be added to the Python path of each worker. + Note that both source_dir and target_dir must have the same basename). For + example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this + case, the directory /d/e will be added to the Python path of each worker. - Note that this method is not completely safe to use. For example, workers - that do not do the copying and only set their paths (only one worker per node - does the copying) may try to execute functions that use the files in the - directory being copied before the directory being copied has finished - untarring. + Note that this method is not completely safe to use. For example, workers + that do not do the copying and only set their paths (only one worker per + node does the copying) may try to execute functions that use the files in + the directory being copied before the directory being copied has finished + untarring. - Args: - source_dir (str): The directory to copy. - target_dir (str): The location to copy it to on the other machines. If this - is not provided, the source_dir will be used. If it is provided and is - different from source_dir, the source_dir also be copied to the - target_dir location on this machine. - """ - target_dir = source_dir if target_dir is None else target_dir - source_dir = os.path.abspath(source_dir) - target_dir = os.path.abspath(target_dir) - source_basename = os.path.basename(source_dir) - target_basename = os.path.basename(target_dir) - if source_basename != target_basename: - raise Exception("The source_dir and target_dir must have the same base " - "name, {} != {}".format(source_basename, target_basename)) - tarred_bytes = tarred_directory_as_bytes(source_dir) + Args: + source_dir (str): The directory to copy. + target_dir (str): The location to copy it to on the other machines. If + this is not provided, the source_dir will be used. If it is + provided and is different from source_dir, the source_dir also be + copied to the target_dir location on this machine. + """ + target_dir = source_dir if target_dir is None else target_dir + source_dir = os.path.abspath(source_dir) + target_dir = os.path.abspath(target_dir) + source_basename = os.path.basename(source_dir) + target_basename = os.path.basename(target_dir) + if source_basename != target_basename: + raise Exception("The source_dir and target_dir must have the same " + "base name, {} != {}".format(source_basename, + target_basename)) + tarred_bytes = tarred_directory_as_bytes(source_dir) - def f(worker_info): - if worker_info["counter"] == 0: - tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir)) - sys.path.append(os.path.dirname(target_dir)) - # Run this function on all workers to copy the directory to all nodes and to - # add the directory to the Python path of each worker. - ray.worker.global_worker.run_function_on_all_workers(f) + def f(worker_info): + if worker_info["counter"] == 0: + tarred_bytes_to_directory(tarred_bytes, + os.path.dirname(target_dir)) + sys.path.append(os.path.dirname(target_dir)) + # Run this function on all workers to copy the directory to all nodes and + # to add the directory to the Python path of each worker. + ray.worker.global_worker.run_function_on_all_workers(f) diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py index 8cbc6f28e..a7001bb9d 100644 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ b/python/ray/global_scheduler/global_scheduler_services.py @@ -10,45 +10,45 @@ import time def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None): - """Start a global scheduler process. + """Start a global scheduler process. - Args: - redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will run - on. - use_valgrind (bool): True if the global scheduler should be started inside - of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the global scheduler should be started inside - a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. + Args: + redis_address (str): The address of the Redis instance. + node_ip_address: The IP address of the node that this scheduler will + run on. + use_valgrind (bool): True if the global scheduler should be started + inside of valgrind. If this is True, use_profiler must be False. + use_profiler (bool): True if the global scheduler should be started + inside a profiler. If this is True, use_valgrind must be False. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. - Return: - The process ID of the global scheduler process. - """ - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - global_scheduler_executable = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "../core/src/global_scheduler/global_scheduler") - command = [global_scheduler_executable, - "-r", redis_address, - "-h", node_ip_address] - if use_valgrind: - pid = subprocess.Popen(["valgrind", - "--track-origins=yes", - "--leak-check=full", - "--show-leak-kinds=all", - "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return pid + Return: + The process ID of the global scheduler process. + """ + if use_valgrind and use_profiler: + raise Exception("Cannot use valgrind and profiler at the same time.") + global_scheduler_executable = os.path.join( + os.path.abspath(os.path.dirname(__file__)), + "../core/src/global_scheduler/global_scheduler") + command = [global_scheduler_executable, + "-r", redis_address, + "-h", node_ip_address] + if use_valgrind: + pid = subprocess.Popen(["valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--error-exitcode=1"] + command, + stdout=stdout_file, stderr=stderr_file) + time.sleep(1.0) + elif use_profiler: + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, stderr=stderr_file) + time.sleep(1.0) + else: + pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + time.sleep(0.1) + return pid diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index e1a11db91..7b4d193d9 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -33,275 +33,285 @@ TASK_PREFIX = "TT:" def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def new_port(): - return random.randint(10000, 65535) + return random.randint(10000, 65535) class TestGlobalScheduler(unittest.TestCase): - def setUp(self): - # Start one Redis server and N pairs of (plasma, local_scheduler) - self.node_ip_address = "127.0.0.1" - redis_address, redis_shards = services.start_redis(self.node_ip_address) - redis_port = services.get_port(redis_address) - time.sleep(0.1) - # Create a client for the global state store. - self.state = state.GlobalState() - self.state._initialize_global_state(self.node_ip_address, redis_port) + def setUp(self): + # Start one Redis server and N pairs of (plasma, local_scheduler) + self.node_ip_address = "127.0.0.1" + redis_address, redis_shards = services.start_redis( + self.node_ip_address) + redis_port = services.get_port(redis_address) + time.sleep(0.1) + # Create a client for the global state store. + self.state = state.GlobalState() + self.state._initialize_global_state(self.node_ip_address, redis_port) - # Start one global scheduler. - self.p1 = global_scheduler.start_global_scheduler( - redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND) - self.plasma_store_pids = [] - self.plasma_manager_pids = [] - self.local_scheduler_pids = [] - self.plasma_clients = [] - self.local_scheduler_clients = [] + # Start one global scheduler. + self.p1 = global_scheduler.start_global_scheduler( + redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND) + self.plasma_store_pids = [] + self.plasma_manager_pids = [] + self.local_scheduler_pids = [] + self.plasma_clients = [] + self.local_scheduler_clients = [] - for i in range(NUM_CLUSTER_NODES): - # Start the Plasma store. Plasma store name is randomly generated. - plasma_store_name, p2 = plasma.start_plasma_store() - self.plasma_store_pids.append(p2) - # Start the Plasma manager. - # Assumption: Plasma manager name and port are randomly generated by the - # plasma module. - manager_info = plasma.start_plasma_manager(plasma_store_name, - redis_address) - plasma_manager_name, p3, plasma_manager_port = manager_info - self.plasma_manager_pids.append(p3) - plasma_address = "{}:{}".format(self.node_ip_address, - plasma_manager_port) - plasma_client = plasma.PlasmaClient(plasma_store_name, - plasma_manager_name) - self.plasma_clients.append(plasma_client) - # Start the local scheduler. - local_scheduler_name, p4 = local_scheduler.start_local_scheduler( - plasma_store_name, - plasma_manager_name=plasma_manager_name, - plasma_address=plasma_address, - redis_address=redis_address, - static_resource_list=[10, 0]) - # Connect to the scheduler. - local_scheduler_client = local_scheduler.LocalSchedulerClient( - local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0) - self.local_scheduler_clients.append(local_scheduler_client) - self.local_scheduler_pids.append(p4) + for i in range(NUM_CLUSTER_NODES): + # Start the Plasma store. Plasma store name is randomly generated. + plasma_store_name, p2 = plasma.start_plasma_store() + self.plasma_store_pids.append(p2) + # Start the Plasma manager. + # Assumption: Plasma manager name and port are randomly generated + # by the plasma module. + manager_info = plasma.start_plasma_manager(plasma_store_name, + redis_address) + plasma_manager_name, p3, plasma_manager_port = manager_info + self.plasma_manager_pids.append(p3) + plasma_address = "{}:{}".format(self.node_ip_address, + plasma_manager_port) + plasma_client = plasma.PlasmaClient(plasma_store_name, + plasma_manager_name) + self.plasma_clients.append(plasma_client) + # Start the local scheduler. + local_scheduler_name, p4 = local_scheduler.start_local_scheduler( + plasma_store_name, + plasma_manager_name=plasma_manager_name, + plasma_address=plasma_address, + redis_address=redis_address, + static_resource_list=[10, 0]) + # Connect to the scheduler. + local_scheduler_client = local_scheduler.LocalSchedulerClient( + local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0) + self.local_scheduler_clients.append(local_scheduler_client) + self.local_scheduler_pids.append(p4) - def tearDown(self): - # Check that the processes are still alive. - self.assertEqual(self.p1.poll(), None) - for p2 in self.plasma_store_pids: - self.assertEqual(p2.poll(), None) - for p3 in self.plasma_manager_pids: - self.assertEqual(p3.poll(), None) - for p4 in self.local_scheduler_pids: - self.assertEqual(p4.poll(), None) + def tearDown(self): + # Check that the processes are still alive. + self.assertEqual(self.p1.poll(), None) + for p2 in self.plasma_store_pids: + self.assertEqual(p2.poll(), None) + for p3 in self.plasma_manager_pids: + self.assertEqual(p3.poll(), None) + for p4 in self.local_scheduler_pids: + self.assertEqual(p4.poll(), None) - redis_processes = services.all_processes[ - services.PROCESS_TYPE_REDIS_SERVER] - for redis_process in redis_processes: - self.assertEqual(redis_process.poll(), None) + redis_processes = services.all_processes[ + services.PROCESS_TYPE_REDIS_SERVER] + for redis_process in redis_processes: + self.assertEqual(redis_process.poll(), None) - # Kill the global scheduler. - if USE_VALGRIND: - self.p1.send_signal(signal.SIGTERM) - self.p1.wait() - if self.p1.returncode != 0: - os._exit(-1) - else: - self.p1.kill() - # Kill local schedulers, plasma managers, and plasma stores. - for p2 in self.local_scheduler_pids: - p2.kill() - for p3 in self.plasma_manager_pids: - p3.kill() - for p4 in self.plasma_store_pids: - p4.kill() - # Kill Redis. In the event that we are using valgrind, this needs to happen - # after we kill the global scheduler. - while redis_processes: - redis_process = redis_processes.pop() - redis_process.kill() - - def get_plasma_manager_id(self): - """Get the db_client_id with client_type equal to plasma_manager. - - Iterates over all the client table keys, gets the db_client_id for the - client with client_type matching plasma_manager. Strips the client table - prefix. TODO(atumanov): write a separate function to get all plasma manager - client IDs. - - Returns: - The db_client_id if one is found and otherwise None. - """ - db_client_id = None - - client_list = self.state.client_table()[self.node_ip_address] - for client in client_list: - if client["ClientType"] == "plasma_manager": - db_client_id = client["DBClientID"] - break - - return db_client_id - - def test_task_default_resources(self): - task1 = local_scheduler.Task(random_driver_id(), random_function_id(), - [random_object_id()], 0, random_task_id(), 0) - self.assertEqual(task1.required_resources(), [1.0, 0.0]) - task2 = local_scheduler.Task(random_driver_id(), random_function_id(), - [random_object_id()], 0, random_task_id(), 0, - local_scheduler.ObjectID(NIL_ACTOR_ID), 0, - [1.0, 2.0]) - self.assertEqual(task2.required_resources(), [1.0, 2.0]) - - def test_redis_only_single_task(self): - """ - Tests global scheduler functionality by interacting with Redis and checking - task state transitions in Redis only. TODO(atumanov): implement. - """ - # Check precondition for this test: - # There should be 2n+1 db clients: the global scheduler + one local - # scheduler and one plasma per node. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - db_client_id = self.get_plasma_manager_id() - assert(db_client_id is not None) - - def test_integration_single_task(self): - # There should be three db clients, the global scheduler, the local - # scheduler, and the plasma manager. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - - num_return_vals = [0, 1, 2, 3, 5, 10] - # Insert the object into Redis. - data_size = 0xf1f0 - metadata_size = 0x40 - plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) - - # Sleep before submitting task to local scheduler. - time.sleep(0.1) - # Submit a task to Redis. - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep)], - num_return_vals[0], random_task_id(), 0) - self.local_scheduler_clients[0].submit(task) - time.sleep(0.1) - # There should now be a task in Redis, and it should get assigned to the - # local scheduler - num_retries = 10 - while num_retries > 0: - task_entries = self.state.task_table() - self.assertLessEqual(len(task_entries), 1) - if len(task_entries) == 1: - task_id, task = task_entries.popitem() - task_status = task["State"] - self.assertTrue(task_status in [state.TASK_STATUS_WAITING, - state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED]) - if task_status == state.TASK_STATUS_QUEUED: - break + # Kill the global scheduler. + if USE_VALGRIND: + self.p1.send_signal(signal.SIGTERM) + self.p1.wait() + if self.p1.returncode != 0: + os._exit(-1) else: - print(task_status) - print("The task has not been scheduled yet, trying again.") - num_retries -= 1 - time.sleep(1) + self.p1.kill() + # Kill local schedulers, plasma managers, and plasma stores. + for p2 in self.local_scheduler_pids: + p2.kill() + for p3 in self.plasma_manager_pids: + p3.kill() + for p4 in self.plasma_store_pids: + p4.kill() + # Kill Redis. In the event that we are using valgrind, this needs to + # happen after we kill the global scheduler. + while redis_processes: + redis_process = redis_processes.pop() + redis_process.kill() - if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED: - # Failed to submit and schedule a single task -- bail. - self.tearDown() - sys.exit(1) + def get_plasma_manager_id(self): + """Get the db_client_id with client_type equal to plasma_manager. - def integration_many_tasks_helper(self, timesync=True): - # There should be three db clients, the global scheduler, the local - # scheduler, and the plasma manager. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - num_return_vals = [0, 1, 2, 3, 5, 10] + Iterates over all the client table keys, gets the db_client_id for the + client with client_type matching plasma_manager. Strips the client + table prefix. TODO(atumanov): write a separate function to get all + plasma manager client IDs. - # Submit a bunch of tasks to Redis. - num_tasks = 1000 - for _ in range(num_tasks): - # Create a new object for each task. - data_size = np.random.randint(1 << 12) - metadata_size = np.random.randint(1 << 9) - plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object(plasma_client, - data_size, - metadata_size, - seal=True) - if timesync: - # Give 10ms for object info handler to fire (long enough to yield CPU). - time.sleep(0.010) - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep)], - num_return_vals[0], random_task_id(), 0) - self.local_scheduler_clients[0].submit(task) - # Check that there are the correct number of tasks in Redis and that they - # all get assigned to the local scheduler. - num_retries = 10 - num_tasks_done = 0 - while num_retries > 0: - task_entries = self.state.task_table() - self.assertLessEqual(len(task_entries), num_tasks) - # First, check if all tasks made it to Redis. - if len(task_entries) == num_tasks: - task_statuses = [task_entry["State"] for task_entry in - task_entries.values()] - self.assertTrue(all([status in [state.TASK_STATUS_WAITING, - state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED] - for status in task_statuses])) - num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED) - num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED) - num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING) - print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, " - "tasks queued = {}, retries left = {}" - .format(len(task_entries), num_tasks_waiting, - num_tasks_scheduled, num_tasks_done, num_retries)) - if all([status == state.TASK_STATUS_QUEUED for status in - task_statuses]): - # We're done, so pass. - break - num_retries -= 1 - time.sleep(0.1) + Returns: + The db_client_id if one is found and otherwise None. + """ + db_client_id = None - self.assertEqual(num_tasks_done, num_tasks) + client_list = self.state.client_table()[self.node_ip_address] + for client in client_list: + if client["ClientType"] == "plasma_manager": + db_client_id = client["DBClientID"] + break - def test_integration_many_tasks_handler_sync(self): - self.integration_many_tasks_helper(timesync=True) + return db_client_id - def test_integration_many_tasks(self): - # More realistic case: should handle out of order object and task - # notifications. - self.integration_many_tasks_helper(timesync=False) + def test_task_default_resources(self): + task1 = local_scheduler.Task(random_driver_id(), random_function_id(), + [random_object_id()], 0, random_task_id(), + 0) + self.assertEqual(task1.required_resources(), [1.0, 0.0]) + task2 = local_scheduler.Task(random_driver_id(), random_function_id(), + [random_object_id()], 0, random_task_id(), + 0, local_scheduler.ObjectID(NIL_ACTOR_ID), + 0, [1.0, 2.0]) + self.assertEqual(task2.required_resources(), [1.0, 2.0]) + + def test_redis_only_single_task(self): + # Tests global scheduler functionality by interacting with Redis and + # checking task state transitions in Redis only. TODO(atumanov): + # implement. + + # Check precondition for this test: + # There should be 2n+1 db clients: the global scheduler + one local + # scheduler and one plasma per node. + self.assertEqual( + len(self.state.client_table()[self.node_ip_address]), + 2 * NUM_CLUSTER_NODES + 1) + db_client_id = self.get_plasma_manager_id() + assert(db_client_id is not None) + + def test_integration_single_task(self): + # There should be three db clients, the global scheduler, the local + # scheduler, and the plasma manager. + self.assertEqual( + len(self.state.client_table()[self.node_ip_address]), + 2 * NUM_CLUSTER_NODES + 1) + + num_return_vals = [0, 1, 2, 3, 5, 10] + # Insert the object into Redis. + data_size = 0xf1f0 + metadata_size = 0x40 + plasma_client = self.plasma_clients[0] + object_dep, memory_buffer, metadata = create_object( + plasma_client, data_size, metadata_size, seal=True) + + # Sleep before submitting task to local scheduler. + time.sleep(0.1) + # Submit a task to Redis. + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [local_scheduler.ObjectID(object_dep)], + num_return_vals[0], random_task_id(), 0) + self.local_scheduler_clients[0].submit(task) + time.sleep(0.1) + # There should now be a task in Redis, and it should get assigned to + # the local scheduler + num_retries = 10 + while num_retries > 0: + task_entries = self.state.task_table() + self.assertLessEqual(len(task_entries), 1) + if len(task_entries) == 1: + task_id, task = task_entries.popitem() + task_status = task["State"] + self.assertTrue(task_status in [state.TASK_STATUS_WAITING, + state.TASK_STATUS_SCHEDULED, + state.TASK_STATUS_QUEUED]) + if task_status == state.TASK_STATUS_QUEUED: + break + else: + print(task_status) + print("The task has not been scheduled yet, trying again.") + num_retries -= 1 + time.sleep(1) + + if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED: + # Failed to submit and schedule a single task -- bail. + self.tearDown() + sys.exit(1) + + def integration_many_tasks_helper(self, timesync=True): + # There should be three db clients, the global scheduler, the local + # scheduler, and the plasma manager. + self.assertEqual( + len(self.state.client_table()[self.node_ip_address]), + 2 * NUM_CLUSTER_NODES + 1) + num_return_vals = [0, 1, 2, 3, 5, 10] + + # Submit a bunch of tasks to Redis. + num_tasks = 1000 + for _ in range(num_tasks): + # Create a new object for each task. + data_size = np.random.randint(1 << 12) + metadata_size = np.random.randint(1 << 9) + plasma_client = self.plasma_clients[0] + object_dep, memory_buffer, metadata = create_object(plasma_client, + data_size, + metadata_size, + seal=True) + if timesync: + # Give 10ms for object info handler to fire (long enough to + # yield CPU). + time.sleep(0.010) + task = local_scheduler.Task(random_driver_id(), + random_function_id(), + [local_scheduler.ObjectID(object_dep)], + num_return_vals[0], random_task_id(), + 0) + self.local_scheduler_clients[0].submit(task) + # Check that there are the correct number of tasks in Redis and that + # they all get assigned to the local scheduler. + num_retries = 10 + num_tasks_done = 0 + while num_retries > 0: + task_entries = self.state.task_table() + self.assertLessEqual(len(task_entries), num_tasks) + # First, check if all tasks made it to Redis. + if len(task_entries) == num_tasks: + task_statuses = [task_entry["State"] for task_entry in + task_entries.values()] + self.assertTrue(all([status in [state.TASK_STATUS_WAITING, + state.TASK_STATUS_SCHEDULED, + state.TASK_STATUS_QUEUED] + for status in task_statuses])) + num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED) + num_tasks_scheduled = task_statuses.count( + state.TASK_STATUS_SCHEDULED) + num_tasks_waiting = task_statuses.count( + state.TASK_STATUS_WAITING) + print("tasks in Redis = {}, tasks waiting = {}, " + "tasks scheduled = {}, " + "tasks queued = {}, retries left = {}" + .format(len(task_entries), num_tasks_waiting, + num_tasks_scheduled, num_tasks_done, + num_retries)) + if all([status == state.TASK_STATUS_QUEUED for status in + task_statuses]): + # We're done, so pass. + break + num_retries -= 1 + time.sleep(0.1) + + self.assertEqual(num_tasks_done, num_tasks) + + def test_integration_many_tasks_handler_sync(self): + self.integration_many_tasks_helper(timesync=True) + + def test_integration_many_tasks(self): + # More realistic case: should handle out of order object and task + # notifications. + self.integration_many_tasks_helper(timesync=False) if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) + if len(sys.argv) > 1: + # Pop the argument so we don't mess with unittest's own argument + # parser. + if sys.argv[-1] == "valgrind": + arg = sys.argv.pop() + USE_VALGRIND = True + print("Using valgrind for tests") + unittest.main(verbosity=2) diff --git a/python/ray/local_scheduler/local_scheduler_services.py b/python/ray/local_scheduler/local_scheduler_services.py index 03c213095..6fa1cd160 100644 --- a/python/ray/local_scheduler/local_scheduler_services.py +++ b/python/ray/local_scheduler/local_scheduler_services.py @@ -9,7 +9,7 @@ import time def random_name(): - return str(random.randint(0, 99999999)) + return str(random.randint(0, 99999999)) def start_local_scheduler(plasma_store_name, @@ -24,95 +24,99 @@ def start_local_scheduler(plasma_store_name, stderr_file=None, static_resource_list=None, num_workers=0): - """Start a local scheduler process. + """Start a local scheduler process. - Args: - plasma_store_name (str): The name of the plasma store socket to connect to. - plasma_manager_name (str): The name of the plasma manager to connect to. - This does not need to be provided, but if it is, then the Redis address - must be provided as well. - worker_path (str): The path of the worker script to use when the local - scheduler starts up new workers. - plasma_address (str): The address of the plasma manager to connect to. This - is only used by the global scheduler to figure out which plasma managers - are connected to which local schedulers. - node_ip_address (str): The address of the node that this local scheduler is - running on. - redis_address (str): The address of the Redis instance to connect to. If - this is not provided, then the local scheduler will not connect to Redis. - use_valgrind (bool): True if the local scheduler should be started inside - of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the local scheduler should be started inside a - profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - static_resource_list (list): A list of integers specifying the local - scheduler's resource capacities. The resources should appear in an order - matching the order defined in task.h. - num_workers (int): The number of workers that the local scheduler should - start. + Args: + plasma_store_name (str): The name of the plasma store socket to connect + to. + plasma_manager_name (str): The name of the plasma manager to connect + to. This does not need to be provided, but if it is, then the Redis + address must be provided as well. + worker_path (str): The path of the worker script to use when the local + scheduler starts up new workers. + plasma_address (str): The address of the plasma manager to connect to. + This is only used by the global scheduler to figure out which + plasma managers are connected to which local schedulers. + node_ip_address (str): The address of the node that this local + scheduler is running on. + redis_address (str): The address of the Redis instance to connect to. + If this is not provided, then the local scheduler will not connect + to Redis. + use_valgrind (bool): True if the local scheduler should be started + inside of valgrind. If this is True, use_profiler must be False. + use_profiler (bool): True if the local scheduler should be started + inside a profiler. If this is True, use_valgrind must be False. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + static_resource_list (list): A list of integers specifying the local + scheduler's resource capacities. The resources should appear in an + order matching the order defined in task.h. + num_workers (int): The number of workers that the local scheduler + should start. - Return: - A tuple of the name of the local scheduler socket and the process ID of the - local scheduler process. - """ - if (plasma_manager_name is None) != (redis_address is None): - raise Exception("If one of the plasma_manager_name and the redis_address " - "is provided, then both must be provided.") - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - local_scheduler_executable = os.path.join(os.path.dirname( - os.path.abspath(__file__)), - "../core/src/local_scheduler/local_scheduler") - local_scheduler_name = "/tmp/scheduler{}".format(random_name()) - command = [local_scheduler_executable, - "-s", local_scheduler_name, - "-p", plasma_store_name, - "-h", node_ip_address, - "-n", str(num_workers)] - if plasma_manager_name is not None: - command += ["-m", plasma_manager_name] - if worker_path is not None: - assert plasma_store_name is not None - assert plasma_manager_name is not None - assert redis_address is not None - start_worker_command = ("python {} " - "--node-ip-address={} " - "--object-store-name={} " - "--object-store-manager-name={} " - "--local-scheduler-name={} " - "--redis-address={}").format(worker_path, - node_ip_address, - plasma_store_name, - plasma_manager_name, - local_scheduler_name, - redis_address) - command += ["-w", start_worker_command] - if redis_address is not None: - command += ["-r", redis_address] - if plasma_address is not None: - command += ["-a", plasma_address] - if static_resource_list is not None: - assert all([isinstance(resource, int) or isinstance(resource, float) - for resource in static_resource_list]) - command += ["-c", ",".join([str(resource) for resource - in static_resource_list])] + Return: + A tuple of the name of the local scheduler socket and the process ID of + the local scheduler process. + """ + if (plasma_manager_name is None) != (redis_address is None): + raise Exception("If one of the plasma_manager_name and the " + "redis_address is provided, then both must be " + "provided.") + if use_valgrind and use_profiler: + raise Exception("Cannot use valgrind and profiler at the same time.") + local_scheduler_executable = os.path.join(os.path.dirname( + os.path.abspath(__file__)), + "../core/src/local_scheduler/local_scheduler") + local_scheduler_name = "/tmp/scheduler{}".format(random_name()) + command = [local_scheduler_executable, + "-s", local_scheduler_name, + "-p", plasma_store_name, + "-h", node_ip_address, + "-n", str(num_workers)] + if plasma_manager_name is not None: + command += ["-m", plasma_manager_name] + if worker_path is not None: + assert plasma_store_name is not None + assert plasma_manager_name is not None + assert redis_address is not None + start_worker_command = ("python {} " + "--node-ip-address={} " + "--object-store-name={} " + "--object-store-manager-name={} " + "--local-scheduler-name={} " + "--redis-address={}" + .format(worker_path, + node_ip_address, + plasma_store_name, + plasma_manager_name, + local_scheduler_name, + redis_address)) + command += ["-w", start_worker_command] + if redis_address is not None: + command += ["-r", redis_address] + if plasma_address is not None: + command += ["-a", plasma_address] + if static_resource_list is not None: + assert all([isinstance(resource, int) or isinstance(resource, float) + for resource in static_resource_list]) + command += ["-c", ",".join([str(resource) for resource + in static_resource_list])] - if use_valgrind: - pid = subprocess.Popen(["valgrind", - "--track-origins=yes", - "--leak-check=full", - "--show-leak-kinds=all", - "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return local_scheduler_name, pid + if use_valgrind: + pid = subprocess.Popen(["valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--error-exitcode=1"] + command, + stdout=stdout_file, stderr=stderr_file) + time.sleep(1.0) + elif use_profiler: + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, stderr=stderr_file) + time.sleep(1.0) + else: + pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + time.sleep(0.1) + return local_scheduler_name, pid diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py index 1a77173ce..6577119f5 100644 --- a/python/ray/local_scheduler/test/test.py +++ b/python/ray/local_scheduler/test/test.py @@ -21,195 +21,202 @@ NIL_ACTOR_ID = 20 * b"\xff" def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) + return local_scheduler.ObjectID(np.random.bytes(ID_SIZE)) class TestLocalSchedulerClient(unittest.TestCase): - def setUp(self): - # Start Plasma store. - plasma_store_name, self.p1 = plasma.start_plasma_store() - self.plasma_client = plasma.PlasmaClient(plasma_store_name, - release_delay=0) - # Start a local scheduler. - scheduler_name, self.p2 = local_scheduler.start_local_scheduler( - plasma_store_name, use_valgrind=USE_VALGRIND) - # Connect to the scheduler. - self.local_scheduler_client = local_scheduler.LocalSchedulerClient( - scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0) + def setUp(self): + # Start Plasma store. + plasma_store_name, self.p1 = plasma.start_plasma_store() + self.plasma_client = plasma.PlasmaClient(plasma_store_name, + release_delay=0) + # Start a local scheduler. + scheduler_name, self.p2 = local_scheduler.start_local_scheduler( + plasma_store_name, use_valgrind=USE_VALGRIND) + # Connect to the scheduler. + self.local_scheduler_client = local_scheduler.LocalSchedulerClient( + scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0) - def tearDown(self): - # Check that the processes are still alive. - self.assertEqual(self.p1.poll(), None) - self.assertEqual(self.p2.poll(), None) + def tearDown(self): + # Check that the processes are still alive. + self.assertEqual(self.p1.poll(), None) + self.assertEqual(self.p2.poll(), None) - # Kill Plasma. - self.p1.kill() - # Kill the local scheduler. - if USE_VALGRIND: - self.p2.send_signal(signal.SIGTERM) - self.p2.wait() - if self.p2.returncode != 0: - os._exit(-1) - else: - self.p2.kill() + # Kill Plasma. + self.p1.kill() + # Kill the local scheduler. + if USE_VALGRIND: + self.p2.send_signal(signal.SIGTERM) + self.p2.wait() + if self.p2.returncode != 0: + os._exit(-1) + else: + self.p2.kill() - def test_submit_and_get_task(self): - function_id = random_function_id() - object_ids = [random_object_id() for i in range(256)] - # Create and seal the objects in the object store so that we can schedule - # all of the subsequent tasks. - for object_id in object_ids: - self.plasma_client.create(object_id.id(), 0) - self.plasma_client.seal(object_id.id()) - # Define some arguments to use for the tasks. - args_list = [ - [], - [{}], - [()], - 1 * [1], - 10 * [1], - 100 * [1], - 1000 * [1], - 1 * ["a"], - 10 * ["a"], - 100 * ["a"], - 1000 * ["a"], - [1, 1.3, 1 << 100, "hi", u"hi", [1, 2]], - object_ids[:1], - object_ids[:2], - object_ids[:3], - object_ids[:4], - object_ids[:5], - object_ids[:10], - object_ids[:100], - object_ids[:256], - [1, object_ids[0]], - [object_ids[0], "a"], - [1, object_ids[0], "a"], - [object_ids[0], 1, object_ids[1], "a"], - object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids - ] + def test_submit_and_get_task(self): + function_id = random_function_id() + object_ids = [random_object_id() for i in range(256)] + # Create and seal the objects in the object store so that we can + # schedule all of the subsequent tasks. + for object_id in object_ids: + self.plasma_client.create(object_id.id(), 0) + self.plasma_client.seal(object_id.id()) + # Define some arguments to use for the tasks. + args_list = [ + [], + [{}], + [()], + 1 * [1], + 10 * [1], + 100 * [1], + 1000 * [1], + 1 * ["a"], + 10 * ["a"], + 100 * ["a"], + 1000 * ["a"], + [1, 1.3, 1 << 100, "hi", u"hi", [1, 2]], + object_ids[:1], + object_ids[:2], + object_ids[:3], + object_ids[:4], + object_ids[:5], + object_ids[:10], + object_ids[:100], + object_ids[:256], + [1, object_ids[0]], + [object_ids[0], "a"], + [1, object_ids[0], "a"], + [object_ids[0], 1, object_ids[1], "a"], + object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], + object_ids + 100 * ["a"] + object_ids + ] - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, args, - num_return_vals, random_task_id(), 0) - # Submit a task. + for args in args_list: + for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: + task = local_scheduler.Task(random_driver_id(), function_id, + args, num_return_vals, + random_task_id(), 0) + # Submit a task. + self.local_scheduler_client.submit(task) + # Get the task. + new_task = self.local_scheduler_client.get_task() + self.assertEqual(task.function_id().id(), + new_task.function_id().id()) + retrieved_args = new_task.arguments() + returns = new_task.returns() + self.assertEqual(len(args), len(retrieved_args)) + self.assertEqual(num_return_vals, len(returns)) + for i in range(len(retrieved_args)): + if isinstance(args[i], local_scheduler.ObjectID): + self.assertEqual(args[i].id(), retrieved_args[i].id()) + else: + self.assertEqual(args[i], retrieved_args[i]) + + # Submit all of the tasks. + for args in args_list: + for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: + task = local_scheduler.Task(random_driver_id(), function_id, + args, num_return_vals, + random_task_id(), 0) + self.local_scheduler_client.submit(task) + # Get all of the tasks. + for args in args_list: + for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: + new_task = self.local_scheduler_client.get_task() + + def test_scheduling_when_objects_ready(self): + # Create a task and submit it. + object_id = random_object_id() + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [object_id], 0, random_task_id(), 0) self.local_scheduler_client.submit(task) - # Get the task. - new_task = self.local_scheduler_client.get_task() - self.assertEqual(task.function_id().id(), new_task.function_id().id()) - retrieved_args = new_task.arguments() - returns = new_task.returns() - self.assertEqual(len(args), len(retrieved_args)) - self.assertEqual(num_return_vals, len(returns)) - for i in range(len(retrieved_args)): - if isinstance(args[i], local_scheduler.ObjectID): - self.assertEqual(args[i].id(), retrieved_args[i].id()) - else: - self.assertEqual(args[i], retrieved_args[i]) - # Submit all of the tasks. - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, args, - num_return_vals, random_task_id(), 0) + # Launch a thread to get the task. + def get_task(): + self.local_scheduler_client.get_task() + t = threading.Thread(target=get_task) + t.start() + # Sleep to give the thread time to call get_task. + time.sleep(0.1) + # Create and seal the object ID in the object store. This should + # trigger a scheduling event. + self.plasma_client.create(object_id.id(), 0) + self.plasma_client.seal(object_id.id()) + # Wait until the thread finishes so that we know the task was + # scheduled. + t.join() + + def test_scheduling_when_objects_evicted(self): + # Create a task with two dependencies and submit it. + object_id1 = random_object_id() + object_id2 = random_object_id() + task = local_scheduler.Task(random_driver_id(), random_function_id(), + [object_id1, object_id2], 0, + random_task_id(), 0) self.local_scheduler_client.submit(task) - # Get all of the tasks. - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - new_task = self.local_scheduler_client.get_task() - def test_scheduling_when_objects_ready(self): - # Create a task and submit it. - object_id = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id], 0, random_task_id(), 0) - self.local_scheduler_client.submit(task) + # Launch a thread to get the task. + def get_task(): + self.local_scheduler_client.get_task() + t = threading.Thread(target=get_task) + t.start() - # Launch a thread to get the task. - def get_task(): - self.local_scheduler_client.get_task() - t = threading.Thread(target=get_task) - t.start() - # Sleep to give the thread time to call get_task. - time.sleep(0.1) - # Create and seal the object ID in the object store. This should trigger a - # scheduling event. - self.plasma_client.create(object_id.id(), 0) - self.plasma_client.seal(object_id.id()) - # Wait until the thread finishes so that we know the task was scheduled. - t.join() + # Make one of the dependencies available. + buf = self.plasma_client.create(object_id1.id(), 1) + self.plasma_client.seal(object_id1.id()) + # Release the object. + del buf + # Check that the thread is still waiting for a task. + time.sleep(0.1) + self.assertTrue(t.is_alive()) + # Force eviction of the first dependency. + self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY) + # Check that the thread is still waiting for a task. + time.sleep(0.1) + self.assertTrue(t.is_alive()) + # Check that the first object dependency was evicted. + object1 = self.plasma_client.get([object_id1.id()], timeout_ms=0) + self.assertEqual(object1, [None]) + # Check that the thread is still waiting for a task. + time.sleep(0.1) + self.assertTrue(t.is_alive()) - def test_scheduling_when_objects_evicted(self): - # Create a task with two dependencies and submit it. - object_id1 = random_object_id() - object_id2 = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id1, object_id2], 0, random_task_id(), - 0) - self.local_scheduler_client.submit(task) + # Create the second dependency. + self.plasma_client.create(object_id2.id(), 1) + self.plasma_client.seal(object_id2.id()) + # Check that the thread is still waiting for a task. + time.sleep(0.1) + self.assertTrue(t.is_alive()) - # Launch a thread to get the task. - def get_task(): - self.local_scheduler_client.get_task() - t = threading.Thread(target=get_task) - t.start() + # Create the first dependency again. Both dependencies are now + # available. + self.plasma_client.create(object_id1.id(), 1) + self.plasma_client.seal(object_id1.id()) - # Make one of the dependencies available. - buf = self.plasma_client.create(object_id1.id(), 1) - self.plasma_client.seal(object_id1.id()) - # Release the object. - del buf - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - # Force eviction of the first dependency. - self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - # Check that the first object dependency was evicted. - object1 = self.plasma_client.get([object_id1.id()], timeout_ms=0) - self.assertEqual(object1, [None]) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - - # Create the second dependency. - self.plasma_client.create(object_id2.id(), 1) - self.plasma_client.seal(object_id2.id()) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - - # Create the first dependency again. Both dependencies are now available. - self.plasma_client.create(object_id1.id(), 1) - self.plasma_client.seal(object_id1.id()) - - # Wait until the thread finishes so that we know the task was scheduled. - t.join() + # Wait until the thread finishes so that we know the task was + # scheduled. + t.join() if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) + if len(sys.argv) > 1: + # Pop the argument so we don't mess with unittest's own argument + # parser. + if sys.argv[-1] == "valgrind": + arg = sys.argv.pop() + USE_VALGRIND = True + print("Using valgrind for tests") + unittest.main(verbosity=2) diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 924c4597a..8e58e617c 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -12,94 +12,102 @@ from ray.services import get_port class LogMonitor(object): - """A monitor process for monitoring Ray log files. + """A monitor process for monitoring Ray log files. - Attributes: - node_ip_address: The IP address of the node that the log monitor process is - running on. This will be used to determine which log files to track. - redis_client: A client used to communicate with the Redis server. - log_filenames: A list of the names of the log files that this monitor - process is monitoring. - log_files: A dictionary mapping the name of a log file to a list of strings - representing its contents. - log_file_handles: A dictionary mapping the name of a log file to a file - handle for that file. - """ - def __init__(self, redis_ip_address, redis_port, node_ip_address): - """Initialize the log monitor object.""" - self.node_ip_address = node_ip_address - self.redis_client = redis.StrictRedis(host=redis_ip_address, - port=redis_port) - self.log_files = {} - self.log_file_handles = {} - - def update_log_filenames(self): - """Get the most up-to-date list of log files to monitor from Redis.""" - num_current_log_files = len(self.log_files) - new_log_filenames = self.redis_client.lrange( - "LOG_FILENAMES:{}".format(self.node_ip_address), - num_current_log_files, -1) - for log_filename in new_log_filenames: - print("Beginning to track file {}".format(log_filename)) - assert log_filename not in self.log_files - self.log_files[log_filename] = [] - - def check_log_files_and_push_updates(self): - """Get any changes to the log files and push updates to Redis.""" - for log_filename in self.log_files: - if log_filename in self.log_file_handles: - # Get any updates to the file. - new_lines = [] - while True: - current_position = self.log_file_handles[log_filename].tell() - next_line = self.log_file_handles[log_filename].readline() - if next_line != "": - new_lines.append(next_line) - else: - self.log_file_handles[log_filename].seek(current_position) - break - - # If there are any new lines, cache them and also push them to Redis. - if len(new_lines) > 0: - self.log_files[log_filename] += new_lines - redis_key = "LOGFILE:{}:{}".format(self.node_ip_address, - log_filename.decode("ascii")) - self.redis_client.rpush(redis_key, *new_lines) - else: - try: - self.log_file_handles[log_filename] = open(log_filename, "r") - except IOError as e: - if e.errno == os.errno.EMFILE: - print("Warning: Some files are not being logged because there are " - "too many open files.") - elif e.errno == os.errno.ENOENT: - print("Warning: The file {} was not found.".format(log_filename)) - else: - raise e - - def run(self): - """Run the log monitor. - - This will query Redis once every second to check if there are new log files - to monitor. It will also store those log files in Redis. + Attributes: + node_ip_address: The IP address of the node that the log monitor + process is running on. This will be used to determine which log + files to track. + redis_client: A client used to communicate with the Redis server. + log_filenames: A list of the names of the log files that this monitor + process is monitoring. + log_files: A dictionary mapping the name of a log file to a list of + strings representing its contents. + log_file_handles: A dictionary mapping the name of a log file to a file + handle for that file. """ - while True: - self.update_log_filenames() - self.check_log_files_and_push_updates() - time.sleep(1) + def __init__(self, redis_ip_address, redis_port, node_ip_address): + """Initialize the log monitor object.""" + self.node_ip_address = node_ip_address + self.redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) + self.log_files = {} + self.log_file_handles = {} + + def update_log_filenames(self): + """Get the most up-to-date list of log files to monitor from Redis.""" + num_current_log_files = len(self.log_files) + new_log_filenames = self.redis_client.lrange( + "LOG_FILENAMES:{}".format(self.node_ip_address), + num_current_log_files, -1) + for log_filename in new_log_filenames: + print("Beginning to track file {}".format(log_filename)) + assert log_filename not in self.log_files + self.log_files[log_filename] = [] + + def check_log_files_and_push_updates(self): + """Get any changes to the log files and push updates to Redis.""" + for log_filename in self.log_files: + if log_filename in self.log_file_handles: + # Get any updates to the file. + new_lines = [] + while True: + current_position = ( + self.log_file_handles[log_filename].tell()) + next_line = self.log_file_handles[log_filename].readline() + if next_line != "": + new_lines.append(next_line) + else: + self.log_file_handles[log_filename].seek( + current_position) + break + + # If there are any new lines, cache them and also push them to + # Redis. + if len(new_lines) > 0: + self.log_files[log_filename] += new_lines + redis_key = "LOGFILE:{}:{}".format( + self.node_ip_address, log_filename.decode("ascii")) + self.redis_client.rpush(redis_key, *new_lines) + else: + try: + self.log_file_handles[log_filename] = open(log_filename, + "r") + except IOError as e: + if e.errno == os.errno.EMFILE: + print("Warning: Some files are not being logged " + "because there are too many open files.") + elif e.errno == os.errno.ENOENT: + print("Warning: The file {} was not " + "found.".format(log_filename)) + else: + raise e + + def run(self): + """Run the log monitor. + + This will query Redis once every second to check if there are new log + files to monitor. It will also store those log files in Redis. + """ + while True: + self.update_log_filenames() + self.check_log_files_and_push_updates() + time.sleep(1) if __name__ == "__main__": - parser = argparse.ArgumentParser(description=("Parse Redis server for the " - "log monitor to connect to.")) - parser.add_argument("--redis-address", required=True, type=str, - help="The address to use for Redis.") - parser.add_argument("--node-ip-address", required=True, type=str, - help="The IP address of the node this process is on.") - args = parser.parse_args() + parser = argparse.ArgumentParser(description=("Parse Redis server for the " + "log monitor to connect " + "to.")) + parser.add_argument("--redis-address", required=True, type=str, + help="The address to use for Redis.") + parser.add_argument("--node-ip-address", required=True, type=str, + help="The IP address of the node this process is on.") + args = parser.parse_args() - redis_ip_address = get_ip_address(args.redis_address) - redis_port = get_port(args.redis_address) + redis_ip_address = get_ip_address(args.redis_address) + redis_port = get_port(args.redis_address) - log_monitor = LogMonitor(redis_ip_address, redis_port, args.node_ip_address) - log_monitor.run() + log_monitor = LogMonitor(redis_ip_address, redis_port, + args.node_ip_address) + log_monitor.run() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index d7004d4ae..eb86443c9 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -48,354 +48,367 @@ log.setLevel(logging.INFO) class Monitor(object): - """A monitor for Ray processes. + """A monitor for Ray processes. - The monitor is in charge of cleaning up the tables in the global state after - processes have died. The monitor is currently not responsible for detecting - component failures. + The monitor is in charge of cleaning up the tables in the global state + after processes have died. The monitor is currently not responsible for + detecting component failures. - Attributes: - redis: A connection to the Redis server. - subscribe_client: A pubsub client for the Redis server. This is used to - receive notifications about failed components. - subscribed: A dictionary mapping channel names (str) to whether or not the - subscription to that channel has succeeded yet (bool). - dead_local_schedulers: A set of the local scheduler IDs of all of the local - schedulers that were up at one point and have died since then. - live_plasma_managers: A counter mapping live plasma manager IDs to the - number of heartbeats that have passed since we last heard from that - plasma manager. A plasma manager is live if we received a heartbeat from - it at any point, and if it has not timed out. - dead_plasma_managers: A set of the plasma manager IDs of all the plasma - managers that were up at one point and have died since then. - """ - def __init__(self, redis_address, redis_port): - # Initialize the Redis clients. - self.state = ray.experimental.state.GlobalState() - self.state._initialize_global_state(redis_address, redis_port) - self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0) - # TODO(swang): Update pubsub client to use ray.experimental.state once - # subscriptions are implemented there. - self.subscribe_client = self.redis.pubsub() - self.subscribed = {} - # Initialize data structures to keep track of the active database clients. - self.dead_local_schedulers = set() - self.live_plasma_managers = Counter() - self.dead_plasma_managers = set() - - def subscribe(self, channel): - """Subscribe to the given channel. - - Args: - channel (str): The channel to subscribe to. - - Raises: - Exception: An exception is raised if the subscription fails. + Attributes: + redis: A connection to the Redis server. + subscribe_client: A pubsub client for the Redis server. This is used to + receive notifications about failed components. + subscribed: A dictionary mapping channel names (str) to whether or not + the subscription to that channel has succeeded yet (bool). + dead_local_schedulers: A set of the local scheduler IDs of all of the + local schedulers that were up at one point and have died since + then. + live_plasma_managers: A counter mapping live plasma manager IDs to the + number of heartbeats that have passed since we last heard from that + plasma manager. A plasma manager is live if we received a heartbeat + from it at any point, and if it has not timed out. + dead_plasma_managers: A set of the plasma manager IDs of all the plasma + managers that were up at one point and have died since then. """ - self.subscribe_client.subscribe(channel) - self.subscribed[channel] = False + def __init__(self, redis_address, redis_port): + # Initialize the Redis clients. + self.state = ray.experimental.state.GlobalState() + self.state._initialize_global_state(redis_address, redis_port) + self.redis = redis.StrictRedis(host=redis_address, port=redis_port, + db=0) + # TODO(swang): Update pubsub client to use ray.experimental.state once + # subscriptions are implemented there. + self.subscribe_client = self.redis.pubsub() + self.subscribed = {} + # Initialize data structures to keep track of the active database + # clients. + self.dead_local_schedulers = set() + self.live_plasma_managers = Counter() + self.dead_plasma_managers = set() - def cleanup_task_table(self): - """Clean up global state for failed local schedulers. + def subscribe(self, channel): + """Subscribe to the given channel. - This marks any tasks that were scheduled on dead local schedulers as - TASK_STATUS_LOST. A local scheduler is deemed dead if it is in - self.dead_local_schedulers. - """ - tasks = self.state.task_table() - num_tasks_updated = 0 - for task_id, task in tasks.items(): - # See if the corresponding local scheduler is alive. - if task["LocalSchedulerID"] in self.dead_local_schedulers: - # If the task is scheduled on a dead local scheduler, mark the task as - # lost. - key = binary_to_object_id(hex_to_binary(task_id)) - ok = self.state._execute_command( - key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), - ray.experimental.state.TASK_STATUS_LOST, NIL_ID) - if ok != b"OK": - log.warn("Failed to update lost task for dead scheduler.") - num_tasks_updated += 1 - if num_tasks_updated > 0: - log.warn("Marked {} tasks as lost.".format(num_tasks_updated)) + Args: + channel (str): The channel to subscribe to. - def cleanup_object_table(self): - """Clean up global state for failed plasma managers. + Raises: + Exception: An exception is raised if the subscription fails. + """ + self.subscribe_client.subscribe(channel) + self.subscribed[channel] = False - This removes dead plasma managers from any location entries in the object - table. A plasma manager is deemed dead if it is in - self.dead_plasma_managers. - """ - # TODO(swang): Also kill the associated plasma store, since it's no longer - # reachable without a plasma manager. - objects = self.state.object_table() - num_objects_removed = 0 - for object_id, obj in objects.items(): - manager_ids = obj["ManagerIDs"] - if manager_ids is None: - continue - for manager in manager_ids: - if manager in self.dead_plasma_managers: - # If the object was on a dead plasma manager, remove that location - # entry. - ok = self.state._execute_command(object_id, - "RAY.OBJECT_TABLE_REMOVE", - object_id.id(), - hex_to_binary(manager)) - if ok != b"OK": - log.warn("Failed to remove object location for dead plasma " - "manager.") - num_objects_removed += 1 - if num_objects_removed > 0: - log.warn("Marked {} objects as lost.".format(num_objects_removed)) + def cleanup_task_table(self): + """Clean up global state for failed local schedulers. - def scan_db_client_table(self): - """Scan the database client table for dead clients. + This marks any tasks that were scheduled on dead local schedulers as + TASK_STATUS_LOST. A local scheduler is deemed dead if it is in + self.dead_local_schedulers. + """ + tasks = self.state.task_table() + num_tasks_updated = 0 + for task_id, task in tasks.items(): + # See if the corresponding local scheduler is alive. + if task["LocalSchedulerID"] in self.dead_local_schedulers: + # If the task is scheduled on a dead local scheduler, mark the + # task as lost. + key = binary_to_object_id(hex_to_binary(task_id)) + ok = self.state._execute_command( + key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), + ray.experimental.state.TASK_STATUS_LOST, NIL_ID) + if ok != b"OK": + log.warn("Failed to update lost task for dead scheduler.") + num_tasks_updated += 1 + if num_tasks_updated > 0: + log.warn("Marked {} tasks as lost.".format(num_tasks_updated)) - After subscribing to the client table, it's necessary to call this before - reading any messages from the subscription channel. This ensures that we do - not miss any notifications for deleted clients that occurred before we - subscribed. - """ - clients = self.state.client_table() - for node_ip_address, node_clients in clients.items(): - for client in node_clients: - db_client_id = client["DBClientID"] - client_type = client["ClientType"] - if client["Deleted"]: - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - self.dead_plasma_managers.add(db_client_id) + def cleanup_object_table(self): + """Clean up global state for failed plasma managers. - def subscribe_handler(self, channel, data): - """Handle a subscription success message from Redis. - """ - log.debug("Subscribed to {}, data was {}".format(channel, data)) - self.subscribed[channel] = True + This removes dead plasma managers from any location entries in the + object table. A plasma manager is deemed dead if it is in + self.dead_plasma_managers. + """ + # TODO(swang): Also kill the associated plasma store, since it's no + # longer reachable without a plasma manager. + objects = self.state.object_table() + num_objects_removed = 0 + for object_id, obj in objects.items(): + manager_ids = obj["ManagerIDs"] + if manager_ids is None: + continue + for manager in manager_ids: + if manager in self.dead_plasma_managers: + # If the object was on a dead plasma manager, remove that + # location entry. + ok = self.state._execute_command(object_id, + "RAY.OBJECT_TABLE_REMOVE", + object_id.id(), + hex_to_binary(manager)) + if ok != b"OK": + log.warn("Failed to remove object location for dead " + "plasma manager.") + num_objects_removed += 1 + if num_objects_removed > 0: + log.warn("Marked {} objects as lost.".format(num_objects_removed)) - def db_client_notification_handler(self, channel, data): - """Handle a notification from the db_client table from Redis. + def scan_db_client_table(self): + """Scan the database client table for dead clients. - This handler processes notifications from the db_client table. - Notifications should be parsed using the SubscribeToDBClientTableReply - flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the - associated state in the state tables should be handled by the caller. - """ - notification_object = (SubscribeToDBClientTableReply - .GetRootAsSubscribeToDBClientTableReply(data, 0)) - db_client_id = binary_to_hex(notification_object.DbClientId()) - client_type = notification_object.ClientType() - is_insertion = notification_object.IsInsertion() + After subscribing to the client table, it's necessary to call this + before reading any messages from the subscription channel. This ensures + that we do not miss any notifications for deleted clients that occurred + before we subscribed. + """ + clients = self.state.client_table() + for node_ip_address, node_clients in clients.items(): + for client in node_clients: + db_client_id = client["DBClientID"] + client_type = client["ClientType"] + if client["Deleted"]: + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + self.dead_local_schedulers.add(db_client_id) + elif client_type == PLASMA_MANAGER_CLIENT_TYPE: + self.dead_plasma_managers.add(db_client_id) - # If the update was an insertion, we ignore it. - if is_insertion: - return + def subscribe_handler(self, channel, data): + """Handle a subscription success message from Redis.""" + log.debug("Subscribed to {}, data was {}".format(channel, data)) + self.subscribed[channel] = True - # If the update was a deletion, add them to our accounting for dead - # local schedulers and plasma managers. - log.warn("Removed {}, client ID {}".format(client_type, db_client_id)) - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - if db_client_id not in self.dead_local_schedulers: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - if db_client_id not in self.dead_plasma_managers: - self.dead_plasma_managers.add(db_client_id) - # Stop tracking this plasma manager's heartbeats, since it's - # already dead. - del self.live_plasma_managers[db_client_id] + def db_client_notification_handler(self, channel, data): + """Handle a notification from the db_client table from Redis. - def plasma_manager_heartbeat_handler(self, channel, data): - """Handle a plasma manager heartbeat from Redis. + This handler processes notifications from the db_client table. + Notifications should be parsed using the SubscribeToDBClientTableReply + flatbuffer. Deletions are processed, insertions are ignored. Cleanup of + the associated state in the state tables should be handled by the + caller. + """ + notification_object = (SubscribeToDBClientTableReply + .GetRootAsSubscribeToDBClientTableReply(data, + 0)) + db_client_id = binary_to_hex(notification_object.DbClientId()) + client_type = notification_object.ClientType() + is_insertion = notification_object.IsInsertion() - This resets the number of heartbeats that we've missed from this plasma - manager. - """ - # The first DB_CLIENT_ID_SIZE characters are the client ID. - db_client_id = data[:DB_CLIENT_ID_SIZE] - # Reset the number of heartbeats that we've missed from this plasma - # manager. - self.live_plasma_managers[db_client_id] = 0 + # If the update was an insertion, we ignore it. + if is_insertion: + return - def driver_removed_handler(self, channel, data): - """Handle a notification that a driver has been removed. + # If the update was a deletion, add them to our accounting for dead + # local schedulers and plasma managers. + log.warn("Removed {}, client ID {}".format(client_type, db_client_id)) + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + if db_client_id not in self.dead_local_schedulers: + self.dead_local_schedulers.add(db_client_id) + elif client_type == PLASMA_MANAGER_CLIENT_TYPE: + if db_client_id not in self.dead_plasma_managers: + self.dead_plasma_managers.add(db_client_id) + # Stop tracking this plasma manager's heartbeats, since it's + # already dead. + del self.live_plasma_managers[db_client_id] - This releases any GPU resources that were reserved for that driver in - Redis. - """ - message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) - driver_id = message.DriverId() - log.info("Driver {} has been removed.".format(binary_to_hex(driver_id))) + def plasma_manager_heartbeat_handler(self, channel, data): + """Handle a plasma manager heartbeat from Redis. - # Get a list of the local schedulers. - client_table = ray.global_state.client_table() - local_schedulers = [] - for ip_address, clients in client_table.items(): - for client in clients: - if client["ClientType"] == "local_scheduler": - local_schedulers.append(client) + This resets the number of heartbeats that we've missed from this plasma + manager. + """ + # The first DB_CLIENT_ID_SIZE characters are the client ID. + db_client_id = data[:DB_CLIENT_ID_SIZE] + # Reset the number of heartbeats that we've missed from this plasma + # manager. + self.live_plasma_managers[db_client_id] = 0 - # Release any GPU resources that have been reserved for this driver in - # Redis. - for local_scheduler in local_schedulers: - if int(local_scheduler["NumGPUs"]) > 0: - local_scheduler_id = local_scheduler["DBClientID"] + def driver_removed_handler(self, channel, data): + """Handle a notification that a driver has been removed. - num_gpus_returned = 0 + This releases any GPU resources that were reserved for that driver in + Redis. + """ + message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) + driver_id = message.DriverId() + log.info("Driver {} has been removed." + .format(binary_to_hex(driver_id))) - # Perform a transaction to return the GPUs. - with self.redis.pipeline() as pipe: - while True: - try: - # If this key is changed before the transaction below (the - # multi/exec block), then the transaction will not take place. - pipe.watch(local_scheduler_id) + # Get a list of the local schedulers. + client_table = ray.global_state.client_table() + local_schedulers = [] + for ip_address, clients in client_table.items(): + for client in clients: + if client["ClientType"] == "local_scheduler": + local_schedulers.append(client) - result = pipe.hget(local_scheduler_id, "gpus_in_use") - gpus_in_use = dict() if result is None else json.loads(result) + # Release any GPU resources that have been reserved for this driver in + # Redis. + for local_scheduler in local_schedulers: + if int(local_scheduler["NumGPUs"]) > 0: + local_scheduler_id = local_scheduler["DBClientID"] - driver_id_hex = binary_to_hex(driver_id) - if driver_id_hex in gpus_in_use: - num_gpus_returned = gpus_in_use.pop(driver_id_hex) + num_gpus_returned = 0 - pipe.multi() + # Perform a transaction to return the GPUs. + with self.redis.pipeline() as pipe: + while True: + try: + # If this key is changed before the transaction + # below (the multi/exec block), then the + # transaction will not take place. + pipe.watch(local_scheduler_id) - pipe.hset(local_scheduler_id, "gpus_in_use", - json.dumps(gpus_in_use)) + result = pipe.hget(local_scheduler_id, + "gpus_in_use") + gpus_in_use = (dict() if result is None + else json.loads(result)) - pipe.execute() - # If a WatchError is not raise, then the operations should have - # gone through atomically. - break - except redis.WatchError: - # Another client must have changed the watched key between the - # time we started WATCHing it and the pipeline's execution. We - # should just retry. - continue + driver_id_hex = binary_to_hex(driver_id) + if driver_id_hex in gpus_in_use: + num_gpus_returned = gpus_in_use.pop( + driver_id_hex) - log.info("Driver {} is returning GPU IDs {} to local scheduler {}." - .format(driver_id, num_gpus_returned, local_scheduler_id)) + pipe.multi() - def process_messages(self): - """Process all messages ready in the subscription channels. + pipe.hset(local_scheduler_id, "gpus_in_use", + json.dumps(gpus_in_use)) - This reads messages from the subscription channels and calls the - appropriate handlers until there are no messages left. - """ - while True: - message = self.subscribe_client.get_message() - if message is None: - return + pipe.execute() + # If a WatchError is not raise, then the operations + # should have gone through atomically. + break + except redis.WatchError: + # Another client must have changed the watched key + # between the time we started WATCHing it and the + # pipeline's execution. We should just retry. + continue - # Parse the message. - channel = message["channel"] - data = message["data"] + log.info("Driver {} is returning GPU IDs {} to local " + "scheduler {}.".format(driver_id, num_gpus_returned, + local_scheduler_id)) - # Determine the appropriate message handler. - message_handler = None - if not self.subscribed[channel]: - # If the data was an integer, then the message was a response to an - # initial subscription request. - message_handler = self.subscribe_handler - elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: - assert(self.subscribed[channel]) - # The message was a heartbeat from a plasma manager. - message_handler = self.plasma_manager_heartbeat_handler - elif channel == DB_CLIENT_TABLE_NAME: - assert(self.subscribed[channel]) - # The message was a notification from the db_client table. - message_handler = self.db_client_notification_handler - elif channel == DRIVER_DEATH_CHANNEL: - assert(self.subscribed[channel]) - # The message was a notification that a driver was removed. - message_handler = self.driver_removed_handler - else: - raise Exception("This code should be unreachable.") + def process_messages(self): + """Process all messages ready in the subscription channels. - # Call the handler. - assert(message_handler is not None) - message_handler(channel, data) + This reads messages from the subscription channels and calls the + appropriate handlers until there are no messages left. + """ + while True: + message = self.subscribe_client.get_message() + if message is None: + return - def run(self): - """Run the monitor. + # Parse the message. + channel = message["channel"] + data = message["data"] - This function loops forever, checking for messages about dead database - clients and cleaning up state accordingly. - """ - # Initialize the subscription channel. - self.subscribe(DB_CLIENT_TABLE_NAME) - self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) - self.subscribe(DRIVER_DEATH_CHANNEL) + # Determine the appropriate message handler. + message_handler = None + if not self.subscribed[channel]: + # If the data was an integer, then the message was a response + # to an initial subscription request. + message_handler = self.subscribe_handler + elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: + assert(self.subscribed[channel]) + # The message was a heartbeat from a plasma manager. + message_handler = self.plasma_manager_heartbeat_handler + elif channel == DB_CLIENT_TABLE_NAME: + assert(self.subscribed[channel]) + # The message was a notification from the db_client table. + message_handler = self.db_client_notification_handler + elif channel == DRIVER_DEATH_CHANNEL: + assert(self.subscribed[channel]) + # The message was a notification that a driver was removed. + message_handler = self.driver_removed_handler + else: + raise Exception("This code should be unreachable.") - # Scan the database table for dead database clients. NOTE: This must be - # called before reading any messages from the subscription channel. This - # ensures that we start in a consistent state, since we may have missed - # notifications that were sent before we connected to the subscription - # channel. - self.scan_db_client_table() - # If there were any dead clients at startup, clean up the associated state - # in the state tables. - if len(self.dead_local_schedulers) > 0: - self.cleanup_task_table() - if len(self.dead_plasma_managers) > 0: - self.cleanup_object_table() - log.debug("{} dead local schedulers, {} plasma managers total, {} dead " - "plasma managers".format(len(self.dead_local_schedulers), - (len(self.live_plasma_managers) + - len(self.dead_plasma_managers)), - len(self.dead_plasma_managers))) + # Call the handler. + assert(message_handler is not None) + message_handler(channel, data) - # Handle messages from the subscription channels. - while True: - # Record how many dead local schedulers and plasma managers we had at the - # beginning of this round. - num_dead_local_schedulers = len(self.dead_local_schedulers) - num_dead_plasma_managers = len(self.dead_plasma_managers) - # Process a round of messages. - self.process_messages() - # If any new local schedulers or plasma managers were marked as dead in - # this round, clean up the associated state. - if len(self.dead_local_schedulers) > num_dead_local_schedulers: - self.cleanup_task_table() - if len(self.dead_plasma_managers) > num_dead_plasma_managers: - self.cleanup_object_table() + def run(self): + """Run the monitor. - # Handle plasma managers that timed out during this round. - plasma_manager_ids = list(self.live_plasma_managers.keys()) - for plasma_manager_id in plasma_manager_ids: - if ((self.live_plasma_managers - [plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT): - log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE)) - # Remove the plasma manager from the managers whose heartbeats we're - # tracking. - del self.live_plasma_managers[plasma_manager_id] - # Remove the plasma manager from the db_client table. The - # corresponding state in the object table will be cleaned up once we - # receive the notification for this db_client deletion. - self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id) + This function loops forever, checking for messages about dead database + clients and cleaning up state accordingly. + """ + # Initialize the subscription channel. + self.subscribe(DB_CLIENT_TABLE_NAME) + self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) + self.subscribe(DRIVER_DEATH_CHANNEL) - # Increment the number of heartbeats that we've missed from each plasma - # manager. - for plasma_manager_id in self.live_plasma_managers: - self.live_plasma_managers[plasma_manager_id] += 1 + # Scan the database table for dead database clients. NOTE: This must be + # called before reading any messages from the subscription channel. + # This ensures that we start in a consistent state, since we may have + # missed notifications that were sent before we connected to the + # subscription channel. + self.scan_db_client_table() + # If there were any dead clients at startup, clean up the associated + # state in the state tables. + if len(self.dead_local_schedulers) > 0: + self.cleanup_task_table() + if len(self.dead_plasma_managers) > 0: + self.cleanup_object_table() + log.debug("{} dead local schedulers, {} plasma managers total, {} " + "dead plasma managers".format( + len(self.dead_local_schedulers), + (len(self.live_plasma_managers) + + len(self.dead_plasma_managers)), + len(self.dead_plasma_managers))) - # Wait for a heartbeat interval before processing the next round of - # messages. - time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3) + # Handle messages from the subscription channels. + while True: + # Record how many dead local schedulers and plasma managers we had + # at the beginning of this round. + num_dead_local_schedulers = len(self.dead_local_schedulers) + num_dead_plasma_managers = len(self.dead_plasma_managers) + # Process a round of messages. + self.process_messages() + # If any new local schedulers or plasma managers were marked as + # dead in this round, clean up the associated state. + if len(self.dead_local_schedulers) > num_dead_local_schedulers: + self.cleanup_task_table() + if len(self.dead_plasma_managers) > num_dead_plasma_managers: + self.cleanup_object_table() + + # Handle plasma managers that timed out during this round. + plasma_manager_ids = list(self.live_plasma_managers.keys()) + for plasma_manager_id in plasma_manager_ids: + if ((self.live_plasma_managers + [plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT): + log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE)) + # Remove the plasma manager from the managers whose + # heartbeats we're tracking. + del self.live_plasma_managers[plasma_manager_id] + # Remove the plasma manager from the db_client table. The + # corresponding state in the object table will be cleaned + # up once we receive the notification for this db_client + # deletion. + self.redis.execute_command("RAY.DISCONNECT", + plasma_manager_id) + + # Increment the number of heartbeats that we've missed from each + # plasma manager. + for plasma_manager_id in self.live_plasma_managers: + self.live_plasma_managers[plasma_manager_id] += 1 + + # Wait for a heartbeat interval before processing the next round of + # messages. + time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3) if __name__ == "__main__": - parser = argparse.ArgumentParser(description=("Parse Redis server for the " - "monitor to connect to.")) - parser.add_argument("--redis-address", required=True, type=str, - help="the address to use for Redis") - args = parser.parse_args() + parser = argparse.ArgumentParser(description=("Parse Redis server for the " + "monitor to connect to.")) + parser.add_argument("--redis-address", required=True, type=str, + help="the address to use for Redis") + args = parser.parse_args() - redis_ip_address = get_ip_address(args.redis_address) - redis_port = get_port(args.redis_address) + redis_ip_address = get_ip_address(args.redis_address) + redis_port = get_port(args.redis_address) - # Initialize the global state. - ray.global_state._initialize_global_state(redis_ip_address, redis_port) + # Initialize the global state. + ray.global_state._initialize_global_state(redis_ip_address, redis_port) - monitor = Monitor(redis_ip_address, redis_port) - monitor.run() + monitor = Monitor(redis_ip_address, redis_port) + monitor.run() diff --git a/python/ray/numbuf/__init__.py b/python/ray/numbuf/__init__.py index dc6a82e22..fdca8de3e 100644 --- a/python/ray/numbuf/__init__.py +++ b/python/ray/numbuf/__init__.py @@ -16,28 +16,26 @@ __all__ = ["deserialize_list", "numbuf_error", "store_list", "write_to_buffer"] try: - from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error, - numbuf_plasma_object_exists_error, - read_from_buffer, - register_callbacks, retrieve_list, - serialize_list, store_list, - write_to_buffer) + from ray.core.src.numbuf.libnumbuf import ( + deserialize_list, numbuf_error, numbuf_plasma_object_exists_error, + read_from_buffer, register_callbacks, retrieve_list, serialize_list, + store_list, write_to_buffer) except ImportError as e: - if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or - "CXX" in e.msg)): - # This code path should be taken with Python 3. - e.msg += helpful_message - elif (hasattr(e, "message") and isinstance(e.message, str) and - ("libstdc++" in e.message or "CXX" in e.message)): - # This code path should be taken with Python 2. - condition = (hasattr(e, "args") and isinstance(e.args, tuple) and - len(e.args) == 1 and isinstance(e.args[0], str)) - if condition: - e.args = (e.args[0] + helpful_message,) - else: - if not hasattr(e, "args"): - e.args = () - elif not isinstance(e.args, tuple): - e.args = (e.args,) - e.args += (helpful_message,) - raise + if ((hasattr(e, "msg") and isinstance(e.msg, str) and + ("libstdc++" in e.msg or "CXX" in e.msg))): + # This code path should be taken with Python 3. + e.msg += helpful_message + elif (hasattr(e, "message") and isinstance(e.message, str) and + ("libstdc++" in e.message or "CXX" in e.message)): + # This code path should be taken with Python 2. + condition = (hasattr(e, "args") and isinstance(e.args, tuple) and + len(e.args) == 1 and isinstance(e.args[0], str)) + if condition: + e.args = (e.args[0] + helpful_message,) + else: + if not hasattr(e, "args"): + e.args = () + elif not isinstance(e.args, tuple): + e.args = (e.args,) + e.args += (helpful_message,) + raise diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index 718d3f01f..b7f0aefde 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -21,342 +21,355 @@ PLASMA_WAIT_TIMEOUT = 2 ** 30 class PlasmaBuffer(object): - """This is the type of objects returned by calls to get with a PlasmaClient. + """This is the type returned by calls to get with a PlasmaClient. - We define our own class instead of directly returning a buffer object so that - we can add a custom destructor which notifies Plasma that the object is no - longer being used, so the memory in the Plasma store backing the object can - potentially be freed. + We define our own class instead of directly returning a buffer object so + that we can add a custom destructor which notifies Plasma that the object + is no longer being used, so the memory in the Plasma store backing the + object can potentially be freed. - Attributes: - buffer (buffer): A buffer containing an object in the Plasma store. - plasma_id (PlasmaID): The ID of the object in the buffer. - plasma_client (PlasmaClient): The PlasmaClient that we use to communicate - with the store and manager. - """ - def __init__(self, buff, plasma_id, plasma_client): - """Initialize a PlasmaBuffer.""" - self.buffer = buff - self.plasma_id = plasma_id - self.plasma_client = plasma_client - - def __del__(self): - """Notify Plasma that the object is no longer needed. - - If the plasma client has been shut down, then don't do anything. + Attributes: + buffer (buffer): A buffer containing an object in the Plasma store. + plasma_id (PlasmaID): The ID of the object in the buffer. + plasma_client (PlasmaClient): The PlasmaClient that we use to communicate + with the store and manager. """ - if self.plasma_client.alive: - libplasma.release(self.plasma_client.conn, self.plasma_id) + def __init__(self, buff, plasma_id, plasma_client): + """Initialize a PlasmaBuffer.""" + self.buffer = buff + self.plasma_id = plasma_id + self.plasma_client = plasma_client - def __getitem__(self, index): - """Read from the PlasmaBuffer as if it were just a regular buffer.""" - # We currently don't allow slicing plasma buffers. We should handle this - # better, but it requires some care because the slice may be backed by the - # same memory in the object store, but the original plasma buffer may go - # out of scope causing the memory to no longer be accessible. - assert not isinstance(index, slice) - value = self.buffer[index] - if sys.version_info >= (3, 0) and not isinstance(index, slice): - value = chr(value) - return value + def __del__(self): + """Notify Plasma that the object is no longer needed. - def __setitem__(self, index, value): - """Write to the PlasmaBuffer as if it were just a regular buffer. + If the plasma client has been shut down, then don't do anything. + """ + if self.plasma_client.alive: + libplasma.release(self.plasma_client.conn, self.plasma_id) - This should fail because the buffer should be read only. - """ - # We currently don't allow slicing plasma buffers. We should handle this - # better, but it requires some care because the slice may be backed by the - # same memory in the object store, but the original plasma buffer may go - # out of scope causing the memory to no longer be accessible. - assert not isinstance(index, slice) - if sys.version_info >= (3, 0) and not isinstance(index, slice): - value = ord(value) - self.buffer[index] = value + def __getitem__(self, index): + """Read from the PlasmaBuffer as if it were just a regular buffer.""" + # We currently don't allow slicing plasma buffers. We should handle + # this better, but it requires some care because the slice may be + # backed by the same memory in the object store, but the original + # plasma buffer may go out of scope causing the memory to no longer be + # accessible. + assert not isinstance(index, slice) + value = self.buffer[index] + if sys.version_info >= (3, 0) and not isinstance(index, slice): + value = chr(value) + return value - def __len__(self): - """Return the length of the buffer.""" - return len(self.buffer) + def __setitem__(self, index, value): + """Write to the PlasmaBuffer as if it were just a regular buffer. + + This should fail because the buffer should be read only. + """ + # We currently don't allow slicing plasma buffers. We should handle + # this better, but it requires some care because the slice may be + # backed by the same memory in the object store, but the original + # plasma buffer may go out of scope causing the memory to no longer be + # accessible. + assert not isinstance(index, slice) + if sys.version_info >= (3, 0) and not isinstance(index, slice): + value = ord(value) + self.buffer[index] = value + + def __len__(self): + """Return the length of the buffer.""" + return len(self.buffer) def buffers_equal(buff1, buff2): - """Compare two buffers. These buffers may be PlasmaBuffer objects. + """Compare two buffers. These buffers may be PlasmaBuffer objects. - This method should only be used in the tests. We implement a special helper - method for doing this because doing comparisons by slicing is much faster, - but we don't want to expose slicing of PlasmaBuffer objects because it - currently is not safe. - """ - buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1 - buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2 - return buff1_to_compare[:] == buff2_to_compare[:] + This method should only be used in the tests. We implement a special helper + method for doing this because doing comparisons by slicing is much faster, + but we don't want to expose slicing of PlasmaBuffer objects because it + currently is not safe. + """ + buff1_to_compare = (buff1.buffer if isinstance(buff1, PlasmaBuffer) + else buff1) + buff2_to_compare = (buff2.buffer if isinstance(buff2, PlasmaBuffer) + else buff2) + return buff1_to_compare[:] == buff2_to_compare[:] class PlasmaClient(object): - """The PlasmaClient is used to interface with a plasma store and manager. + """The PlasmaClient is used to interface with a plasma store and manager. - The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a - buffer, and get a buffer. Buffers are referred to by object IDs, which are - strings. - """ - - def __init__(self, store_socket_name, manager_socket_name=None, - release_delay=64): - """Initialize the PlasmaClient. - - Args: - store_socket_name (str): Name of the socket the plasma store is listening - at. - manager_socket_name (str): Name of the socket the plasma manager is - listening at. - release_delay (int): The maximum number of objects that the client will - keep and delay releasing (for caching reasons). + The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a + buffer, and get a buffer. Buffers are referred to by object IDs, which are + strings. """ - self.store_socket_name = store_socket_name - self.manager_socket_name = manager_socket_name - self.alive = True - if manager_socket_name is not None: - self.conn = libplasma.connect(store_socket_name, manager_socket_name, - release_delay) - else: - self.conn = libplasma.connect(store_socket_name, "", release_delay) + def __init__(self, store_socket_name, manager_socket_name=None, + release_delay=64): + """Initialize the PlasmaClient. - def shutdown(self): - """Shutdown the client so that it does not send messages. + Args: + store_socket_name (str): Name of the socket the plasma store is + listening at. + manager_socket_name (str): Name of the socket the plasma manager is + listening at. + release_delay (int): The maximum number of objects that the client + will keep and delay releasing (for caching reasons). + """ + self.store_socket_name = store_socket_name + self.manager_socket_name = manager_socket_name + self.alive = True - If we kill the Plasma store and Plasma manager that this client is - connected to, then we can use this method to prevent the client from trying - to send messages to the killed processes. - """ - if self.alive: - libplasma.disconnect(self.conn) - self.alive = False + if manager_socket_name is not None: + self.conn = libplasma.connect(store_socket_name, + manager_socket_name, + release_delay) + else: + self.conn = libplasma.connect(store_socket_name, "", release_delay) - def create(self, object_id, size, metadata=None): - """Create a new buffer in the PlasmaStore for a particular object ID. + def shutdown(self): + """Shutdown the client so that it does not send messages. - The returned buffer is mutable until seal is called. + If we kill the Plasma store and Plasma manager that this client is + connected to, then we can use this method to prevent the client from + trying to send messages to the killed processes. + """ + if self.alive: + libplasma.disconnect(self.conn) + self.alive = False - Args: - object_id (str): A string used to identify an object. - size (int): The size in bytes of the created buffer. - metadata (buffer): An optional buffer encoding whatever metadata the user - wishes to encode. + def create(self, object_id, size, metadata=None): + """Create a new buffer in the PlasmaStore for a particular object ID. - Raises: - plasma_object_exists_error: This exception is raised if the object could - not be created because there already is an object with the same ID in - the plasma store. - plasma_out_of_memory_error: This exception is raised if the object could - not be created because the plasma store is unable to evict enough - objects to create room for it. - """ - # Turn the metadata into the right type. - metadata = bytearray(b"") if metadata is None else metadata - buff = libplasma.create(self.conn, object_id, size, metadata) - return PlasmaBuffer(buff, object_id, self) + The returned buffer is mutable until seal is called. - def get(self, object_ids, timeout_ms=-1): - """Create a buffer from the PlasmaStore based on object ID. + Args: + object_id (str): A string used to identify an object. + size (int): The size in bytes of the created buffer. + metadata (buffer): An optional buffer encoding whatever metadata the + user wishes to encode. - If the object has not been sealed yet, this call will block. The retrieved - buffer is immutable. + Raises: + plasma_object_exists_error: This exception is raised if the object + could not be created because there already is an object with the + same ID in the plasma store. + plasma_out_of_memory_error: This exception is raised if the object + could not be created because the plasma store is unable to evict + enough objects to create room for it. + """ + # Turn the metadata into the right type. + metadata = bytearray(b"") if metadata is None else metadata + buff = libplasma.create(self.conn, object_id, size, metadata) + return PlasmaBuffer(buff, object_id, self) - Args: - object_ids (List[str]): A list of strings used to identify some objects. - timeout_ms (int): The number of milliseconds that the get call should - block before timing out and returning. Pass -1 if the call should block - and 0 if the call should return immediately. - """ - results = libplasma.get(self.conn, object_ids, timeout_ms) - assert len(object_ids) == len(results) - returns = [] - for i in range(len(object_ids)): - if results[i] is None: - returns.append(None) - else: - returns.append(PlasmaBuffer(results[i][0], object_ids[i], self)) - return returns + def get(self, object_ids, timeout_ms=-1): + """Create a buffer from the PlasmaStore based on object ID. - def get_metadata(self, object_ids, timeout_ms=-1): - """Create a buffer from the PlasmaStore based on object ID. + If the object has not been sealed yet, this call will block. The + retrieved buffer is immutable. - If the object has not been sealed yet, this call will block until the - object has been sealed. The retrieved buffer is immutable. + Args: + object_ids (List[str]): A list of strings used to identify some + objects. + timeout_ms (int): The number of milliseconds that the get call should + block before timing out and returning. Pass -1 if the call should + block and 0 if the call should return immediately. + """ + results = libplasma.get(self.conn, object_ids, timeout_ms) + assert len(object_ids) == len(results) + returns = [] + for i in range(len(object_ids)): + if results[i] is None: + returns.append(None) + else: + returns.append(PlasmaBuffer(results[i][0], object_ids[i], + self)) + return returns - Args: - object_ids (List[str]): A list of strings used to identify some objects. - timeout_ms (int): The number of milliseconds that the get call should - block before timing out and returning. Pass -1 if the call should block - and 0 if the call should return immediately. - """ - results = libplasma.get(self.conn, object_ids, timeout_ms) - assert len(object_ids) == len(results) - returns = [] - for i in range(len(object_ids)): - if results[i] is None: - returns.append(None) - else: - returns.append(PlasmaBuffer(results[i][1], object_ids[i], self)) - return returns + def get_metadata(self, object_ids, timeout_ms=-1): + """Create a buffer from the PlasmaStore based on object ID. - def contains(self, object_id): - """Check if the object is present and has been sealed in the PlasmaStore. + If the object has not been sealed yet, this call will block until the + object has been sealed. The retrieved buffer is immutable. - Args: - object_id (str): A string used to identify an object. - """ - return libplasma.contains(self.conn, object_id) + Args: + object_ids (List[str]): A list of strings used to identify some + objects. + timeout_ms (int): The number of milliseconds that the get call should + block before timing out and returning. Pass -1 if the call should + block and 0 if the call should return immediately. + """ + results = libplasma.get(self.conn, object_ids, timeout_ms) + assert len(object_ids) == len(results) + returns = [] + for i in range(len(object_ids)): + if results[i] is None: + returns.append(None) + else: + returns.append(PlasmaBuffer(results[i][1], object_ids[i], + self)) + return returns - def hash(self, object_id): - """Compute the hash of an object in the object store. + def contains(self, object_id): + """Check if the object is present and has been sealed. - Args: - object_id (str): A string used to identify an object. + Args: + object_id (str): A string used to identify an object. + """ + return libplasma.contains(self.conn, object_id) - Returns: - A digest string object's SHA256 hash. If the object isn't in the object - store, the string will have length zero. - """ - return libplasma.hash(self.conn, object_id) + def hash(self, object_id): + """Compute the hash of an object in the object store. - def seal(self, object_id): - """Seal the buffer in the PlasmaStore for a particular object ID. + Args: + object_id (str): A string used to identify an object. - Once a buffer has been sealed, the buffer is immutable and can only be - accessed through get. + Returns: + A digest string object's SHA256 hash. If the object isn't in the + object store, the string will have length zero. + """ + return libplasma.hash(self.conn, object_id) - Args: - object_id (str): A string used to identify an object. - """ - libplasma.seal(self.conn, object_id) + def seal(self, object_id): + """Seal the buffer in the PlasmaStore for a particular object ID. - def delete(self, object_id): - """Delete the buffer in the PlasmaStore for a particular object ID. + Once a buffer has been sealed, the buffer is immutable and can only be + accessed through get. - Once a buffer has been deleted, the buffer is no longer accessible. + Args: + object_id (str): A string used to identify an object. + """ + libplasma.seal(self.conn, object_id) - Args: - object_id (str): A string used to identify an object. - """ - libplasma.delete(self.conn, object_id) + def delete(self, object_id): + """Delete the buffer in the PlasmaStore for a particular object ID. - def evict(self, num_bytes): - """Evict some objects until to recover some bytes. + Once a buffer has been deleted, the buffer is no longer accessible. - Recover at least num_bytes bytes if possible. + Args: + object_id (str): A string used to identify an object. + """ + libplasma.delete(self.conn, object_id) - Args: - num_bytes (int): The number of bytes to attempt to recover. - """ - return libplasma.evict(self.conn, num_bytes) + def evict(self, num_bytes): + """Evict some objects until to recover some bytes. - def transfer(self, addr, port, object_id): - """Transfer local object with id object_id to another plasma instance + Recover at least num_bytes bytes if possible. - Args: - addr (str): IPv4 address of the plasma instance the object is sent to. - port (int): Port number of the plasma instance the object is sent to. - object_id (str): A string used to identify an object. - """ - return libplasma.transfer(self.conn, object_id, addr, port) + Args: + num_bytes (int): The number of bytes to attempt to recover. + """ + return libplasma.evict(self.conn, num_bytes) - def fetch(self, object_ids): - """Fetch the objects with the given IDs from other plasma manager instances. + def transfer(self, addr, port, object_id): + """Transfer local object with id object_id to another plasma instance - Args: - object_ids (List[str]): A list of strings used to identify the objects. - """ - return libplasma.fetch(self.conn, object_ids) + Args: + addr (str): IPv4 address of the plasma instance the object is sent + to. + port (int): Port number of the plasma instance the object is sent to. + object_id (str): A string used to identify an object. + """ + return libplasma.transfer(self.conn, object_id, addr, port) - def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1): - """Wait until num_returns objects in object_ids are ready. + def fetch(self, object_ids): + """Fetch the objects with the given IDs from other plasma managers. - Currently, the object ID arguments to wait must be unique. + Args: + object_ids (List[str]): A list of strings used to identify the + objects. + """ + return libplasma.fetch(self.conn, object_ids) - Args: - object_ids (List[str]): List of object IDs to wait for. - timeout (int): Return to the caller after timeout milliseconds. - num_returns (int): We are waiting for this number of objects to be ready. + def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1): + """Wait until num_returns objects in object_ids are ready. - Returns: - ready_ids, waiting_ids (List[str], List[str]): List of object IDs that - are ready and list of object IDs we might still wait on respectively. - """ - # Check that the object ID arguments are unique. The plasma manager - # currently crashes if given duplicate object IDs. - if len(object_ids) != len(set(object_ids)): - raise Exception("Wait requires a list of unique object IDs.") - ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, - num_returns) - return ready_ids, list(waiting_ids) + Currently, the object ID arguments to wait must be unique. - def subscribe(self): - """Subscribe to notifications about sealed objects.""" - self.notification_fd = libplasma.subscribe(self.conn) + Args: + object_ids (List[str]): List of object IDs to wait for. + timeout (int): Return to the caller after timeout milliseconds. + num_returns (int): We are waiting for this number of objects to be + ready. - def get_next_notification(self): - """Get the next notification from the notification socket.""" - return libplasma.receive_notification(self.notification_fd) + Returns: + ready_ids, waiting_ids (List[str], List[str]): List of object IDs + that are ready and list of object IDs we might still wait on + respectively. + """ + # Check that the object ID arguments are unique. The plasma manager + # currently crashes if given duplicate object IDs. + if len(object_ids) != len(set(object_ids)): + raise Exception("Wait requires a list of unique object IDs.") + ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout, + num_returns) + return ready_ids, list(waiting_ids) + + def subscribe(self): + """Subscribe to notifications about sealed objects.""" + self.notification_fd = libplasma.subscribe(self.conn) + + def get_next_notification(self): + """Get the next notification from the notification socket.""" + return libplasma.receive_notification(self.notification_fd) DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9 def random_name(): - return str(random.randint(0, 99999999)) + return str(random.randint(0, 99999999)) def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None): - """Start a plasma store process. + """Start a plasma store process. - Args: - use_valgrind (bool): True if the plasma store should be started inside of - valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the plasma store should be started inside a - profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. + Args: + use_valgrind (bool): True if the plasma store should be started inside of + valgrind. If this is True, use_profiler must be False. + use_profiler (bool): True if the plasma store should be started inside a + profiler. If this is True, use_valgrind must be False. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. - Return: - A tuple of the name of the plasma store socket and the process ID of the - plasma store process. - """ - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - plasma_store_executable = os.path.join(os.path.abspath( - os.path.dirname(__file__)), - "../core/src/plasma/plasma_store") - plasma_store_name = "/tmp/plasma_store{}".format(random_name()) - command = [plasma_store_executable, - "-s", plasma_store_name, - "-m", str(plasma_store_memory)] - if use_valgrind: - pid = subprocess.Popen(["valgrind", - "--track-origins=yes", - "--leak-check=full", - "--show-leak-kinds=all", - "--leak-check-heuristics=stdstring", - "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return plasma_store_name, pid + Return: + A tuple of the name of the plasma store socket and the process ID of the + plasma store process. + """ + if use_valgrind and use_profiler: + raise Exception("Cannot use valgrind and profiler at the same time.") + plasma_store_executable = os.path.join(os.path.abspath( + os.path.dirname(__file__)), + "../core/src/plasma/plasma_store") + plasma_store_name = "/tmp/plasma_store{}".format(random_name()) + command = [plasma_store_executable, + "-s", plasma_store_name, + "-m", str(plasma_store_memory)] + if use_valgrind: + pid = subprocess.Popen(["valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--leak-check-heuristics=stdstring", + "--error-exitcode=1"] + command, + stdout=stdout_file, stderr=stderr_file) + time.sleep(1.0) + elif use_profiler: + pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, + stdout=stdout_file, stderr=stderr_file) + time.sleep(1.0) + else: + pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + time.sleep(0.1) + return plasma_store_name, pid def new_port(): - return random.randint(10000, 65535) + return random.randint(10000, 65535) def start_plasma_manager(store_name, redis_address, @@ -364,69 +377,71 @@ def start_plasma_manager(store_name, redis_address, num_retries=20, use_valgrind=False, run_profiler=False, stdout_file=None, stderr_file=None): - """Start a plasma manager and return the ports it listens on. + """Start a plasma manager and return the ports it listens on. - Args: - store_name (str): The name of the plasma store socket. - redis_address (str): The address of the Redis server. - node_ip_address (str): The IP address of the node. - plasma_manager_port (int): The port to use for the plasma manager. If this - is not provided, a port will be generated at random. - use_valgrind (bool): True if the Plasma manager should be started inside of - valgrind and False otherwise. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. + Args: + store_name (str): The name of the plasma store socket. + redis_address (str): The address of the Redis server. + node_ip_address (str): The IP address of the node. + plasma_manager_port (int): The port to use for the plasma manager. If + this is not provided, a port will be generated at random. + use_valgrind (bool): True if the Plasma manager should be started inside + of valgrind and False otherwise. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. - Returns: - A tuple of the Plasma manager socket name, the process ID of the Plasma - manager process, and the port that the manager is listening on. + Returns: + A tuple of the Plasma manager socket name, the process ID of the Plasma + manager process, and the port that the manager is listening on. - Raises: - Exception: An exception is raised if the manager could not be started. - """ - plasma_manager_executable = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "../core/src/plasma/plasma_manager") - plasma_manager_name = "/tmp/plasma_manager{}".format(random_name()) - if plasma_manager_port is not None: - if num_retries != 1: - raise Exception("num_retries must be 1 if port is specified.") - else: - plasma_manager_port = new_port() - process = None - counter = 0 - while counter < num_retries: - if counter > 0: - print("Plasma manager failed to start, retrying now.") - command = [plasma_manager_executable, - "-s", store_name, - "-m", plasma_manager_name, - "-h", node_ip_address, - "-p", str(plasma_manager_port), - "-r", redis_address, - ] - if use_valgrind: - process = subprocess.Popen(["valgrind", - "--track-origins=yes", - "--leak-check=full", - "--show-leak-kinds=all", - "--error-exitcode=1"] + command, - stdout=stdout_file, stderr=stderr_file) - elif run_profiler: - process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, stderr=stderr_file) + Raises: + Exception: An exception is raised if the manager could not be started. + """ + plasma_manager_executable = os.path.join( + os.path.abspath(os.path.dirname(__file__)), + "../core/src/plasma/plasma_manager") + plasma_manager_name = "/tmp/plasma_manager{}".format(random_name()) + if plasma_manager_port is not None: + if num_retries != 1: + raise Exception("num_retries must be 1 if port is specified.") else: - process = subprocess.Popen(command, stdout=stdout_file, - stderr=stderr_file) - # This sleep is critical. If the plasma_manager fails to start because the - # port is already in use, then we need it to fail within 0.1 seconds. - time.sleep(0.1) - # See if the process has terminated - if process.poll() is None: - return plasma_manager_name, process, plasma_manager_port - # Generate a new port and try again. - plasma_manager_port = new_port() - counter += 1 - raise Exception("Couldn't start plasma manager.") + plasma_manager_port = new_port() + process = None + counter = 0 + while counter < num_retries: + if counter > 0: + print("Plasma manager failed to start, retrying now.") + command = [plasma_manager_executable, + "-s", store_name, + "-m", plasma_manager_name, + "-h", node_ip_address, + "-p", str(plasma_manager_port), + "-r", redis_address, + ] + if use_valgrind: + process = subprocess.Popen(["valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--error-exitcode=1"] + command, + stdout=stdout_file, stderr=stderr_file) + elif run_profiler: + process = subprocess.Popen((["valgrind", "--tool=callgrind"] + + command), + stdout=stdout_file, stderr=stderr_file) + else: + process = subprocess.Popen(command, stdout=stdout_file, + stderr=stderr_file) + # This sleep is critical. If the plasma_manager fails to start because + # the port is already in use, then we need it to fail within 0.1 + # seconds. + time.sleep(0.1) + # See if the process has terminated + if process.poll() is None: + return plasma_manager_name, process, plasma_manager_port + # Generate a new port and try again. + plasma_manager_port = new_port() + counter += 1 + raise Exception("Couldn't start plasma manager.") diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py index 94a670626..a56dc60a2 100644 --- a/python/ray/plasma/test/test.py +++ b/python/ray/plasma/test/test.py @@ -22,834 +22,872 @@ PLASMA_STORE_MEMORY = 1000000000 def assert_get_object_equal(unit_test, client1, client2, object_id, memory_buffer=None, metadata=None): - client1_buff = client1.get([object_id])[0] - client2_buff = client2.get([object_id])[0] - client1_metadata = client1.get_metadata([object_id])[0] - client2_metadata = client2.get_metadata([object_id])[0] - unit_test.assertEqual(len(client1_buff), len(client2_buff)) - unit_test.assertEqual(len(client1_metadata), len(client2_metadata)) - # Check that the buffers from the two clients are the same. - unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff)) - # Check that the metadata buffers from the two clients are the same. - unit_test.assertTrue(plasma.buffers_equal(client1_metadata, - client2_metadata)) - # If a reference buffer was provided, check that it is the same as well. - if memory_buffer is not None: - unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff)) - # If reference metadata was provided, check that it is the same as well. - if metadata is not None: - unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata)) + client1_buff = client1.get([object_id])[0] + client2_buff = client2.get([object_id])[0] + client1_metadata = client1.get_metadata([object_id])[0] + client2_metadata = client2.get_metadata([object_id])[0] + unit_test.assertEqual(len(client1_buff), len(client2_buff)) + unit_test.assertEqual(len(client1_metadata), len(client2_metadata)) + # Check that the buffers from the two clients are the same. + unit_test.assertTrue(plasma.buffers_equal(client1_buff, client2_buff)) + # Check that the metadata buffers from the two clients are the same. + unit_test.assertTrue(plasma.buffers_equal(client1_metadata, + client2_metadata)) + # If a reference buffer was provided, check that it is the same as well. + if memory_buffer is not None: + unit_test.assertTrue(plasma.buffers_equal(memory_buffer, client1_buff)) + # If reference metadata was provided, check that it is the same as well. + if metadata is not None: + unit_test.assertTrue(plasma.buffers_equal(metadata, client1_metadata)) class TestPlasmaClient(unittest.TestCase): - def setUp(self): - # Start Plasma store. - plasma_store_name, self.p = plasma.start_plasma_store( - use_valgrind=USE_VALGRIND) - # Connect to Plasma. - self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64) - # For the eviction test - self.plasma_client2 = plasma.PlasmaClient(plasma_store_name, None, 0) + def setUp(self): + # Start Plasma store. + plasma_store_name, self.p = plasma.start_plasma_store( + use_valgrind=USE_VALGRIND) + # Connect to Plasma. + self.plasma_client = plasma.PlasmaClient(plasma_store_name, None, 64) + # For the eviction test + self.plasma_client2 = plasma.PlasmaClient(plasma_store_name, None, 0) - def tearDown(self): - # Check that the Plasma store is still alive. - self.assertEqual(self.p.poll(), None) - # Kill the plasma store process. - if USE_VALGRIND: - self.p.send_signal(signal.SIGTERM) - self.p.wait() - if self.p.returncode != 0: - os._exit(-1) - else: - self.p.kill() - - def test_create(self): - # Create an object id string. - object_id = random_object_id() - # Create a new buffer and write to it. - length = 50 - memory_buffer = self.plasma_client.create(object_id, length) - for i in range(length): - memory_buffer[i] = chr(i % 256) - # Seal the object. - self.plasma_client.seal(object_id) - # Get the object. - memory_buffer = self.plasma_client.get([object_id])[0] - for i in range(length): - self.assertEqual(memory_buffer[i], chr(i % 256)) - - def test_create_with_metadata(self): - for length in range(1000): - # Create an object id string. - object_id = random_object_id() - # Create a random metadata string. - metadata = generate_metadata(length) - # Create a new buffer and write to it. - memory_buffer = self.plasma_client.create(object_id, length, metadata) - for i in range(length): - memory_buffer[i] = chr(i % 256) - # Seal the object. - self.plasma_client.seal(object_id) - # Get the object. - memory_buffer = self.plasma_client.get([object_id])[0] - for i in range(length): - self.assertEqual(memory_buffer[i], chr(i % 256)) - # Get the metadata. - metadata_buffer = self.plasma_client.get_metadata([object_id])[0] - self.assertEqual(len(metadata), len(metadata_buffer)) - for i in range(len(metadata)): - self.assertEqual(chr(metadata[i]), metadata_buffer[i]) - - def test_create_existing(self): - # This test is partially used to test the code path in which we create an - # object with an ID that already exists - length = 100 - for _ in range(1000): - object_id = random_object_id() - self.plasma_client.create(object_id, length, generate_metadata(length)) - try: - self.plasma_client.create(object_id, length, generate_metadata(length)) - except plasma.plasma_object_exists_error as e: - pass - else: - self.assertTrue(False) - - def test_get(self): - num_object_ids = 100 - # Test timing out of get with various timeouts. - for timeout in [0, 10, 100, 1000]: - object_ids = [random_object_id() for _ in range(num_object_ids)] - results = self.plasma_client.get(object_ids, timeout_ms=timeout) - self.assertEqual(results, num_object_ids * [None]) - - data_buffers = [] - metadata_buffers = [] - for i in range(num_object_ids): - if i % 2 == 0: - data_buffer, metadata_buffer = create_object_with_id( - self.plasma_client, object_ids[i], 2000, 2000) - data_buffers.append(data_buffer) - metadata_buffers.append(metadata_buffer) - - # Test timing out from some but not all get calls with various timeouts. - for timeout in [0, 10, 100, 1000]: - data_results = self.plasma_client.get(object_ids, timeout_ms=timeout) - # metadata_results = self.plasma_client.get_metadata(object_ids, - # timeout_ms=timeout) - for i in range(num_object_ids): - if i % 2 == 0: - self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], - data_results[i])) - # TODO(rkn): We should compare the metadata as well. But currently - # the types are different (e.g., memoryview versus bytearray). - # self.assertTrue(plasma.buffers_equal(metadata_buffers[i // 2], - # metadata_results[i])) + def tearDown(self): + # Check that the Plasma store is still alive. + self.assertEqual(self.p.poll(), None) + # Kill the plasma store process. + if USE_VALGRIND: + self.p.send_signal(signal.SIGTERM) + self.p.wait() + if self.p.returncode != 0: + os._exit(-1) else: - self.assertIsNone(results[i]) + self.p.kill() - def test_store_full(self): - # The store is started with 1GB, so make sure that create throws an - # exception when it is full. - def assert_create_raises_plasma_full(unit_test, size): - partial_size = np.random.randint(size) - try: - _, memory_buffer, _ = create_object(unit_test.plasma_client, - partial_size, - size - partial_size) - except plasma.plasma_out_of_memory_error as e: - pass - else: - # For some reason the above didn't throw an exception, so fail. - unit_test.assertTrue(False) + def test_create(self): + # Create an object id string. + object_id = random_object_id() + # Create a new buffer and write to it. + length = 50 + memory_buffer = self.plasma_client.create(object_id, length) + for i in range(length): + memory_buffer[i] = chr(i % 256) + # Seal the object. + self.plasma_client.seal(object_id) + # Get the object. + memory_buffer = self.plasma_client.get([object_id])[0] + for i in range(length): + self.assertEqual(memory_buffer[i], chr(i % 256)) - # Create a list to keep some of the buffers in scope. - memory_buffers = [] - _, memory_buffer, _ = create_object(self.plasma_client, 5 * 10 ** 8, 0) - memory_buffers.append(memory_buffer) - # Remaining space is 5 * 10 ** 8. Make sure that we can't create an object - # of size 5 * 10 ** 8 + 1, but we can create one of size 2 * 10 ** 8. - assert_create_raises_plasma_full(self, 5 * 10 ** 8 + 1) - _, memory_buffer, _ = create_object(self.plasma_client, 2 * 10 ** 8, 0) - del memory_buffer - _, memory_buffer, _ = create_object(self.plasma_client, 2 * 10 ** 8, 0) - del memory_buffer - assert_create_raises_plasma_full(self, 5 * 10 ** 8 + 1) + def test_create_with_metadata(self): + for length in range(1000): + # Create an object id string. + object_id = random_object_id() + # Create a random metadata string. + metadata = generate_metadata(length) + # Create a new buffer and write to it. + memory_buffer = self.plasma_client.create(object_id, length, + metadata) + for i in range(length): + memory_buffer[i] = chr(i % 256) + # Seal the object. + self.plasma_client.seal(object_id) + # Get the object. + memory_buffer = self.plasma_client.get([object_id])[0] + for i in range(length): + self.assertEqual(memory_buffer[i], chr(i % 256)) + # Get the metadata. + metadata_buffer = self.plasma_client.get_metadata([object_id])[0] + self.assertEqual(len(metadata), len(metadata_buffer)) + for i in range(len(metadata)): + self.assertEqual(chr(metadata[i]), metadata_buffer[i]) - _, memory_buffer, _ = create_object(self.plasma_client, 2 * 10 ** 8, 0) - memory_buffers.append(memory_buffer) - # Remaining space is 3 * 10 ** 8. - assert_create_raises_plasma_full(self, 3 * 10 ** 8 + 1) + def test_create_existing(self): + # This test is partially used to test the code path in which we create + # an object with an ID that already exists + length = 100 + for _ in range(1000): + object_id = random_object_id() + self.plasma_client.create(object_id, length, + generate_metadata(length)) + try: + self.plasma_client.create(object_id, length, + generate_metadata(length)) + except plasma.plasma_object_exists_error as e: + pass + else: + self.assertTrue(False) - _, memory_buffer, _ = create_object(self.plasma_client, 10 ** 8, 0) - memory_buffers.append(memory_buffer) - # Remaining space is 2 * 10 ** 8. - assert_create_raises_plasma_full(self, 2 * 10 ** 8 + 1) + def test_get(self): + num_object_ids = 100 + # Test timing out of get with various timeouts. + for timeout in [0, 10, 100, 1000]: + object_ids = [random_object_id() for _ in range(num_object_ids)] + results = self.plasma_client.get(object_ids, timeout_ms=timeout) + self.assertEqual(results, num_object_ids * [None]) - def test_contains(self): - fake_object_ids = [random_object_id() for _ in range(100)] - real_object_ids = [random_object_id() for _ in range(100)] - for object_id in real_object_ids: - self.assertFalse(self.plasma_client.contains(object_id)) - self.plasma_client.create(object_id, 100) - self.plasma_client.seal(object_id) - self.assertTrue(self.plasma_client.contains(object_id)) - for object_id in fake_object_ids: - self.assertFalse(self.plasma_client.contains(object_id)) - for object_id in real_object_ids: - self.assertTrue(self.plasma_client.contains(object_id)) + data_buffers = [] + metadata_buffers = [] + for i in range(num_object_ids): + if i % 2 == 0: + data_buffer, metadata_buffer = create_object_with_id( + self.plasma_client, object_ids[i], 2000, 2000) + data_buffers.append(data_buffer) + metadata_buffers.append(metadata_buffer) - def test_hash(self): - # Check the hash of an object that doesn't exist. - object_id1 = random_object_id() - self.plasma_client.hash(object_id1) + # Test timing out from some but not all get calls with various + # timeouts. + for timeout in [0, 10, 100, 1000]: + data_results = self.plasma_client.get(object_ids, + timeout_ms=timeout) + for i in range(num_object_ids): + if i % 2 == 0: + self.assertTrue(plasma.buffers_equal(data_buffers[i // 2], + data_results[i])) + else: + self.assertIsNone(results[i]) - length = 1000 - # Create a random object, and check that the hash function always returns - # the same value. - metadata = generate_metadata(length) - memory_buffer = self.plasma_client.create(object_id1, length, metadata) - for i in range(length): - memory_buffer[i] = chr(i % 256) - self.plasma_client.seal(object_id1) - self.assertEqual(self.plasma_client.hash(object_id1), - self.plasma_client.hash(object_id1)) + def test_store_full(self): + # The store is started with 1GB, so make sure that create throws an + # exception when it is full. + def assert_create_raises_plasma_full(unit_test, size): + partial_size = np.random.randint(size) + try: + _, memory_buffer, _ = create_object(unit_test.plasma_client, + partial_size, + size - partial_size) + except plasma.plasma_out_of_memory_error as e: + pass + else: + # For some reason the above didn't throw an exception, so fail. + unit_test.assertTrue(False) - # Create a second object with the same value as the first, and check that - # their hashes are equal. - object_id2 = random_object_id() - memory_buffer = self.plasma_client.create(object_id2, length, metadata) - for i in range(length): - memory_buffer[i] = chr(i % 256) - self.plasma_client.seal(object_id2) - self.assertEqual(self.plasma_client.hash(object_id1), - self.plasma_client.hash(object_id2)) + # Create a list to keep some of the buffers in scope. + memory_buffers = [] + _, memory_buffer, _ = create_object(self.plasma_client, 5 * 10 ** 8, 0) + memory_buffers.append(memory_buffer) + # Remaining space is 5 * 10 ** 8. Make sure that we can't create an + # object of size 5 * 10 ** 8 + 1, but we can create one of size + # 2 * 10 ** 8. + assert_create_raises_plasma_full(self, 5 * 10 ** 8 + 1) + _, memory_buffer, _ = create_object(self.plasma_client, 2 * 10 ** 8, 0) + del memory_buffer + _, memory_buffer, _ = create_object(self.plasma_client, 2 * 10 ** 8, 0) + del memory_buffer + assert_create_raises_plasma_full(self, 5 * 10 ** 8 + 1) - # Create a third object with a different value from the first two, and - # check that its hash is different. - object_id3 = random_object_id() - metadata = generate_metadata(length) - memory_buffer = self.plasma_client.create(object_id3, length, metadata) - for i in range(length): - memory_buffer[i] = chr((i + 1) % 256) - self.plasma_client.seal(object_id3) - self.assertNotEqual(self.plasma_client.hash(object_id1), - self.plasma_client.hash(object_id3)) + _, memory_buffer, _ = create_object(self.plasma_client, 2 * 10 ** 8, 0) + memory_buffers.append(memory_buffer) + # Remaining space is 3 * 10 ** 8. + assert_create_raises_plasma_full(self, 3 * 10 ** 8 + 1) - # Create a fourth object with the same value as the third, but different - # metadata. Check that its hash is different from any of the previous - # three. - object_id4 = random_object_id() - metadata4 = generate_metadata(length) - memory_buffer = self.plasma_client.create(object_id4, length, metadata4) - for i in range(length): - memory_buffer[i] = chr((i + 1) % 256) - self.plasma_client.seal(object_id4) - self.assertNotEqual(self.plasma_client.hash(object_id1), - self.plasma_client.hash(object_id4)) - self.assertNotEqual(self.plasma_client.hash(object_id3), - self.plasma_client.hash(object_id4)) + _, memory_buffer, _ = create_object(self.plasma_client, 10 ** 8, 0) + memory_buffers.append(memory_buffer) + # Remaining space is 2 * 10 ** 8. + assert_create_raises_plasma_full(self, 2 * 10 ** 8 + 1) - def test_many_hashes(self): - hashes = [] - length = 2 ** 10 + def test_contains(self): + fake_object_ids = [random_object_id() for _ in range(100)] + real_object_ids = [random_object_id() for _ in range(100)] + for object_id in real_object_ids: + self.assertFalse(self.plasma_client.contains(object_id)) + self.plasma_client.create(object_id, 100) + self.plasma_client.seal(object_id) + self.assertTrue(self.plasma_client.contains(object_id)) + for object_id in fake_object_ids: + self.assertFalse(self.plasma_client.contains(object_id)) + for object_id in real_object_ids: + self.assertTrue(self.plasma_client.contains(object_id)) - for i in range(256): - object_id = random_object_id() - memory_buffer = self.plasma_client.create(object_id, length) - for j in range(length): - memory_buffer[j] = chr(i) - self.plasma_client.seal(object_id) - hashes.append(self.plasma_client.hash(object_id)) + def test_hash(self): + # Check the hash of an object that doesn't exist. + object_id1 = random_object_id() + self.plasma_client.hash(object_id1) - # Create objects of varying length. Each pair has two bits different. - for i in range(length): - object_id = random_object_id() - memory_buffer = self.plasma_client.create(object_id, length) - for j in range(length): - memory_buffer[j] = chr(0) - memory_buffer[i] = chr(1) - self.plasma_client.seal(object_id) - hashes.append(self.plasma_client.hash(object_id)) + length = 1000 + # Create a random object, and check that the hash function always + # returns the same value. + metadata = generate_metadata(length) + memory_buffer = self.plasma_client.create(object_id1, length, metadata) + for i in range(length): + memory_buffer[i] = chr(i % 256) + self.plasma_client.seal(object_id1) + self.assertEqual(self.plasma_client.hash(object_id1), + self.plasma_client.hash(object_id1)) - # Create objects of varying length, all with value 0. - for i in range(length): - object_id = random_object_id() - memory_buffer = self.plasma_client.create(object_id, i) - for j in range(i): - memory_buffer[j] = chr(0) - self.plasma_client.seal(object_id) - hashes.append(self.plasma_client.hash(object_id)) + # Create a second object with the same value as the first, and check + # that their hashes are equal. + object_id2 = random_object_id() + memory_buffer = self.plasma_client.create(object_id2, length, metadata) + for i in range(length): + memory_buffer[i] = chr(i % 256) + self.plasma_client.seal(object_id2) + self.assertEqual(self.plasma_client.hash(object_id1), + self.plasma_client.hash(object_id2)) - # Check that all hashes were unique. - self.assertEqual(len(set(hashes)), 256 + length + length) + # Create a third object with a different value from the first two, and + # check that its hash is different. + object_id3 = random_object_id() + metadata = generate_metadata(length) + memory_buffer = self.plasma_client.create(object_id3, length, metadata) + for i in range(length): + memory_buffer[i] = chr((i + 1) % 256) + self.plasma_client.seal(object_id3) + self.assertNotEqual(self.plasma_client.hash(object_id1), + self.plasma_client.hash(object_id3)) - # def test_individual_delete(self): - # length = 100 - # # Create an object id string. - # object_id = random_object_id() - # # Create a random metadata string. - # metadata = generate_metadata(100) - # # Create a new buffer and write to it. - # memory_buffer = self.plasma_client.create(object_id, length, metadata) - # for i in range(length): - # memory_buffer[i] = chr(i % 256) - # # Seal the object. - # self.plasma_client.seal(object_id) - # # Check that the object is present. - # self.assertTrue(self.plasma_client.contains(object_id)) - # # Delete the object. - # self.plasma_client.delete(object_id) - # # Make sure the object is no longer present. - # self.assertFalse(self.plasma_client.contains(object_id)) - # - # def test_delete(self): - # # Create some objects. - # object_ids = [random_object_id() for _ in range(100)] - # for object_id in object_ids: - # length = 100 - # # Create a random metadata string. - # metadata = generate_metadata(100) - # # Create a new buffer and write to it. - # memory_buffer = self.plasma_client.create(object_id, length, metadata) - # for i in range(length): - # memory_buffer[i] = chr(i % 256) - # # Seal the object. - # self.plasma_client.seal(object_id) - # # Check that the object is present. - # self.assertTrue(self.plasma_client.contains(object_id)) - # - # # Delete the objects and make sure they are no longer present. - # for object_id in object_ids: - # # Delete the object. - # self.plasma_client.delete(object_id) - # # Make sure the object is no longer present. - # self.assertFalse(self.plasma_client.contains(object_id)) + # Create a fourth object with the same value as the third, but + # different metadata. Check that its hash is different from any of the + # previous three. + object_id4 = random_object_id() + metadata4 = generate_metadata(length) + memory_buffer = self.plasma_client.create(object_id4, length, + metadata4) + for i in range(length): + memory_buffer[i] = chr((i + 1) % 256) + self.plasma_client.seal(object_id4) + self.assertNotEqual(self.plasma_client.hash(object_id1), + self.plasma_client.hash(object_id4)) + self.assertNotEqual(self.plasma_client.hash(object_id3), + self.plasma_client.hash(object_id4)) - def test_illegal_functionality(self): - # Create an object id string. - object_id = random_object_id() - # Create a new buffer and write to it. - length = 1000 - memory_buffer = self.plasma_client.create(object_id, length) - # Make sure we cannot access memory out of bounds. - self.assertRaises(Exception, lambda: memory_buffer[length]) - # Seal the object. - self.plasma_client.seal(object_id) - # This test is commented out because it currently fails. - # # Make sure the object is ready only now. - # def illegal_assignment(): - # memory_buffer[0] = chr(0) - # self.assertRaises(Exception, illegal_assignment) - # Get the object. - memory_buffer = self.plasma_client.get([object_id])[0] + def test_many_hashes(self): + hashes = [] + length = 2 ** 10 - # Make sure the object is read only. - def illegal_assignment(): - memory_buffer[0] = chr(0) - self.assertRaises(Exception, illegal_assignment) + for i in range(256): + object_id = random_object_id() + memory_buffer = self.plasma_client.create(object_id, length) + for j in range(length): + memory_buffer[j] = chr(i) + self.plasma_client.seal(object_id) + hashes.append(self.plasma_client.hash(object_id)) - def test_evict(self): - client = self.plasma_client2 - object_id1 = random_object_id() - b1 = client.create(object_id1, 1000) - client.seal(object_id1) - del b1 - self.assertEqual(client.evict(1), 1000) + # Create objects of varying length. Each pair has two bits different. + for i in range(length): + object_id = random_object_id() + memory_buffer = self.plasma_client.create(object_id, length) + for j in range(length): + memory_buffer[j] = chr(0) + memory_buffer[i] = chr(1) + self.plasma_client.seal(object_id) + hashes.append(self.plasma_client.hash(object_id)) - object_id2 = random_object_id() - object_id3 = random_object_id() - b2 = client.create(object_id2, 999) - b3 = client.create(object_id3, 998) - client.seal(object_id3) - del b3 - self.assertEqual(client.evict(1000), 998) + # Create objects of varying length, all with value 0. + for i in range(length): + object_id = random_object_id() + memory_buffer = self.plasma_client.create(object_id, i) + for j in range(i): + memory_buffer[j] = chr(0) + self.plasma_client.seal(object_id) + hashes.append(self.plasma_client.hash(object_id)) - object_id4 = random_object_id() - b4 = client.create(object_id4, 997) - client.seal(object_id4) - del b4 - client.seal(object_id2) - del b2 - self.assertEqual(client.evict(1), 997) - self.assertEqual(client.evict(1), 999) + # Check that all hashes were unique. + self.assertEqual(len(set(hashes)), 256 + length + length) - object_id5 = random_object_id() - object_id6 = random_object_id() - object_id7 = random_object_id() - b5 = client.create(object_id5, 996) - b6 = client.create(object_id6, 995) - b7 = client.create(object_id7, 994) - client.seal(object_id5) - client.seal(object_id6) - client.seal(object_id7) - del b5 - del b6 - del b7 - self.assertEqual(client.evict(2000), 996 + 995 + 994) + # def test_individual_delete(self): + # length = 100 + # # Create an object id string. + # object_id = random_object_id() + # # Create a random metadata string. + # metadata = generate_metadata(100) + # # Create a new buffer and write to it. + # memory_buffer = self.plasma_client.create(object_id, length, metadata) + # for i in range(length): + # memory_buffer[i] = chr(i % 256) + # # Seal the object. + # self.plasma_client.seal(object_id) + # # Check that the object is present. + # self.assertTrue(self.plasma_client.contains(object_id)) + # # Delete the object. + # self.plasma_client.delete(object_id) + # # Make sure the object is no longer present. + # self.assertFalse(self.plasma_client.contains(object_id)) + # + # def test_delete(self): + # # Create some objects. + # object_ids = [random_object_id() for _ in range(100)] + # for object_id in object_ids: + # length = 100 + # # Create a random metadata string. + # metadata = generate_metadata(100) + # # Create a new buffer and write to it. + # memory_buffer = self.plasma_client.create(object_id, length, + # metadata) + # for i in range(length): + # memory_buffer[i] = chr(i % 256) + # # Seal the object. + # self.plasma_client.seal(object_id) + # # Check that the object is present. + # self.assertTrue(self.plasma_client.contains(object_id)) + # + # # Delete the objects and make sure they are no longer present. + # for object_id in object_ids: + # # Delete the object. + # self.plasma_client.delete(object_id) + # # Make sure the object is no longer present. + # self.assertFalse(self.plasma_client.contains(object_id)) - def test_subscribe(self): - # Subscribe to notifications from the Plasma Store. - self.plasma_client.subscribe() - for i in [1, 10, 100, 1000, 10000, 100000]: - object_ids = [random_object_id() for _ in range(i)] - metadata_sizes = [np.random.randint(1000) for _ in range(i)] - data_sizes = [np.random.randint(1000) for _ in range(i)] - for j in range(i): - self.plasma_client.create( - object_ids[j], size=data_sizes[j], - metadata=bytearray(np.random.bytes(metadata_sizes[j]))) - self.plasma_client.seal(object_ids[j]) - # Check that we received notifications for all of the objects. - for j in range(i): - notification_info = self.plasma_client.get_next_notification() - recv_objid, recv_dsize, recv_msize = notification_info - self.assertEqual(object_ids[j], recv_objid) - self.assertEqual(data_sizes[j], recv_dsize) - self.assertEqual(metadata_sizes[j], recv_msize) + def test_illegal_functionality(self): + # Create an object id string. + object_id = random_object_id() + # Create a new buffer and write to it. + length = 1000 + memory_buffer = self.plasma_client.create(object_id, length) + # Make sure we cannot access memory out of bounds. + self.assertRaises(Exception, lambda: memory_buffer[length]) + # Seal the object. + self.plasma_client.seal(object_id) + # This test is commented out because it currently fails. + # # Make sure the object is ready only now. + # def illegal_assignment(): + # memory_buffer[0] = chr(0) + # self.assertRaises(Exception, illegal_assignment) + # Get the object. + memory_buffer = self.plasma_client.get([object_id])[0] - def test_subscribe_deletions(self): - # Subscribe to notifications from the Plasma Store. We use plasma_client2 - # to make sure that all used objects will get evicted properly. - self.plasma_client2.subscribe() - for i in [1, 10, 100, 1000, 10000, 100000]: - object_ids = [random_object_id() for _ in range(i)] - # Add 1 to the sizes to make sure we have nonzero object sizes. - metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)] - data_sizes = [np.random.randint(1000) + 1 for _ in range(i)] - for j in range(i): - x = self.plasma_client2.create( - object_ids[j], size=data_sizes[j], - metadata=bytearray(np.random.bytes(metadata_sizes[j]))) - self.plasma_client2.seal(object_ids[j]) - del x - # Check that we received notifications for creating all of the objects. - for j in range(i): - notification_info = self.plasma_client2.get_next_notification() - recv_objid, recv_dsize, recv_msize = notification_info - self.assertEqual(object_ids[j], recv_objid) - self.assertEqual(data_sizes[j], recv_dsize) - self.assertEqual(metadata_sizes[j], recv_msize) + # Make sure the object is read only. + def illegal_assignment(): + memory_buffer[0] = chr(0) + self.assertRaises(Exception, illegal_assignment) - # Check that we receive notifications for deleting all objects, as we - # evict them. - for j in range(i): + def test_evict(self): + client = self.plasma_client2 + object_id1 = random_object_id() + b1 = client.create(object_id1, 1000) + client.seal(object_id1) + del b1 + self.assertEqual(client.evict(1), 1000) + + object_id2 = random_object_id() + object_id3 = random_object_id() + b2 = client.create(object_id2, 999) + b3 = client.create(object_id3, 998) + client.seal(object_id3) + del b3 + self.assertEqual(client.evict(1000), 998) + + object_id4 = random_object_id() + b4 = client.create(object_id4, 997) + client.seal(object_id4) + del b4 + client.seal(object_id2) + del b2 + self.assertEqual(client.evict(1), 997) + self.assertEqual(client.evict(1), 999) + + object_id5 = random_object_id() + object_id6 = random_object_id() + object_id7 = random_object_id() + b5 = client.create(object_id5, 996) + b6 = client.create(object_id6, 995) + b7 = client.create(object_id7, 994) + client.seal(object_id5) + client.seal(object_id6) + client.seal(object_id7) + del b5 + del b6 + del b7 + self.assertEqual(client.evict(2000), 996 + 995 + 994) + + def test_subscribe(self): + # Subscribe to notifications from the Plasma Store. + self.plasma_client.subscribe() + for i in [1, 10, 100, 1000, 10000, 100000]: + object_ids = [random_object_id() for _ in range(i)] + metadata_sizes = [np.random.randint(1000) for _ in range(i)] + data_sizes = [np.random.randint(1000) for _ in range(i)] + for j in range(i): + self.plasma_client.create( + object_ids[j], size=data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + self.plasma_client.seal(object_ids[j]) + # Check that we received notifications for all of the objects. + for j in range(i): + notification_info = self.plasma_client.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + self.assertEqual(object_ids[j], recv_objid) + self.assertEqual(data_sizes[j], recv_dsize) + self.assertEqual(metadata_sizes[j], recv_msize) + + def test_subscribe_deletions(self): + # Subscribe to notifications from the Plasma Store. We use + # plasma_client2 to make sure that all used objects will get evicted + # properly. + self.plasma_client2.subscribe() + for i in [1, 10, 100, 1000, 10000, 100000]: + object_ids = [random_object_id() for _ in range(i)] + # Add 1 to the sizes to make sure we have nonzero object sizes. + metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)] + data_sizes = [np.random.randint(1000) + 1 for _ in range(i)] + for j in range(i): + x = self.plasma_client2.create( + object_ids[j], size=data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + self.plasma_client2.seal(object_ids[j]) + del x + # Check that we received notifications for creating all of the + # objects. + for j in range(i): + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + self.assertEqual(object_ids[j], recv_objid) + self.assertEqual(data_sizes[j], recv_dsize) + self.assertEqual(metadata_sizes[j], recv_msize) + + # Check that we receive notifications for deleting all objects, as + # we evict them. + for j in range(i): + self.assertEqual(self.plasma_client2.evict(1), + data_sizes[j] + metadata_sizes[j]) + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + self.assertEqual(object_ids[j], recv_objid) + self.assertEqual(-1, recv_dsize) + self.assertEqual(-1, recv_msize) + + # Test multiple deletion notifications. The first 9 object IDs have + # size 0, and the last has a nonzero size. When Plasma evicts 1 byte, + # it will evict all objects, so we should receive deletion + # notifications for each. + num_object_ids = 10 + object_ids = [random_object_id() for _ in range(num_object_ids)] + metadata_sizes = [0] * (num_object_ids - 1) + data_sizes = [0] * (num_object_ids - 1) + metadata_sizes.append(np.random.randint(1000)) + data_sizes.append(np.random.randint(1000)) + for i in range(num_object_ids): + x = self.plasma_client2.create( + object_ids[i], size=data_sizes[i], + metadata=bytearray(np.random.bytes(metadata_sizes[i]))) + self.plasma_client2.seal(object_ids[i]) + del x + for i in range(num_object_ids): + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + self.assertEqual(object_ids[i], recv_objid) + self.assertEqual(data_sizes[i], recv_dsize) + self.assertEqual(metadata_sizes[i], recv_msize) self.assertEqual(self.plasma_client2.evict(1), - data_sizes[j] + metadata_sizes[j]) - notification_info = self.plasma_client2.get_next_notification() - recv_objid, recv_dsize, recv_msize = notification_info - self.assertEqual(object_ids[j], recv_objid) - self.assertEqual(-1, recv_dsize) - self.assertEqual(-1, recv_msize) - - # Test multiple deletion notifications. The first 9 object IDs have size 0, - # and the last has a nonzero size. When Plasma evicts 1 byte, it will evict - # all objects, so we should receive deletion notifications for each. - num_object_ids = 10 - object_ids = [random_object_id() for _ in range(num_object_ids)] - metadata_sizes = [0] * (num_object_ids - 1) - data_sizes = [0] * (num_object_ids - 1) - metadata_sizes.append(np.random.randint(1000)) - data_sizes.append(np.random.randint(1000)) - for i in range(num_object_ids): - x = self.plasma_client2.create( - object_ids[i], size=data_sizes[i], - metadata=bytearray(np.random.bytes(metadata_sizes[i]))) - self.plasma_client2.seal(object_ids[i]) - del x - for i in range(num_object_ids): - notification_info = self.plasma_client2.get_next_notification() - recv_objid, recv_dsize, recv_msize = notification_info - self.assertEqual(object_ids[i], recv_objid) - self.assertEqual(data_sizes[i], recv_dsize) - self.assertEqual(metadata_sizes[i], recv_msize) - self.assertEqual(self.plasma_client2.evict(1), - data_sizes[-1] + metadata_sizes[-1]) - for i in range(num_object_ids): - notification_info = self.plasma_client2.get_next_notification() - recv_objid, recv_dsize, recv_msize = notification_info - self.assertEqual(object_ids[i], recv_objid) - self.assertEqual(-1, recv_dsize) - self.assertEqual(-1, recv_msize) + data_sizes[-1] + metadata_sizes[-1]) + for i in range(num_object_ids): + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + self.assertEqual(object_ids[i], recv_objid) + self.assertEqual(-1, recv_dsize) + self.assertEqual(-1, recv_msize) class TestPlasmaManager(unittest.TestCase): - def setUp(self): - # Start two PlasmaStores. - store_name1, self.p2 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND) - store_name2, self.p3 = plasma.start_plasma_store(use_valgrind=USE_VALGRIND) - # Start a Redis server. - redis_address, _ = services.start_redis("127.0.0.1") - # Start two PlasmaManagers. - manager_name1, self.p4, self.port1 = plasma.start_plasma_manager( - store_name1, redis_address, use_valgrind=USE_VALGRIND) - manager_name2, self.p5, self.port2 = plasma.start_plasma_manager( - store_name2, redis_address, use_valgrind=USE_VALGRIND) - # Connect two PlasmaClients. - self.client1 = plasma.PlasmaClient(store_name1, manager_name1) - self.client2 = plasma.PlasmaClient(store_name2, manager_name2) + def setUp(self): + # Start two PlasmaStores. + store_name1, self.p2 = plasma.start_plasma_store( + use_valgrind=USE_VALGRIND) + store_name2, self.p3 = plasma.start_plasma_store( + use_valgrind=USE_VALGRIND) + # Start a Redis server. + redis_address, _ = services.start_redis("127.0.0.1") + # Start two PlasmaManagers. + manager_name1, self.p4, self.port1 = plasma.start_plasma_manager( + store_name1, redis_address, use_valgrind=USE_VALGRIND) + manager_name2, self.p5, self.port2 = plasma.start_plasma_manager( + store_name2, redis_address, use_valgrind=USE_VALGRIND) + # Connect two PlasmaClients. + self.client1 = plasma.PlasmaClient(store_name1, manager_name1) + self.client2 = plasma.PlasmaClient(store_name2, manager_name2) - # Store the processes that will be explicitly killed during tearDown so - # that a test case can remove ones that will be killed during the test. - # NOTE: If this specific order is changed, valgrind will fail. - self.processes_to_kill = [self.p4, self.p5, self.p2, self.p3] + # Store the processes that will be explicitly killed during tearDown so + # that a test case can remove ones that will be killed during the test. + # NOTE: If this specific order is changed, valgrind will fail. + self.processes_to_kill = [self.p4, self.p5, self.p2, self.p3] - def tearDown(self): - # Check that the processes are still alive. - for process in self.processes_to_kill: - self.assertEqual(process.poll(), None) + def tearDown(self): + # Check that the processes are still alive. + for process in self.processes_to_kill: + self.assertEqual(process.poll(), None) - # Kill the Plasma store and Plasma manager processes. - if USE_VALGRIND: - # Give processes opportunity to finish work. - time.sleep(1) - for process in self.processes_to_kill: - process.send_signal(signal.SIGTERM) - process.wait() - if process.returncode != 0: - print("aborting due to valgrind error") - os._exit(-1) - else: - for process in self.processes_to_kill: - process.kill() + # Kill the Plasma store and Plasma manager processes. + if USE_VALGRIND: + # Give processes opportunity to finish work. + time.sleep(1) + for process in self.processes_to_kill: + process.send_signal(signal.SIGTERM) + process.wait() + if process.returncode != 0: + print("aborting due to valgrind error") + os._exit(-1) + else: + for process in self.processes_to_kill: + process.kill() - # Clean up the Redis server. - services.cleanup() + # Clean up the Redis server. + services.cleanup() - def test_fetch(self): - for _ in range(10): - # Create an object. - object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, - 2000) - self.client1.fetch([object_id1]) - self.assertEqual(self.client1.contains(object_id1), True) - self.assertEqual(self.client2.contains(object_id1), False) - # Fetch the object from the other plasma manager. - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id1): - self.client2.fetch([object_id1]) - # Compare the two buffers. - assert_get_object_equal(self, self.client1, self.client2, object_id1, - memory_buffer=memory_buffer1, metadata=metadata1) + def test_fetch(self): + for _ in range(10): + # Create an object. + object_id1, memory_buffer1, metadata1 = create_object(self.client1, + 2000, + 2000) + self.client1.fetch([object_id1]) + self.assertEqual(self.client1.contains(object_id1), True) + self.assertEqual(self.client2.contains(object_id1), False) + # Fetch the object from the other plasma manager. + # TODO(rkn): Right now we must wait for the object table to be + # updated. + while not self.client2.contains(object_id1): + self.client2.fetch([object_id1]) + # Compare the two buffers. + assert_get_object_equal(self, self.client1, self.client2, + object_id1, + memory_buffer=memory_buffer1, + metadata=metadata1) - # Test that we can call fetch on object IDs that don't exist yet. - object_id2 = random_object_id() - self.client1.fetch([object_id2]) - self.assertEqual(self.client1.contains(object_id2), False) - memory_buffer2, metadata2 = create_object_with_id(self.client2, object_id2, - 2000, 2000) - # # Check that the object has been fetched. - # self.assertEqual(self.client1.contains(object_id2), True) - # Compare the two buffers. - # assert_get_object_equal(self, self.client1, self.client2, object_id2, - # memory_buffer=memory_buffer2, metadata=metadata2) + # Test that we can call fetch on object IDs that don't exist yet. + object_id2 = random_object_id() + self.client1.fetch([object_id2]) + self.assertEqual(self.client1.contains(object_id2), False) + memory_buffer2, metadata2 = create_object_with_id(self.client2, + object_id2, + 2000, 2000) + # # Check that the object has been fetched. + # self.assertEqual(self.client1.contains(object_id2), True) + # Compare the two buffers. + # assert_get_object_equal(self, self.client1, self.client2, object_id2, + # memory_buffer=memory_buffer2, + # metadata=metadata2) - # Test calling the same fetch request a bunch of times. - object_id3 = random_object_id() - self.assertEqual(self.client1.contains(object_id3), False) - self.assertEqual(self.client2.contains(object_id3), False) - for _ in range(10): - self.client1.fetch([object_id3]) - self.client2.fetch([object_id3]) - memory_buffer3, metadata3 = create_object_with_id(self.client1, object_id3, - 2000, 2000) - for _ in range(10): - self.client1.fetch([object_id3]) - self.client2.fetch([object_id3]) - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id3): - self.client2.fetch([object_id3]) - assert_get_object_equal(self, self.client1, self.client2, object_id3, - memory_buffer=memory_buffer3, metadata=metadata3) + # Test calling the same fetch request a bunch of times. + object_id3 = random_object_id() + self.assertEqual(self.client1.contains(object_id3), False) + self.assertEqual(self.client2.contains(object_id3), False) + for _ in range(10): + self.client1.fetch([object_id3]) + self.client2.fetch([object_id3]) + memory_buffer3, metadata3 = create_object_with_id(self.client1, + object_id3, + 2000, 2000) + for _ in range(10): + self.client1.fetch([object_id3]) + self.client2.fetch([object_id3]) + # TODO(rkn): Right now we must wait for the object table to be updated. + while not self.client2.contains(object_id3): + self.client2.fetch([object_id3]) + assert_get_object_equal(self, self.client1, self.client2, object_id3, + memory_buffer=memory_buffer3, + metadata=metadata3) - def test_fetch_multiple(self): - for _ in range(20): - # Create two objects and a third fake one that doesn't exist. - object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, - 2000) - missing_object_id = random_object_id() - object_id2, memory_buffer2, metadata2 = create_object(self.client1, 2000, - 2000) - object_ids = [object_id1, missing_object_id, object_id2] - # Fetch the objects from the other plasma store. The second object ID - # should timeout since it does not exist. - # TODO(rkn): Right now we must wait for the object table to be updated. - while ((not self.client2.contains(object_id1)) or - (not self.client2.contains(object_id2))): - self.client2.fetch(object_ids) - # Compare the buffers of the objects that do exist. - assert_get_object_equal(self, self.client1, self.client2, object_id1, - memory_buffer=memory_buffer1, metadata=metadata1) - assert_get_object_equal(self, self.client1, self.client2, object_id2, - memory_buffer=memory_buffer2, metadata=metadata2) - # Fetch in the other direction. The fake object still does not exist. - self.client1.fetch(object_ids) - assert_get_object_equal(self, self.client2, self.client1, object_id1, - memory_buffer=memory_buffer1, metadata=metadata1) - assert_get_object_equal(self, self.client2, self.client1, object_id2, - memory_buffer=memory_buffer2, metadata=metadata2) + def test_fetch_multiple(self): + for _ in range(20): + # Create two objects and a third fake one that doesn't exist. + object_id1, memory_buffer1, metadata1 = create_object(self.client1, + 2000, + 2000) + missing_object_id = random_object_id() + object_id2, memory_buffer2, metadata2 = create_object(self.client1, + 2000, + 2000) + object_ids = [object_id1, missing_object_id, object_id2] + # Fetch the objects from the other plasma store. The second object + # ID should timeout since it does not exist. + # TODO(rkn): Right now we must wait for the object table to be + # updated. + while ((not self.client2.contains(object_id1)) or + (not self.client2.contains(object_id2))): + self.client2.fetch(object_ids) + # Compare the buffers of the objects that do exist. + assert_get_object_equal(self, self.client1, self.client2, + object_id1, memory_buffer=memory_buffer1, + metadata=metadata1) + assert_get_object_equal(self, self.client1, self.client2, + object_id2, memory_buffer=memory_buffer2, + metadata=metadata2) + # Fetch in the other direction. The fake object still does not + # exist. + self.client1.fetch(object_ids) + assert_get_object_equal(self, self.client2, self.client1, + object_id1, memory_buffer=memory_buffer1, + metadata=metadata1) + assert_get_object_equal(self, self.client2, self.client1, + object_id2, memory_buffer=memory_buffer2, + metadata=metadata2) - # Check that we can call fetch with duplicated object IDs. - object_id3 = random_object_id() - self.client1.fetch([object_id3, object_id3]) - object_id4, memory_buffer4, metadata4 = create_object(self.client1, 2000, - 2000) - time.sleep(0.1) - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id4): - self.client2.fetch([object_id3, object_id3, object_id4, object_id4]) - assert_get_object_equal(self, self.client2, self.client1, object_id4, - memory_buffer=memory_buffer4, metadata=metadata4) + # Check that we can call fetch with duplicated object IDs. + object_id3 = random_object_id() + self.client1.fetch([object_id3, object_id3]) + object_id4, memory_buffer4, metadata4 = create_object(self.client1, + 2000, + 2000) + time.sleep(0.1) + # TODO(rkn): Right now we must wait for the object table to be updated. + while not self.client2.contains(object_id4): + self.client2.fetch([object_id3, object_id3, object_id4, + object_id4]) + assert_get_object_equal(self, self.client2, self.client1, object_id4, + memory_buffer=memory_buffer4, + metadata=metadata4) - def test_wait(self): - # Test timeout. - obj_id0 = random_object_id() - self.client1.wait([obj_id0], timeout=100, num_returns=1) - # If we get here, the test worked. + def test_wait(self): + # Test timeout. + obj_id0 = random_object_id() + self.client1.wait([obj_id0], timeout=100, num_returns=1) + # If we get here, the test worked. - # Test wait if local objects available. - obj_id1 = random_object_id() - self.client1.create(obj_id1, 1000) - self.client1.seal(obj_id1) - ready, waiting = self.client1.wait([obj_id1], timeout=100, num_returns=1) - self.assertEqual(set(ready), set([obj_id1])) - self.assertEqual(waiting, []) + # Test wait if local objects available. + obj_id1 = random_object_id() + self.client1.create(obj_id1, 1000) + self.client1.seal(obj_id1) + ready, waiting = self.client1.wait([obj_id1], timeout=100, + num_returns=1) + self.assertEqual(set(ready), set([obj_id1])) + self.assertEqual(waiting, []) - # Test wait if only one object available and only one object waited for. - obj_id2 = random_object_id() - self.client1.create(obj_id2, 1000) - # Don't seal. - ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, - num_returns=1) - self.assertEqual(set(ready), set([obj_id1])) - self.assertEqual(set(waiting), set([obj_id2])) + # Test wait if only one object available and only one object waited + # for. + obj_id2 = random_object_id() + self.client1.create(obj_id2, 1000) + # Don't seal. + ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100, + num_returns=1) + self.assertEqual(set(ready), set([obj_id1])) + self.assertEqual(set(waiting), set([obj_id2])) - # Test wait if object is sealed later. - obj_id3 = random_object_id() + # Test wait if object is sealed later. + obj_id3 = random_object_id() - def finish(): - self.client2.create(obj_id3, 1000) - self.client2.seal(obj_id3) + def finish(): + self.client2.create(obj_id3, 1000) + self.client2.seal(obj_id3) - t = threading.Timer(0.1, finish) - t.start() - ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], - timeout=1000, num_returns=2) - self.assertEqual(set(ready), set([obj_id1, obj_id3])) - self.assertEqual(set(waiting), set([obj_id2])) + t = threading.Timer(0.1, finish) + t.start() + ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], + timeout=1000, num_returns=2) + self.assertEqual(set(ready), set([obj_id1, obj_id3])) + self.assertEqual(set(waiting), set([obj_id2])) - # Test if the appropriate number of objects is shown if some objects are - # not ready. - ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3) - self.assertEqual(set(ready), set([obj_id1, obj_id3])) - self.assertEqual(set(waiting), set([obj_id2])) + # Test if the appropriate number of objects is shown if some objects + # are not ready. + ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3) + self.assertEqual(set(ready), set([obj_id1, obj_id3])) + self.assertEqual(set(waiting), set([obj_id2])) - # Don't forget to seal obj_id2. - self.client1.seal(obj_id2) + # Don't forget to seal obj_id2. + self.client1.seal(obj_id2) - # Test calling wait a bunch of times. - object_ids = [] - # TODO(rkn): Increasing n to 100 (or larger) will cause failures. The - # problem appears to be that the number of timers added to the manager - # event loop slow down the manager so much that some of the asynchronous - # Redis commands timeout triggering fatal failure callbacks. - n = 40 - for i in range(n * (n + 1) // 2): - if i % 2 == 0: - object_id, _, _ = create_object(self.client1, 200, 200) - else: - object_id, _, _ = create_object(self.client2, 200, 200) - object_ids.append(object_id) - # Try waiting for all of the object IDs on the first client. - waiting = object_ids - retrieved = [] - for i in range(1, n + 1): - ready, waiting = self.client1.wait(waiting, timeout=1000, num_returns=i) - self.assertEqual(len(ready), i) - retrieved += ready - self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client1.wait(object_ids, timeout=1000, - num_returns=len(object_ids)) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - # Try waiting for all of the object IDs on the second client. - waiting = object_ids - retrieved = [] - for i in range(1, n + 1): - ready, waiting = self.client2.wait(waiting, timeout=1000, num_returns=i) - self.assertEqual(len(ready), i) - retrieved += ready - self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client2.wait(object_ids, timeout=1000, - num_returns=len(object_ids)) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) + # Test calling wait a bunch of times. + object_ids = [] + # TODO(rkn): Increasing n to 100 (or larger) will cause failures. The + # problem appears to be that the number of timers added to the manager + # event loop slow down the manager so much that some of the + # asynchronous Redis commands timeout triggering fatal failure + # callbacks. + n = 40 + for i in range(n * (n + 1) // 2): + if i % 2 == 0: + object_id, _, _ = create_object(self.client1, 200, 200) + else: + object_id, _, _ = create_object(self.client2, 200, 200) + object_ids.append(object_id) + # Try waiting for all of the object IDs on the first client. + waiting = object_ids + retrieved = [] + for i in range(1, n + 1): + ready, waiting = self.client1.wait(waiting, timeout=1000, + num_returns=i) + self.assertEqual(len(ready), i) + retrieved += ready + self.assertEqual(set(retrieved), set(object_ids)) + ready, waiting = self.client1.wait(object_ids, timeout=1000, + num_returns=len(object_ids)) + self.assertEqual(set(ready), set(object_ids)) + self.assertEqual(waiting, []) + # Try waiting for all of the object IDs on the second client. + waiting = object_ids + retrieved = [] + for i in range(1, n + 1): + ready, waiting = self.client2.wait(waiting, timeout=1000, + num_returns=i) + self.assertEqual(len(ready), i) + retrieved += ready + self.assertEqual(set(retrieved), set(object_ids)) + ready, waiting = self.client2.wait(object_ids, timeout=1000, + num_returns=len(object_ids)) + self.assertEqual(set(ready), set(object_ids)) + self.assertEqual(waiting, []) - # Make sure that wait returns when the requested number of object IDs are - # available and does not wait for all object IDs to be available. - object_ids = [random_object_id() for _ in range(9)] + [20 * b'\x00'] - object_ids_perm = object_ids[:] - random.shuffle(object_ids_perm) - for i in range(10): - if i % 2 == 0: - create_object_with_id(self.client1, object_ids_perm[i], 2000, 2000) - else: - create_object_with_id(self.client2, object_ids_perm[i], 2000, 2000) - ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1)) - self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)])) - self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):])) + # Make sure that wait returns when the requested number of object IDs + # are available and does not wait for all object IDs to be available. + object_ids = [random_object_id() for _ in range(9)] + [20 * b'\x00'] + object_ids_perm = object_ids[:] + random.shuffle(object_ids_perm) + for i in range(10): + if i % 2 == 0: + create_object_with_id(self.client1, object_ids_perm[i], 2000, + 2000) + else: + create_object_with_id(self.client2, object_ids_perm[i], 2000, + 2000) + ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1)) + self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)])) + self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):])) - def test_transfer(self): - num_attempts = 100 - for _ in range(100): - # Create an object. - object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, - 2000) - # Transfer the buffer to the the other Plasma store. There is a race - # condition on the create and transfer of the object, so keep trying - # until the object appears on the second Plasma store. - for i in range(num_attempts): - self.client1.transfer("127.0.0.1", self.port2, object_id1) - buff = self.client2.get([object_id1], timeout_ms=100)[0] - if buff is not None: - break - self.assertNotEqual(buff, None) - del buff + def test_transfer(self): + num_attempts = 100 + for _ in range(100): + # Create an object. + object_id1, memory_buffer1, metadata1 = create_object(self.client1, + 2000, + 2000) + # Transfer the buffer to the the other Plasma store. There is a + # race condition on the create and transfer of the object, so keep + # trying until the object appears on the second Plasma store. + for i in range(num_attempts): + self.client1.transfer("127.0.0.1", self.port2, object_id1) + buff = self.client2.get([object_id1], timeout_ms=100)[0] + if buff is not None: + break + self.assertNotEqual(buff, None) + del buff - # Compare the two buffers. - assert_get_object_equal(self, self.client1, self.client2, object_id1, - memory_buffer=memory_buffer1, metadata=metadata1) - # # Transfer the buffer again. - # self.client1.transfer("127.0.0.1", self.port2, object_id1) - # # Compare the two buffers. - # assert_get_object_equal(self, self.client1, self.client2, object_id1, - # memory_buffer=memory_buffer1, - # metadata=metadata1) + # Compare the two buffers. + assert_get_object_equal(self, self.client1, self.client2, + object_id1, memory_buffer=memory_buffer1, + metadata=metadata1) + # # Transfer the buffer again. + # self.client1.transfer("127.0.0.1", self.port2, object_id1) + # # Compare the two buffers. + # assert_get_object_equal(self, self.client1, self.client2, + # object_id1, + # memory_buffer=memory_buffer1, + # metadata=metadata1) - # Create an object. - object_id2, memory_buffer2, metadata2 = create_object(self.client2, - 20000, 20000) - # Transfer the buffer to the the other Plasma store. There is a race - # condition on the create and transfer of the object, so keep trying - # until the object appears on the second Plasma store. - for i in range(num_attempts): - self.client2.transfer("127.0.0.1", self.port1, object_id2) - buff = self.client1.get([object_id2], timeout_ms=100)[0] - if buff is not None: - break - self.assertNotEqual(buff, None) - del buff + # Create an object. + object_id2, memory_buffer2, metadata2 = create_object(self.client2, + 20000, 20000) + # Transfer the buffer to the the other Plasma store. There is a + # race condition on the create and transfer of the object, so keep + # trying until the object appears on the second Plasma store. + for i in range(num_attempts): + self.client2.transfer("127.0.0.1", self.port1, object_id2) + buff = self.client1.get([object_id2], timeout_ms=100)[0] + if buff is not None: + break + self.assertNotEqual(buff, None) + del buff - # Compare the two buffers. - assert_get_object_equal(self, self.client1, self.client2, object_id2, - memory_buffer=memory_buffer2, metadata=metadata2) + # Compare the two buffers. + assert_get_object_equal(self, self.client1, self.client2, + object_id2, memory_buffer=memory_buffer2, + metadata=metadata2) - def test_illegal_functionality(self): - # Create an object id string. - # object_id = random_object_id() - # Create a new buffer. - # memory_buffer = self.client1.create(object_id, 20000) - # This test is commented out because it currently fails. - # # Transferring the buffer before sealing it should fail. - # self.assertRaises(Exception, - # lambda : self.manager1.transfer(1, object_id)) - pass + def test_illegal_functionality(self): + # Create an object id string. + # object_id = random_object_id() + # Create a new buffer. + # memory_buffer = self.client1.create(object_id, 20000) + # This test is commented out because it currently fails. + # # Transferring the buffer before sealing it should fail. + # self.assertRaises(Exception, + # lambda : self.manager1.transfer(1, object_id)) + pass - def test_stresstest(self): - a = time.time() - object_ids = [] - for i in range(10000): # TODO(pcm): increase this to 100000. - object_id = random_object_id() - object_ids.append(object_id) - self.client1.create(object_id, 1) - self.client1.seal(object_id) - for object_id in object_ids: - self.client1.transfer("127.0.0.1", self.port2, object_id) - b = time.time() - a + def test_stresstest(self): + a = time.time() + object_ids = [] + for i in range(10000): # TODO(pcm): increase this to 100000. + object_id = random_object_id() + object_ids.append(object_id) + self.client1.create(object_id, 1) + self.client1.seal(object_id) + for object_id in object_ids: + self.client1.transfer("127.0.0.1", self.port2, object_id) + b = time.time() - a - print("it took", b, "seconds to put and transfer the objects") + print("it took", b, "seconds to put and transfer the objects") class TestPlasmaManagerRecovery(unittest.TestCase): - def setUp(self): - # Start a Plasma store. - self.store_name, self.p2 = plasma.start_plasma_store( - use_valgrind=USE_VALGRIND) - # Start a Redis server. - self.redis_address, _ = services.start_redis("127.0.0.1") - # Start a PlasmaManagers. - manager_name, self.p3, self.port1 = plasma.start_plasma_manager( - self.store_name, - self.redis_address, - use_valgrind=USE_VALGRIND) - # Connect a PlasmaClient. - self.client = plasma.PlasmaClient(self.store_name, manager_name) + def setUp(self): + # Start a Plasma store. + self.store_name, self.p2 = plasma.start_plasma_store( + use_valgrind=USE_VALGRIND) + # Start a Redis server. + self.redis_address, _ = services.start_redis("127.0.0.1") + # Start a PlasmaManagers. + manager_name, self.p3, self.port1 = plasma.start_plasma_manager( + self.store_name, + self.redis_address, + use_valgrind=USE_VALGRIND) + # Connect a PlasmaClient. + self.client = plasma.PlasmaClient(self.store_name, manager_name) - # Store the processes that will be explicitly killed during tearDown so - # that a test case can remove ones that will be killed during the test. - # NOTE: The plasma managers must be killed before the plasma store since - # plasma store death will bring down the managers. - self.processes_to_kill = [self.p3, self.p2] + # Store the processes that will be explicitly killed during tearDown so + # that a test case can remove ones that will be killed during the test. + # NOTE: The plasma managers must be killed before the plasma store + # since plasma store death will bring down the managers. + self.processes_to_kill = [self.p3, self.p2] - def tearDown(self): - # Check that the processes are still alive. - for process in self.processes_to_kill: - self.assertEqual(process.poll(), None) + def tearDown(self): + # Check that the processes are still alive. + for process in self.processes_to_kill: + self.assertEqual(process.poll(), None) - # Kill the Plasma store and Plasma manager processes. - if USE_VALGRIND: - # Give processes opportunity to finish work. - time.sleep(1) - for process in self.processes_to_kill: - process.send_signal(signal.SIGTERM) - process.wait() - if process.returncode != 0: - print("aborting due to valgrind error") - os._exit(-1) - else: - for process in self.processes_to_kill: - process.kill() + # Kill the Plasma store and Plasma manager processes. + if USE_VALGRIND: + # Give processes opportunity to finish work. + time.sleep(1) + for process in self.processes_to_kill: + process.send_signal(signal.SIGTERM) + process.wait() + if process.returncode != 0: + print("aborting due to valgrind error") + os._exit(-1) + else: + for process in self.processes_to_kill: + process.kill() - # Clean up the Redis server. - services.cleanup() + # Clean up the Redis server. + services.cleanup() - def test_delayed_start(self): - num_objects = 10 - # Create some objects using one client. - object_ids = [random_object_id() for _ in range(num_objects)] - for i in range(10): - create_object_with_id(self.client, object_ids[i], 2000, 2000) + def test_delayed_start(self): + num_objects = 10 + # Create some objects using one client. + object_ids = [random_object_id() for _ in range(num_objects)] + for i in range(10): + create_object_with_id(self.client, object_ids[i], 2000, 2000) - # Wait until the objects have been sealed in the store. - ready, waiting = self.client.wait(object_ids, num_returns=num_objects) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) + # Wait until the objects have been sealed in the store. + ready, waiting = self.client.wait(object_ids, num_returns=num_objects) + self.assertEqual(set(ready), set(object_ids)) + self.assertEqual(waiting, []) - # Start a second plasma manager attached to the same store. - manager_name, self.p5, self.port2 = plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) - self.processes_to_kill = [self.p5] + self.processes_to_kill + # Start a second plasma manager attached to the same store. + manager_name, self.p5, self.port2 = plasma.start_plasma_manager( + self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) + self.processes_to_kill = [self.p5] + self.processes_to_kill - # Check that the second manager knows about existing objects. - client2 = plasma.PlasmaClient(self.store_name, manager_name) - ready, waiting = [], object_ids - while True: - ready, waiting = client2.wait(object_ids, num_returns=num_objects, - timeout=0) - if len(ready) == len(object_ids): - break + # Check that the second manager knows about existing objects. + client2 = plasma.PlasmaClient(self.store_name, manager_name) + ready, waiting = [], object_ids + while True: + ready, waiting = client2.wait(object_ids, num_returns=num_objects, + timeout=0) + if len(ready) == len(object_ids): + break - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) + self.assertEqual(set(ready), set(object_ids)) + self.assertEqual(waiting, []) if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) + if len(sys.argv) > 1: + # Pop the argument so we don't mess with unittest's own argument + # parser. + if sys.argv[-1] == "valgrind": + arg = sys.argv.pop() + USE_VALGRIND = True + print("Using valgrind for tests") + unittest.main(verbosity=2) diff --git a/python/ray/plasma/utils.py b/python/ray/plasma/utils.py index e2b63f2de..f4141a1f1 100644 --- a/python/ray/plasma/utils.py +++ b/python/ray/plasma/utils.py @@ -7,39 +7,41 @@ import random def random_object_id(): - return np.random.bytes(20) + return np.random.bytes(20) def generate_metadata(length): - metadata_buffer = bytearray(length) - if length > 0: - metadata_buffer[0] = random.randint(0, 255) - metadata_buffer[-1] = random.randint(0, 255) - for _ in range(100): - metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255) - return metadata_buffer + metadata_buffer = bytearray(length) + if length > 0: + metadata_buffer[0] = random.randint(0, 255) + metadata_buffer[-1] = random.randint(0, 255) + for _ in range(100): + metadata_buffer[random.randint(0, length - 1)] = ( + random.randint(0, 255)) + return metadata_buffer def write_to_data_buffer(buff, length): - if length > 0: - buff[0] = chr(random.randint(0, 255)) - buff[-1] = chr(random.randint(0, 255)) - for _ in range(100): - buff[random.randint(0, length - 1)] = chr(random.randint(0, 255)) + if length > 0: + buff[0] = chr(random.randint(0, 255)) + buff[-1] = chr(random.randint(0, 255)) + for _ in range(100): + buff[random.randint(0, length - 1)] = chr(random.randint(0, 255)) def create_object_with_id(client, object_id, data_size, metadata_size, seal=True): - metadata = generate_metadata(metadata_size) - memory_buffer = client.create(object_id, data_size, metadata) - write_to_data_buffer(memory_buffer, data_size) - if seal: - client.seal(object_id) - return memory_buffer, metadata + metadata = generate_metadata(metadata_size) + memory_buffer = client.create(object_id, data_size, metadata) + write_to_data_buffer(memory_buffer, data_size) + if seal: + client.seal(object_id) + return memory_buffer, metadata def create_object(client, data_size, metadata_size, seal=True): - object_id = random_object_id() - memory_buffer, metadata = create_object_with_id(client, object_id, data_size, - metadata_size, seal=seal) - return object_id, memory_buffer, metadata + object_id = random_object_id() + memory_buffer, metadata = create_object_with_id(client, object_id, + data_size, metadata_size, + seal=seal) + return object_id, memory_buffer, metadata diff --git a/python/ray/rllib/a3c/LSTM.py b/python/ray/rllib/a3c/LSTM.py index 4aac0982e..4ddb0cb56 100644 --- a/python/ray/rllib/a3c/LSTM.py +++ b/python/ray/rllib/a3c/LSTM.py @@ -16,97 +16,98 @@ use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >= class LSTMPolicy(Policy): - def setup_graph(self, ob_space, ac_space): - """Setup model used for Policy. + def setup_graph(self, ob_space, ac_space): + """Setup model used for Policy. - In this A3C implementation, both the Critic and the Actor share the model. - """ - self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space)) + In this A3C implementation, both the Critic and the Actor share the + model. + """ + self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space)) - for i in range(4): - x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2])) - # Introduce a "fake" batch dimension of 1 after flatten so that we can do - # LSTM over the time dim. - x = tf.expand_dims(flatten(x), [0]) + for i in range(4): + x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2])) + # Introduce a "fake" batch dimension of 1 after flatten so that we can + # do LSTM over the time dim. + x = tf.expand_dims(flatten(x), [0]) - size = 256 - if use_tf100_api: - lstm = rnn.BasicLSTMCell(size, state_is_tuple=True) - else: - lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True) - self.state_size = lstm.state_size - step_size = tf.shape(self.x)[:1] + size = 256 + if use_tf100_api: + lstm = rnn.BasicLSTMCell(size, state_is_tuple=True) + else: + lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True) + self.state_size = lstm.state_size + step_size = tf.shape(self.x)[:1] - c_init = np.zeros((1, lstm.state_size.c), np.float32) - h_init = np.zeros((1, lstm.state_size.h), np.float32) - self.state_init = [c_init, h_init] - c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) - h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) - self.state_in = [c_in, h_in] + c_init = np.zeros((1, lstm.state_size.c), np.float32) + h_init = np.zeros((1, lstm.state_size.h), np.float32) + self.state_init = [c_init, h_init] + c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) + h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) + self.state_in = [c_in, h_in] - if use_tf100_api: - state_in = rnn.LSTMStateTuple(c_in, h_in) - else: - state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) - lstm_outputs, lstm_state = tf.nn.dynamic_rnn( - lstm, x, initial_state=state_in, sequence_length=step_size, - time_major=False) - lstm_c, lstm_h = lstm_state - x = tf.reshape(lstm_outputs, [-1, size]) - self.logits = linear(x, ac_space, "action", - normalized_columns_initializer(0.01)) - self.vf = tf.reshape(linear(x, 1, "value", - normalized_columns_initializer(1.0)), [-1]) - self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] - self.sample = categorical_sample(self.logits, ac_space)[0, :] - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - self.global_step = tf.get_variable( - "global_step", [], tf.int32, - initializer=tf.constant_initializer(0, dtype=tf.int32), - trainable=False) + if use_tf100_api: + state_in = rnn.LSTMStateTuple(c_in, h_in) + else: + state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) + lstm_outputs, lstm_state = tf.nn.dynamic_rnn( + lstm, x, initial_state=state_in, sequence_length=step_size, + time_major=False) + lstm_c, lstm_h = lstm_state + x = tf.reshape(lstm_outputs, [-1, size]) + self.logits = linear(x, ac_space, "action", + normalized_columns_initializer(0.01)) + self.vf = tf.reshape(linear(x, 1, "value", + normalized_columns_initializer(1.0)), [-1]) + self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] + self.sample = categorical_sample(self.logits, ac_space)[0, :] + self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + self.global_step = tf.get_variable( + "global_step", [], tf.int32, + initializer=tf.constant_initializer(0, dtype=tf.int32), + trainable=False) - def get_gradients(self, batch): - """Computing the gradient is actually model-dependent. + def get_gradients(self, batch): + """Computing the gradient is actually model-dependent. - The LSTM needs its hidden states in order to compute the gradient - accurately. - """ - feed_dict = { - self.x: batch.si, - self.ac: batch.a, - self.adv: batch.adv, - self.r: batch.r, - self.state_in[0]: batch.features[0], - self.state_in[1]: batch.features[1] - } - self.local_steps += 1 - return self.sess.run(self.grads, feed_dict=feed_dict) + The LSTM needs its hidden states in order to compute the gradient + accurately. + """ + feed_dict = { + self.x: batch.si, + self.ac: batch.a, + self.adv: batch.adv, + self.r: batch.r, + self.state_in[0]: batch.features[0], + self.state_in[1]: batch.features[1] + } + self.local_steps += 1 + return self.sess.run(self.grads, feed_dict=feed_dict) - def act(self, ob, c, h): - return self.sess.run([self.sample, self.vf] + self.state_out, - {self.x: [ob], - self.state_in[0]: c, - self.state_in[1]: h}) + def act(self, ob, c, h): + return self.sess.run([self.sample, self.vf] + self.state_out, + {self.x: [ob], + self.state_in[0]: c, + self.state_in[1]: h}) - def value(self, ob, c, h): - return self.sess.run(self.vf, {self.x: [ob], - self.state_in[0]: c, - self.state_in[1]: h})[0] + def value(self, ob, c, h): + return self.sess.run(self.vf, {self.x: [ob], + self.state_in[0]: c, + self.state_in[1]: h})[0] - def get_initial_features(self): - return self.state_init + def get_initial_features(self): + return self.state_init class RawLSTMPolicy(LSTMPolicy): - def get_weights(self): - if not hasattr(self, "_weights"): - self._weights = self.variables.get_weights() - return self._weights + def get_weights(self): + if not hasattr(self, "_weights"): + self._weights = self.variables.get_weights() + return self._weights - def set_weights(self, weights): - self._weights = weights + def set_weights(self, weights): + self._weights = weights - def model_update(self, grads): - for var, grad in zip(self.var_list, grads): - self._weights[var.name[:-2]] -= 1e-4 * grad + def model_update(self, grads): + for var, grad in zip(self.var_list, grads): + self._weights[var.name[:-2]] -= 1e-4 * grad diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index bdfced2a7..f0b909e4f 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -22,108 +22,109 @@ DEFAULT_CONFIG = { @ray.remote class Runner(object): - """Actor object to start running simulation on workers. + """Actor object to start running simulation on workers. - The gradient computation is also executed from this object. - """ - def __init__(self, env_name, actor_id, logdir, start=True): - env = create_env(env_name) - self.id = actor_id - num_actions = env.action_space.n - self.policy = LSTMPolicy(env.observation_space.shape, num_actions, - actor_id) - self.runner = RunnerThread(env, self.policy, 20) - self.env = env - self.logdir = logdir - if start: - self.start() - - def pull_batch_from_queue(self): - """Take a rollout from the queue of the thread runner.""" - rollout = self.runner.queue.get(timeout=600.0) - if isinstance(rollout, BaseException): - raise rollout - while not rollout.terminal: - try: - part = self.runner.queue.get_nowait() - if isinstance(part, BaseException): - raise rollout - rollout.extend(part) - except queue.Empty: - break - return rollout - - def get_completed_rollout_metrics(self): - """Returns metrics on previously completed rollouts. - - Calling this clears the queue of completed rollout metrics. + The gradient computation is also executed from this object. """ - completed = [] - while True: - try: - completed.append(self.runner.metrics_queue.get_nowait()) - except queue.Empty: - break - return completed + def __init__(self, env_name, actor_id, logdir, start=True): + env = create_env(env_name) + self.id = actor_id + num_actions = env.action_space.n + self.policy = LSTMPolicy(env.observation_space.shape, num_actions, + actor_id) + self.runner = RunnerThread(env, self.policy, 20) + self.env = env + self.logdir = logdir + if start: + self.start() - def start(self): - summary_writer = tf.summary.FileWriter( - os.path.join(self.logdir, "agent_%d" % self.id)) - self.summary_writer = summary_writer - self.runner.start_runner(self.policy.sess, summary_writer) + def pull_batch_from_queue(self): + """Take a rollout from the queue of the thread runner.""" + rollout = self.runner.queue.get(timeout=600.0) + if isinstance(rollout, BaseException): + raise rollout + while not rollout.terminal: + try: + part = self.runner.queue.get_nowait() + if isinstance(part, BaseException): + raise rollout + rollout.extend(part) + except queue.Empty: + break + return rollout - def compute_gradient(self, params): - self.policy.set_weights(params) - rollout = self.pull_batch_from_queue() - batch = process_rollout(rollout, gamma=0.99, lambda_=1.0) - gradient = self.policy.get_gradients(batch) - info = {"id": self.id, - "size": len(batch.a)} - return gradient, info + def get_completed_rollout_metrics(self): + """Returns metrics on previously completed rollouts. + + Calling this clears the queue of completed rollout metrics. + """ + completed = [] + while True: + try: + completed.append(self.runner.metrics_queue.get_nowait()) + except queue.Empty: + break + return completed + + def start(self): + summary_writer = tf.summary.FileWriter( + os.path.join(self.logdir, "agent_%d" % self.id)) + self.summary_writer = summary_writer + self.runner.start_runner(self.policy.sess, summary_writer) + + def compute_gradient(self, params): + self.policy.set_weights(params) + rollout = self.pull_batch_from_queue() + batch = process_rollout(rollout, gamma=0.99, lambda_=1.0) + gradient = self.policy.get_gradients(batch) + info = {"id": self.id, + "size": len(batch.a)} + return gradient, info class A3C(Algorithm): - def __init__(self, env_name, config, upload_dir=None): - config.update({"alg": "A3C"}) - Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) - self.env = create_env(env_name) - self.policy = LSTMPolicy( - self.env.observation_space.shape, self.env.action_space.n, 0) - self.agents = [ - Runner.remote(env_name, i, self.logdir) - for i in range(config["num_workers"])] - self.parameters = self.policy.get_weights() - self.iteration = 0 + def __init__(self, env_name, config, upload_dir=None): + config.update({"alg": "A3C"}) + Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) + self.env = create_env(env_name) + self.policy = LSTMPolicy( + self.env.observation_space.shape, self.env.action_space.n, 0) + self.agents = [ + Runner.remote(env_name, i, self.logdir) + for i in range(config["num_workers"])] + self.parameters = self.policy.get_weights() + self.iteration = 0 - def train(self): - gradient_list = [ - agent.compute_gradient.remote(self.parameters) - for agent in self.agents] - max_batches = self.config["num_batches_per_iteration"] - batches_so_far = len(gradient_list) - while gradient_list: - done_id, gradient_list = ray.wait(gradient_list) - gradient, info = ray.get(done_id)[0] - self.policy.model_update(gradient) - self.parameters = self.policy.get_weights() - if batches_so_far < max_batches: - batches_so_far += 1 - gradient_list.extend( - [self.agents[info["id"]].compute_gradient.remote(self.parameters)]) - res = self.fetch_metrics_from_workers() - self.iteration += 1 - return res + def train(self): + gradient_list = [ + agent.compute_gradient.remote(self.parameters) + for agent in self.agents] + max_batches = self.config["num_batches_per_iteration"] + batches_so_far = len(gradient_list) + while gradient_list: + done_id, gradient_list = ray.wait(gradient_list) + gradient, info = ray.get(done_id)[0] + self.policy.model_update(gradient) + self.parameters = self.policy.get_weights() + if batches_so_far < max_batches: + batches_so_far += 1 + gradient_list.extend( + [self.agents[info["id"]].compute_gradient.remote( + self.parameters)]) + res = self.fetch_metrics_from_workers() + self.iteration += 1 + return res - def fetch_metrics_from_workers(self): - episode_rewards = [] - episode_lengths = [] - metric_lists = [ - a.get_completed_rollout_metrics.remote() for a in self.agents] - for metrics in metric_lists: - for episode in ray.get(metrics): - episode_lengths.append(episode.episode_length) - episode_rewards.append(episode.episode_reward) - res = TrainingResult( - self.experiment_id.hex, self.iteration, - np.mean(episode_rewards), np.mean(episode_lengths), dict()) - return res + def fetch_metrics_from_workers(self): + episode_rewards = [] + episode_lengths = [] + metric_lists = [ + a.get_completed_rollout_metrics.remote() for a in self.agents] + for metrics in metric_lists: + for episode in ray.get(metrics): + episode_lengths.append(episode.episode_length) + episode_rewards.append(episode.episode_reward) + res = TrainingResult( + self.experiment_id.hex, self.iteration, + np.mean(episode_rewards), np.mean(episode_lengths), dict()) + return res diff --git a/python/ray/rllib/a3c/envs.py b/python/ray/rllib/a3c/envs.py index ba29c408e..95f26f544 100644 --- a/python/ray/rllib/a3c/envs.py +++ b/python/ray/rllib/a3c/envs.py @@ -14,94 +14,96 @@ logger.setLevel(logging.INFO) def create_env(env_id): - env = gym.make(env_id) - env = AtariProcessing(env) - env = Diagnostic(env) - return env + env = gym.make(env_id) + env = AtariProcessing(env) + env = Diagnostic(env) + return env def _process_frame42(frame): - frame = frame[34:(34 + 160), :160] - # Resize by half, then down to 42x42 (essentially mipmapping). If we resize - # directly we lose pixels that, when mapped to 42x42, aren't close enough to - # the pixel boundary. - frame = cv2.resize(frame, (80, 80)) - frame = cv2.resize(frame, (42, 42)) - frame = frame.mean(2) - frame = frame.astype(np.float32) - frame *= (1.0 / 255.0) - frame = np.reshape(frame, [42, 42, 1]) - return frame + frame = frame[34:(34 + 160), :160] + # Resize by half, then down to 42x42 (essentially mipmapping). If we resize + # directly we lose pixels that, when mapped to 42x42, aren't close enough + # to the pixel boundary. + frame = cv2.resize(frame, (80, 80)) + frame = cv2.resize(frame, (42, 42)) + frame = frame.mean(2) + frame = frame.astype(np.float32) + frame *= (1.0 / 255.0) + frame = np.reshape(frame, [42, 42, 1]) + return frame class AtariProcessing(gym.ObservationWrapper): - def __init__(self, env=None): - super(AtariProcessing, self).__init__(env) - self.observation_space = Box(0.0, 1.0, [42, 42, 1]) + def __init__(self, env=None): + super(AtariProcessing, self).__init__(env) + self.observation_space = Box(0.0, 1.0, [42, 42, 1]) - def _observation(self, observation): - return _process_frame42(observation) + def _observation(self, observation): + return _process_frame42(observation) class Diagnostic(gym.Wrapper): - def __init__(self, env=None): - super(Diagnostic, self).__init__(env) - self.diagnostics = DiagnosticsLogger() + def __init__(self, env=None): + super(Diagnostic, self).__init__(env) + self.diagnostics = DiagnosticsLogger() - def _reset(self): - observation = self.env.reset() - return self.diagnostics._after_reset(observation) + def _reset(self): + observation = self.env.reset() + return self.diagnostics._after_reset(observation) - def _step(self, action): - results = self.env.step(action) - return self.diagnostics._after_step(*results) + def _step(self, action): + results = self.env.step(action) + return self.diagnostics._after_step(*results) class DiagnosticsLogger(object): - def __init__(self, log_interval=503): - self._episode_time = time.time() - self._last_time = time.time() - self._local_t = 0 - self._log_interval = log_interval - self._episode_reward = 0 - self._episode_length = 0 - self._all_rewards = [] - self._last_episode_id = -1 + def __init__(self, log_interval=503): + self._episode_time = time.time() + self._last_time = time.time() + self._local_t = 0 + self._log_interval = log_interval + self._episode_reward = 0 + self._episode_length = 0 + self._all_rewards = [] + self._last_episode_id = -1 - def _after_reset(self, observation): - logger.info("Resetting environment") - self._episode_reward = 0 - self._episode_length = 0 - self._all_rewards = [] - return observation + def _after_reset(self, observation): + logger.info("Resetting environment") + self._episode_reward = 0 + self._episode_length = 0 + self._all_rewards = [] + return observation - def _after_step(self, observation, reward, done, info): - to_log = {} - if self._episode_length == 0: - self._episode_time = time.time() + def _after_step(self, observation, reward, done, info): + to_log = {} + if self._episode_length == 0: + self._episode_time = time.time() - self._local_t += 1 + self._local_t += 1 - if self._local_t % self._log_interval == 0: - cur_time = time.time() - self._last_time = cur_time + if self._local_t % self._log_interval == 0: + cur_time = time.time() + self._last_time = cur_time - if reward is not None: - self._episode_reward += reward - if observation is not None: - self._episode_length += 1 - self._all_rewards.append(reward) + if reward is not None: + self._episode_reward += reward + if observation is not None: + self._episode_length += 1 + self._all_rewards.append(reward) - if done: - logger.info("Episode terminating: episode_reward=%s episode_length=%s", - self._episode_reward, self._episode_length) - total_time = time.time() - self._episode_time - to_log["global/episode_reward"] = self._episode_reward - to_log["global/episode_length"] = self._episode_length - to_log["global/episode_time"] = total_time - to_log["global/reward_per_time"] = self._episode_reward / total_time - self._episode_reward = 0 - self._episode_length = 0 - self._all_rewards = [] + if done: + logger.info("Episode terminating: episode_reward=%s " + "episode_length=%s", + self._episode_reward, self._episode_length) + total_time = time.time() - self._episode_time + to_log["global/episode_reward"] = self._episode_reward + to_log["global/episode_length"] = self._episode_length + to_log["global/episode_time"] = total_time + to_log["global/reward_per_time"] = (self._episode_reward / + total_time) + self._episode_reward = 0 + self._episode_length = 0 + self._all_rewards = [] - return observation, reward, done, to_log + return observation, reward, done, to_log diff --git a/python/ray/rllib/a3c/example.py b/python/ray/rllib/a3c/example.py index 34529882e..47fc3d777 100755 --- a/python/ray/rllib/a3c/example.py +++ b/python/ray/rllib/a3c/example.py @@ -11,22 +11,22 @@ from ray.rllib.a3c import A3C, DEFAULT_CONFIG if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run the A3C algorithm.") - parser.add_argument("--environment", default="PongDeterministic-v3", - type=str, help="The gym environment to use.") - parser.add_argument("--redis-address", default=None, type=str, - help="The Redis address of the cluster.") - parser.add_argument("--num-workers", default=4, type=int, - help="The number of A3C workers to use>") + parser = argparse.ArgumentParser(description="Run the A3C algorithm.") + parser.add_argument("--environment", default="PongDeterministic-v3", + type=str, help="The gym environment to use.") + parser.add_argument("--redis-address", default=None, type=str, + help="The Redis address of the cluster.") + parser.add_argument("--num-workers", default=4, type=int, + help="The number of A3C workers to use>") - args = parser.parse_args() - ray.init(redis_address=args.redis_address, num_cpus=args.num_workers) + args = parser.parse_args() + ray.init(redis_address=args.redis_address, num_cpus=args.num_workers) - config = DEFAULT_CONFIG.copy() - config["num_workers"] = args.num_workers + config = DEFAULT_CONFIG.copy() + config["num_workers"] = args.num_workers - a3c = A3C(args.environment, config) + a3c = A3C(args.environment, config) - while True: - res = a3c.train() - print("current status: {}".format(res)) + while True: + res = a3c.train() + print("current status: {}".format(res)) diff --git a/python/ray/rllib/a3c/policy.py b/python/ray/rllib/a3c/policy.py index b77ff6506..f5335c404 100644 --- a/python/ray/rllib/a3c/policy.py +++ b/python/ray/rllib/a3c/policy.py @@ -8,138 +8,139 @@ import ray class Policy(object): - """The policy base class.""" - def __init__(self, ob_space, ac_space, task, name="local"): - self.local_steps = 0 - worker_device = "/job:localhost/replica:0/task:0/cpu:0" - self.g = tf.Graph() - with self.g.as_default(), tf.device(worker_device): - with tf.variable_scope(name): - self.setup_graph(ob_space, ac_space) - assert all([hasattr(self, attr) - for attr in ["vf", "logits", "x", "var_list"]]) - print("Setting up loss") - self.setup_loss(ac_space) - self.initialize() + """The policy base class.""" + def __init__(self, ob_space, ac_space, task, name="local"): + self.local_steps = 0 + worker_device = "/job:localhost/replica:0/task:0/cpu:0" + self.g = tf.Graph() + with self.g.as_default(), tf.device(worker_device): + with tf.variable_scope(name): + self.setup_graph(ob_space, ac_space) + assert all([hasattr(self, attr) + for attr in ["vf", "logits", "x", "var_list"]]) + print("Setting up loss") + self.setup_loss(ac_space) + self.initialize() - def setup_graph(self): - raise NotImplementedError + def setup_graph(self): + raise NotImplementedError - def setup_loss(self, num_actions, summarize=True): - self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac") - self.adv = tf.placeholder(tf.float32, [None], name="adv") - self.r = tf.placeholder(tf.float32, [None], name="r") + def setup_loss(self, num_actions, summarize=True): + self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac") + self.adv = tf.placeholder(tf.float32, [None], name="adv") + self.r = tf.placeholder(tf.float32, [None], name="r") - log_prob_tf = tf.nn.log_softmax(self.logits) - prob_tf = tf.nn.softmax(self.logits) + log_prob_tf = tf.nn.log_softmax(self.logits) + prob_tf = tf.nn.softmax(self.logits) - # The "policy gradients" loss: its derivative is precisely the policy - # gradient. Notice that self.ac is a placeholder that is provided - # externally. adv will contain the advantages, as calculated in - # process_rollout. - pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, - [1]) * self.adv) + # The "policy gradients" loss: its derivative is precisely the policy + # gradient. Notice that self.ac is a placeholder that is provided + # externally. adv will contain the advantages, as calculated in + # process_rollout. + pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, + [1]) * self.adv) - # loss of value function - vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r)) - vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss") - entropy = - tf.reduce_sum(prob_tf * log_prob_tf) + # loss of value function + vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r)) + vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss") + entropy = - tf.reduce_sum(prob_tf * log_prob_tf) - bs = tf.to_float(tf.shape(self.x)[0]) - self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01 + bs = tf.to_float(tf.shape(self.x)[0]) + self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01 - grads = tf.gradients(self.loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, 40.0) + grads = tf.gradients(self.loss, self.var_list) + self.grads, _ = tf.clip_by_global_norm(grads, 40.0) - grads_and_vars = list(zip(self.grads, self.var_list)) - opt = tf.train.AdamOptimizer(1e-4) - self._apply_gradients = opt.apply_gradients(grads_and_vars) + grads_and_vars = list(zip(self.grads, self.var_list)) + opt = tf.train.AdamOptimizer(1e-4) + self._apply_gradients = opt.apply_gradients(grads_and_vars) - if summarize: - tf.summary.scalar("model/policy_loss", pi_loss / bs) - tf.summary.scalar("model/value_loss", vf_loss / bs) - tf.summary.scalar("model/entropy", entropy / bs) - tf.summary.image("model/state", self.x) - self.summary_op = tf.summary.merge_all() + if summarize: + tf.summary.scalar("model/policy_loss", pi_loss / bs) + tf.summary.scalar("model/value_loss", vf_loss / bs) + tf.summary.scalar("model/entropy", entropy / bs) + tf.summary.image("model/state", self.x) + self.summary_op = tf.summary.merge_all() - def initialize(self): - self.sess = tf.Session(graph=self.g, config=tf.ConfigProto( - intra_op_parallelism_threads=1, inter_op_parallelism_threads=2)) - self.variables = ray.experimental.TensorFlowVariables(self.loss, self.sess) - self.sess.run(tf.global_variables_initializer()) + def initialize(self): + self.sess = tf.Session(graph=self.g, config=tf.ConfigProto( + intra_op_parallelism_threads=1, inter_op_parallelism_threads=2)) + self.variables = ray.experimental.TensorFlowVariables(self.loss, + self.sess) + self.sess.run(tf.global_variables_initializer()) - def model_update(self, grads): - feed_dict = {self.grads[i]: grads[i] - for i in range(len(grads))} - self.sess.run(self._apply_gradients, feed_dict=feed_dict) + def model_update(self, grads): + feed_dict = {self.grads[i]: grads[i] + for i in range(len(grads))} + self.sess.run(self._apply_gradients, feed_dict=feed_dict) - def get_weights(self): - weights = self.variables.get_weights() - return weights + def get_weights(self): + weights = self.variables.get_weights() + return weights - def set_weights(self, weights): - self.variables.set_weights(weights) + def set_weights(self, weights): + self.variables.set_weights(weights) - def get_gradients(self, batch): - raise NotImplementedError + def get_gradients(self, batch): + raise NotImplementedError - def get_vf_loss(self): - raise NotImplementedError + def get_vf_loss(self): + raise NotImplementedError - def act(self, ob): - raise NotImplementedError + def act(self, ob): + raise NotImplementedError - def value(self, ob): - raise NotImplementedError + def value(self, ob): + raise NotImplementedError def normalized_columns_initializer(std=1.0): - def _initializer(shape, dtype=None, partition_info=None): - out = np.random.randn(*shape).astype(np.float32) - out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) - return tf.constant(out) - return _initializer + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer def flatten(x): - return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) + return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None): - with tf.variable_scope(name): - stride_shape = [1, stride[0], stride[1], 1] - filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), - num_filters] + with tf.variable_scope(name): + stride_shape = [1, stride[0], stride[1], 1] + filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), + num_filters] - # There are "num input feature maps * filter height * filter width" - # inputs to each hidden unit. - fan_in = np.prod(filter_shape[:3]) - # Each unit in the lower layer receives a gradient from: - # "num output feature maps * filter height * filter width" / pooling size. - fan_out = np.prod(filter_shape[:2]) * num_filters - # Initialize weights with random weights. - w_bound = np.sqrt(6 / (fan_in + fan_out)) + # There are "num input feature maps * filter height * filter width" + # inputs to each hidden unit. + fan_in = np.prod(filter_shape[:3]) + # Each unit in the lower layer receives a gradient from: "num output + # feature maps * filter height * filter width" / pooling size. + fan_out = np.prod(filter_shape[:2]) * num_filters + # Initialize weights with random weights. + w_bound = np.sqrt(6 / (fan_in + fan_out)) - w = tf.get_variable("W", filter_shape, dtype, - tf.random_uniform_initializer(-w_bound, w_bound), - collections=collections) - b = tf.get_variable("b", [1, 1, 1, num_filters], - initializer=tf.constant_initializer(0.0), - collections=collections) - return tf.nn.conv2d(x, w, stride_shape, pad) + b + w = tf.get_variable("W", filter_shape, dtype, + tf.random_uniform_initializer(-w_bound, w_bound), + collections=collections) + b = tf.get_variable("b", [1, 1, 1, num_filters], + initializer=tf.constant_initializer(0.0), + collections=collections) + return tf.nn.conv2d(x, w, stride_shape, pad) + b def linear(x, size, name, initializer=None, bias_init=0): - w = tf.get_variable(name + "/w", [x.get_shape()[1], size], - initializer=initializer) - b = tf.get_variable(name + "/b", [size], - initializer=tf.constant_initializer(bias_init)) - return tf.matmul(x, w) + b + w = tf.get_variable(name + "/w", [x.get_shape()[1], size], + initializer=initializer) + b = tf.get_variable(name + "/b", [size], + initializer=tf.constant_initializer(bias_init)) + return tf.matmul(x, w) + b def categorical_sample(logits, d): - value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], - keep_dims=True), - 1), [1]) - return tf.one_hot(value, d) + value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], + keep_dims=True), + 1), [1]) + return tf.one_hot(value, d) diff --git a/python/ray/rllib/a3c/runner.py b/python/ray/rllib/a3c/runner.py index 7c8dd182f..2119f419c 100644 --- a/python/ray/rllib/a3c/runner.py +++ b/python/ray/rllib/a3c/runner.py @@ -11,26 +11,26 @@ import threading def discount(x, gamma): - return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] + return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] def process_rollout(rollout, gamma, lambda_=1.0): - """Given a rollout, compute its returns and the advantage.""" - batch_si = np.asarray(rollout.states) - batch_a = np.asarray(rollout.actions) - rewards = np.asarray(rollout.rewards) - vpred_t = np.asarray(rollout.values + [rollout.r]) + """Given a rollout, compute its returns and the advantage.""" + batch_si = np.asarray(rollout.states) + batch_a = np.asarray(rollout.actions) + rewards = np.asarray(rollout.rewards) + vpred_t = np.asarray(rollout.values + [rollout.r]) - rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) - batch_r = discount(rewards_plus_v, gamma)[:-1] - delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] - # This formula for the advantage comes "Generalized Advantage Estimation": - # https://arxiv.org/abs/1506.02438 - batch_adv = discount(delta_t, gamma * lambda_) + rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) + batch_r = discount(rewards_plus_v, gamma)[:-1] + delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] + # This formula for the advantage comes "Generalized Advantage Estimation": + # https://arxiv.org/abs/1506.02438 + batch_adv = discount(delta_t, gamma * lambda_) - features = rollout.features[0] - return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, - features) + features = rollout.features[0] + return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, + features) Batch = namedtuple( @@ -41,142 +41,143 @@ CompletedRollout = namedtuple( class PartialRollout(object): - """A piece of a complete rollout. + """A piece of a complete rollout. - We run our agent, and process its experience once it has processed enough - steps. - """ - def __init__(self): - self.states = [] - self.actions = [] - self.rewards = [] - self.values = [] - self.r = 0.0 - self.terminal = False - self.features = [] + We run our agent, and process its experience once it has processed enough + steps. + """ + def __init__(self): + self.states = [] + self.actions = [] + self.rewards = [] + self.values = [] + self.r = 0.0 + self.terminal = False + self.features = [] - def add(self, state, action, reward, value, terminal, features): - self.states += [state] - self.actions += [action] - self.rewards += [reward] - self.values += [value] - self.terminal = terminal - self.features += [features] + def add(self, state, action, reward, value, terminal, features): + self.states += [state] + self.actions += [action] + self.rewards += [reward] + self.values += [value] + self.terminal = terminal + self.features += [features] - def extend(self, other): - assert not self.terminal - self.states.extend(other.states) - self.actions.extend(other.actions) - self.rewards.extend(other.rewards) - self.values.extend(other.values) - self.r = other.r - self.terminal = other.terminal - self.features.extend(other.features) + def extend(self, other): + assert not self.terminal + self.states.extend(other.states) + self.actions.extend(other.actions) + self.rewards.extend(other.rewards) + self.values.extend(other.values) + self.r = other.r + self.terminal = other.terminal + self.features.extend(other.features) class RunnerThread(threading.Thread): - """This thread interacts with the environment and tells it what to do.""" - def __init__(self, env, policy, num_local_steps, visualise=False): - threading.Thread.__init__(self) - self.queue = queue.Queue(5) - self.metrics_queue = queue.Queue() - self.num_local_steps = num_local_steps - self.env = env - self.last_features = None - self.policy = policy - self.daemon = True - self.sess = None - self.summary_writer = None - self.visualise = visualise + """This thread interacts with the environment and tells it what to do.""" + def __init__(self, env, policy, num_local_steps, visualise=False): + threading.Thread.__init__(self) + self.queue = queue.Queue(5) + self.metrics_queue = queue.Queue() + self.num_local_steps = num_local_steps + self.env = env + self.last_features = None + self.policy = policy + self.daemon = True + self.sess = None + self.summary_writer = None + self.visualise = visualise - def start_runner(self, sess, summary_writer): - self.sess = sess - self.summary_writer = summary_writer - self.start() + def start_runner(self, sess, summary_writer): + self.sess = sess + self.summary_writer = summary_writer + self.start() - def run(self): - try: - with self.sess.as_default(): - self._run() - except BaseException as e: - self.queue.put(e) - raise e + def run(self): + try: + with self.sess.as_default(): + self._run() + except BaseException as e: + self.queue.put(e) + raise e - def _run(self): - rollout_provider = env_runner( - self.env, self.policy, self.num_local_steps, - self.summary_writer, self.visualise) - while True: - # The timeout variable exists because apparently, if one worker dies, the - # other workers won't die with it, unless the timeout is set to some - # large number. This is an empirical observation. - item = next(rollout_provider) - if isinstance(item, CompletedRollout): - self.metrics_queue.put(item) - else: - self.queue.put(item, timeout=600.0) + def _run(self): + rollout_provider = env_runner( + self.env, self.policy, self.num_local_steps, + self.summary_writer, self.visualise) + while True: + # The timeout variable exists because apparently, if one worker + # dies, the other workers won't die with it, unless the timeout is + # set to some large number. This is an empirical observation. + item = next(rollout_provider) + if isinstance(item, CompletedRollout): + self.metrics_queue.put(item) + else: + self.queue.put(item, timeout=600.0) def env_runner(env, policy, num_local_steps, summary_writer, render): - """This implements the logic of the thread runner. + """This implements the logic of the thread runner. - It continually runs the policy, and as long as the rollout exceeds a certain - length, the thread runner appends the policy to the queue. - """ - last_state = env.reset() - timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit" - ".max_episode_steps") - last_features = policy.get_initial_features() - length = 0 - rewards = 0 - rollout_number = 0 + It continually runs the policy, and as long as the rollout exceeds a + certain length, the thread runner appends the policy to the queue. + """ + last_state = env.reset() + timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit" + ".max_episode_steps") + last_features = policy.get_initial_features() + length = 0 + rewards = 0 + rollout_number = 0 - while True: - terminal_end = False - rollout = PartialRollout() + while True: + terminal_end = False + rollout = PartialRollout() - for _ in range(num_local_steps): - fetched = policy.act(last_state, *last_features) - action, value_, features = fetched[0], fetched[1], fetched[2:] - # Argmax to convert from one-hot. - state, reward, terminal, info = env.step(action.argmax()) - if render: - env.render() + for _ in range(num_local_steps): + fetched = policy.act(last_state, *last_features) + action, value_, features = fetched[0], fetched[1], fetched[2:] + # Argmax to convert from one-hot. + state, reward, terminal, info = env.step(action.argmax()) + if render: + env.render() - length += 1 - rewards += reward - if length >= timestep_limit: - terminal = True + length += 1 + rewards += reward + if length >= timestep_limit: + terminal = True - # Collect the experience. - rollout.add(last_state, action, reward, value_, terminal, last_features) + # Collect the experience. + rollout.add(last_state, action, reward, value_, terminal, + last_features) - last_state = state - last_features = features + last_state = state + last_features = features - if info: - summary = tf.Summary() - for k, v in info.items(): - summary.value.add(tag=k, simple_value=float(v)) - summary_writer.add_summary(summary, rollout_number) - summary_writer.flush() + if info: + summary = tf.Summary() + for k, v in info.items(): + summary.value.add(tag=k, simple_value=float(v)) + summary_writer.add_summary(summary, rollout_number) + summary_writer.flush() - if terminal: - terminal_end = True - yield CompletedRollout(length, rewards) + if terminal: + terminal_end = True + yield CompletedRollout(length, rewards) - if length >= timestep_limit or not env.metadata.get("semantics" - ".autoreset"): - last_state = env.reset() - last_features = policy.get_initial_features() - rollout_number += 1 - length = 0 - rewards = 0 - break + if (length >= timestep_limit or + not env.metadata.get("semantics.autoreset")): + last_state = env.reset() + last_features = policy.get_initial_features() + rollout_number += 1 + length = 0 + rewards = 0 + break - if not terminal_end: - rollout.r = policy.value(last_state, *last_features) + if not terminal_end: + rollout.r = policy.value(last_state, *last_features) - # Once we have enough experience, yield it, and have the ThreadRunner - # place it on a queue. - yield rollout + # Once we have enough experience, yield it, and have the ThreadRunner + # place it on a queue. + yield rollout diff --git a/python/ray/rllib/common.py b/python/ray/rllib/common.py index b025cbca8..3b01a7de8 100644 --- a/python/ray/rllib/common.py +++ b/python/ray/rllib/common.py @@ -9,39 +9,39 @@ import tempfile import uuid import smart_open if sys.version_info[0] == 2: - import cStringIO as StringIO + import cStringIO as StringIO elif sys.version_info[0] == 3: - import io as StringIO + import io as StringIO logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class RLLibEncoder(json.JSONEncoder): - def default(self, value): - if isinstance(value, np.float32) or isinstance(value, np.float64): - if np.isnan(value): - return None - else: - return float(value) + def default(self, value): + if isinstance(value, np.float32) or isinstance(value, np.float64): + if np.isnan(value): + return None + else: + return float(value) class RLLibLogger(object): - """Writing small amounts of data to S3 with real-time updates. - """ + """Writing small amounts of data to S3 with real-time updates. + """ - def __init__(self, uri): - self.result_buffer = StringIO.StringIO() - self.uri = uri + def __init__(self, uri): + self.result_buffer = StringIO.StringIO() + self.uri = uri - def write(self, b): - # TODO(pcm): At the moment we are writing the whole results output from - # the beginning in each iteration. This will write O(n^2) bytes where n - # is the number of bytes printed so far. Fix this! This should at least - # only write the last 5MBs (S3 chunksize). - with smart_open.smart_open(self.uri, "w") as f: - self.result_buffer.write(b) - f.write(self.result_buffer.getvalue()) + def write(self, b): + # TODO(pcm): At the moment we are writing the whole results output from + # the beginning in each iteration. This will write O(n^2) bytes where n + # is the number of bytes printed so far. Fix this! This should at least + # only write the last 5MBs (S3 chunksize). + with smart_open.smart_open(self.uri, "w") as f: + self.result_buffer.write(b) + f.write(self.result_buffer.getvalue()) TrainingResult = namedtuple("TrainingResult", [ @@ -54,54 +54,55 @@ TrainingResult = namedtuple("TrainingResult", [ class Algorithm(object): - """All RLlib algorithms extend this base class. + """All RLlib algorithms extend this base class. - Algorithm objects retain internal model state between calls to train(), so - you should create a new algorithm instance for each training session. + Algorithm objects retain internal model state between calls to train(), so + you should create a new algorithm instance for each training session. - Attributes: - env_name (str): Name of the OpenAI gym environment to train against. - config (obj): Algorithm-specific configuration data. - logdir (str): Directory in which training outputs should be placed. - - TODO(ekl): support checkpoint / restore of training state. - """ - - def __init__(self, env_name, config, upload_dir="file:///tmp/ray"): - """Initialize an RLLib algorithm. - - Args: - env_name (str): The name of the OpenAI gym environment to use. + Attributes: + env_name (str): Name of the OpenAI gym environment to train against. config (obj): Algorithm-specific configuration data. - upload_dir (str): Root directory into which the output directory - should be placed. Can be local like file:///tmp/ray/ or on S3 - like s3://bucketname/. - """ - self.experiment_id = uuid.uuid4() - self.env_name = env_name - self.config = config - self.config.update({"experiment_id": self.experiment_id.hex}) - self.config.update({"env_name": env_name}) - prefix = "{}_{}_{}".format( - env_name, - self.__class__.__name__, - datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) - if upload_dir.startswith("file"): - self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray") - else: - self.logdir = os.path.join(upload_dir, prefix) - log_path = os.path.join(self.logdir, "config.json") - with smart_open.smart_open(log_path, "w") as f: - json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder) - logger.info( - "%s algorithm created with logdir '%s'", - self.__class__.__name__, self.logdir) + logdir (str): Directory in which training outputs should be placed. - def train(self): - """Runs one logical iteration of training. - - Returns: - A TrainingResult that describes training progress. + TODO(ekl): support checkpoint / restore of training state. """ - raise NotImplementedError + def __init__(self, env_name, config, upload_dir="file:///tmp/ray"): + """Initialize an RLLib algorithm. + + Args: + env_name (str): The name of the OpenAI gym environment to use. + config (obj): Algorithm-specific configuration data. + upload_dir (str): Root directory into which the output directory + should be placed. Can be local like file:///tmp/ray/ or on S3 + like s3://bucketname/. + """ + self.experiment_id = uuid.uuid4() + self.env_name = env_name + self.config = config + self.config.update({"experiment_id": self.experiment_id.hex}) + self.config.update({"env_name": env_name}) + prefix = "{}_{}_{}".format( + env_name, + self.__class__.__name__, + datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) + if upload_dir.startswith("file"): + self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix, + dir="/tmp/ray") + else: + self.logdir = os.path.join(upload_dir, prefix) + log_path = os.path.join(self.logdir, "config.json") + with smart_open.smart_open(log_path, "w") as f: + json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder) + logger.info( + "%s algorithm created with logdir '%s'", + self.__class__.__name__, self.logdir) + + def train(self): + """Runs one logical iteration of training. + + Returns: + A TrainingResult that describes training progress. + """ + + raise NotImplementedError diff --git a/python/ray/rllib/dqn/build_graph.py b/python/ray/rllib/dqn/build_graph.py index bf54c8248..a3e4decc8 100644 --- a/python/ray/rllib/dqn/build_graph.py +++ b/python/ray/rllib/dqn/build_graph.py @@ -80,198 +80,205 @@ from ray.rllib.dqn.common import tf_util as U def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None): - """Creates the act function: + """Creates the act function: - Parameters - ---------- - make_obs_ph: str -> tf.placeholder or TfInput - a function that take a name and creates a placeholder of input with - that name - q_func: (tf.Variable, int, str, bool) -> tf.Variable - the model that takes the following inputs: - observation_in: object - the output of observation placeholder - num_actions: int - number of actions - scope: str - reuse: bool - should be passed to outer variable scope - and returns a tensor of shape (batch_size, num_actions) with values of - every action. - num_actions: int - number of actions. - scope: str or VariableScope - optional scope for variable_scope. - reuse: bool or None - whether or not the variables should be reused. To be able to reuse the - scope must be given. + Parameters + ---------- + make_obs_ph: str -> tf.placeholder or TfInput + a function that take a name and creates a placeholder of input with + that name + q_func: (tf.Variable, int, str, bool) -> tf.Variable + the model that takes the following inputs: + observation_in: object + the output of observation placeholder + num_actions: int + number of actions + scope: str + reuse: bool + should be passed to outer variable scope + and returns a tensor of shape (batch_size, num_actions) with values of + every action. + num_actions: int + number of actions. + scope: str or VariableScope + optional scope for variable_scope. + reuse: bool or None + whether or not the variables should be reused. To be able to reuse the + scope must be given. - Returns - ------- - act: (tf.Variable, bool, float) -> tf.Variable - function to select and action given observation. -` See the top of the file for details. - """ - with tf.variable_scope(scope, reuse=reuse): - observations_ph = U.ensure_tf_input(make_obs_ph("observation")) - stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic") - update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps") + Returns + ------- + act: (tf.Variable, bool, float) -> tf.Variable + function to select and action given observation. + ` See the top of the file for details. + """ + with tf.variable_scope(scope, reuse=reuse): + observations_ph = U.ensure_tf_input(make_obs_ph("observation")) + stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic") + update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps") - eps = tf.get_variable( - "eps", (), initializer=tf.constant_initializer(0)) + eps = tf.get_variable( + "eps", (), initializer=tf.constant_initializer(0)) - q_values = q_func(observations_ph.get(), num_actions, scope="q_func") - deterministic_actions = tf.argmax(q_values, axis=1) + q_values = q_func(observations_ph.get(), num_actions, scope="q_func") + deterministic_actions = tf.argmax(q_values, axis=1) - batch_size = tf.shape(observations_ph.get())[0] - random_actions = tf.random_uniform( - tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64) - chose_random = tf.random_uniform( - tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps - stochastic_actions = tf.where( - chose_random, random_actions, deterministic_actions) + batch_size = tf.shape(observations_ph.get())[0] + random_actions = tf.random_uniform( + tf.stack([batch_size]), minval=0, maxval=num_actions, + dtype=tf.int64) + chose_random = tf.random_uniform( + tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps + stochastic_actions = tf.where( + chose_random, random_actions, deterministic_actions) - output_actions = tf.cond( - stochastic_ph, lambda: stochastic_actions, - lambda: deterministic_actions) - update_eps_expr = eps.assign( - tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps)) + output_actions = tf.cond( + stochastic_ph, lambda: stochastic_actions, + lambda: deterministic_actions) + update_eps_expr = eps.assign( + tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps)) - act = U.function( - inputs=[observations_ph, stochastic_ph, update_eps_ph], - outputs=output_actions, - givens={update_eps_ph: -1.0, stochastic_ph: True}, - updates=[update_eps_expr]) - return act + act = U.function( + inputs=[observations_ph, stochastic_ph, update_eps_ph], + outputs=output_actions, + givens={update_eps_ph: -1.0, stochastic_ph: True}, + updates=[update_eps_expr]) + return act def build_train( make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None, gamma=1.0, double_q=True, scope="deepq", reuse=None): - """Creates the train function: + """Creates the train function: - Parameters - ---------- - make_obs_ph: str -> tf.placeholder or TfInput - a function that takes a name and creates a placeholder of input with - that name - q_func: (tf.Variable, int, str, bool) -> tf.Variable - the model that takes the following inputs: - observation_in: object - the output of observation placeholder - num_actions: int - number of actions - scope: str - reuse: bool - should be passed to outer variable scope - and returns a tensor of shape (batch_size, num_actions) with values of - every action. - num_actions: int - number of actions - reuse: bool - whether or not to reuse the graph variables - optimizer: tf.train.Optimizer - optimizer to use for the Q-learning objective. - grad_norm_clipping: float or None - clip gradient norms to this value. If None no clipping is performed. - gamma: float - discount rate. - double_q: bool - if true will use Double Q Learning (https://arxiv.org/abs/1509.06461). - In general it is a good idea to keep it enabled. - scope: str or VariableScope - optional scope for variable_scope. - reuse: bool or None - whether or not the variables should be reused. To be able to reuse the - scope must be given. + Parameters + ---------- + make_obs_ph: str -> tf.placeholder or TfInput + a function that takes a name and creates a placeholder of input with + that name + q_func: (tf.Variable, int, str, bool) -> tf.Variable + the model that takes the following inputs: + observation_in: object + the output of observation placeholder + num_actions: int + number of actions + scope: str + reuse: bool + should be passed to outer variable scope + and returns a tensor of shape (batch_size, num_actions) with values of + every action. + num_actions: int + number of actions + reuse: bool + whether or not to reuse the graph variables + optimizer: tf.train.Optimizer + optimizer to use for the Q-learning objective. + grad_norm_clipping: float or None + clip gradient norms to this value. If None no clipping is performed. + gamma: float + discount rate. + double_q: bool + if true will use Double Q Learning (https://arxiv.org/abs/1509.06461). + In general it is a good idea to keep it enabled. + scope: str or VariableScope + optional scope for variable_scope. + reuse: bool or None + whether or not the variables should be reused. To be able to reuse the + scope must be given. - Returns - ------- - act: (tf.Variable, bool, float) -> tf.Variable - function to select and action given observation. -` See the top of the file for details. - train: (object, np.array, np.array, object, np.array, np.array) -> np.array - optimize the error in Bellman's equation. -` See the top of the file for details. - update_target: () -> () - copy the parameters from optimized Q function to the target Q function. -` See the top of the file for details. - debug: {str: function} - a bunch of functions to print debug data like q_values. - """ - act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse) + Returns + ------- + act: (tf.Variable, bool, float) -> tf.Variable + function to select and action given observation. + ` See the top of the file for details. + train: (object, np.array, np.array, object, np.array, np.array) -> np.array + optimize the error in Bellman's equation. + ` See the top of the file for details. + update_target: () -> () + copy the parameters from optimized Q function to the target Q function. + ` See the top of the file for details. + debug: {str: function} + a bunch of functions to print debug data like q_values. + """ + act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, + reuse=reuse) - with tf.variable_scope(scope, reuse=reuse): - # set up placeholders - obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t")) - act_t_ph = tf.placeholder(tf.int32, [None], name="action") - rew_t_ph = tf.placeholder(tf.float32, [None], name="reward") - obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1")) - done_mask_ph = tf.placeholder(tf.float32, [None], name="done") - importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight") + with tf.variable_scope(scope, reuse=reuse): + # set up placeholders + obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t")) + act_t_ph = tf.placeholder(tf.int32, [None], name="action") + rew_t_ph = tf.placeholder(tf.float32, [None], name="reward") + obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1")) + done_mask_ph = tf.placeholder(tf.float32, [None], name="done") + importance_weights_ph = tf.placeholder(tf.float32, [None], + name="weight") - # q network evaluation - q_t = q_func( - obs_t_input.get(), num_actions, scope="q_func", - reuse=True) # reuse parameters from act - q_func_vars = U.scope_vars(U.absolute_scope_name("q_func")) + # q network evaluation + q_t = q_func( + obs_t_input.get(), num_actions, scope="q_func", + reuse=True) # reuse parameters from act + q_func_vars = U.scope_vars(U.absolute_scope_name("q_func")) - # target q network evalution - q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func") - target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func")) + # target q network evalution + q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func") + target_q_func_vars = U.scope_vars( + U.absolute_scope_name("target_q_func")) - # q scores for actions which we know were selected in the given state. - q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1) + # q scores for actions which we know were selected in the given state. + q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), + 1) - # compute estimate of best possible value starting from state at t + 1 - if double_q: - q_tp1_using_online_net = q_func( - obs_tp1_input.get(), num_actions, scope="q_func", reuse=True) - q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1) - q_tp1_best = tf.reduce_sum( - q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1) - else: - q_tp1_best = tf.reduce_max(q_tp1, 1) - q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best + # compute estimate of best possible value starting from state at t + 1 + if double_q: + q_tp1_using_online_net = q_func( + obs_tp1_input.get(), num_actions, scope="q_func", reuse=True) + q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1) + q_tp1_best = tf.reduce_sum( + q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), + 1) + else: + q_tp1_best = tf.reduce_max(q_tp1, 1) + q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best - # compute RHS of bellman equation - q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked + # compute RHS of bellman equation + q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked - # compute the error (potentially clipped) - td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) - errors = U.huber_loss(td_error) - weighted_error = tf.reduce_mean(importance_weights_ph * errors) - # compute optimization op (potentially with gradient clipping) - if grad_norm_clipping is not None: - optimize_expr = U.minimize_and_clip( - optimizer, weighted_error, var_list=q_func_vars, - clip_val=grad_norm_clipping) - else: - optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars) + # compute the error (potentially clipped) + td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) + errors = U.huber_loss(td_error) + weighted_error = tf.reduce_mean(importance_weights_ph * errors) + # compute optimization op (potentially with gradient clipping) + if grad_norm_clipping is not None: + optimize_expr = U.minimize_and_clip( + optimizer, weighted_error, var_list=q_func_vars, + clip_val=grad_norm_clipping) + else: + optimize_expr = optimizer.minimize(weighted_error, + var_list=q_func_vars) - # update_target_fn will be called periodically to copy Q network to - # target Q network - update_target_expr = [] - for var, var_target in zip( - sorted(q_func_vars, key=lambda v: v.name), - sorted(target_q_func_vars, key=lambda v: v.name)): - update_target_expr.append(var_target.assign(var)) - update_target_expr = tf.group(*update_target_expr) + # update_target_fn will be called periodically to copy Q network to + # target Q network + update_target_expr = [] + for var, var_target in zip( + sorted(q_func_vars, key=lambda v: v.name), + sorted(target_q_func_vars, key=lambda v: v.name)): + update_target_expr.append(var_target.assign(var)) + update_target_expr = tf.group(*update_target_expr) - # Create callable functions - train = U.function( - inputs=[ - obs_t_input, - act_t_ph, - rew_t_ph, - obs_tp1_input, - done_mask_ph, - importance_weights_ph - ], - outputs=td_error, - updates=[optimize_expr]) - update_target = U.function([], [], updates=[update_target_expr]) + # Create callable functions + train = U.function( + inputs=[ + obs_t_input, + act_t_ph, + rew_t_ph, + obs_tp1_input, + done_mask_ph, + importance_weights_ph + ], + outputs=td_error, + updates=[optimize_expr]) + update_target = U.function([], [], updates=[update_target_expr]) - q_values = U.function([obs_t_input], q_t) + q_values = U.function([obs_t_input], q_t) - return act_f, train, update_target, {'q_values': q_values} + return act_f, train, update_target, {'q_values': q_values} diff --git a/python/ray/rllib/dqn/common/atari_wrappers_deprecated.py b/python/ray/rllib/dqn/common/atari_wrappers_deprecated.py index b9486c016..805744789 100644 --- a/python/ray/rllib/dqn/common/atari_wrappers_deprecated.py +++ b/python/ray/rllib/dqn/common/atari_wrappers_deprecated.py @@ -11,236 +11,240 @@ from gym import spaces class NoopResetEnv(gym.Wrapper): - def __init__(self, env=None, noop_max=30): - """Sample initial states by taking random number of no-ops on reset. - No-op is assumed to be action 0. - """ - super(NoopResetEnv, self).__init__(env) - self.noop_max = noop_max - self.override_num_noops = None - assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + def __init__(self, env=None, noop_max=30): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + super(NoopResetEnv, self).__init__(env) + self.noop_max = noop_max + self.override_num_noops = None + assert env.unwrapped.get_action_meanings()[0] == 'NOOP' - def _reset(self): - """ Do no-op action for a number of steps in [1, noop_max].""" - self.env.reset() - if self.override_num_noops is not None: - noops = self.override_num_noops - else: - noops = np.random.randint(1, self.noop_max + 1) - assert noops > 0 - obs = None - for _ in range(noops): - obs, _, done, _ = self.env.step(0) - if done: - obs = self.env.reset() - return obs + def _reset(self): + """ Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset() + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + noops = np.random.randint(1, self.noop_max + 1) + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, done, _ = self.env.step(0) + if done: + obs = self.env.reset() + return obs class FireResetEnv(gym.Wrapper): - def __init__(self, env=None): - """For environments where the user need to press FIRE for the game to - start.""" - super(FireResetEnv, self).__init__(env) - assert env.unwrapped.get_action_meanings()[1] == 'FIRE' - assert len(env.unwrapped.get_action_meanings()) >= 3 + def __init__(self, env=None): + """For environments where the user need to press FIRE for the game to + start.""" + super(FireResetEnv, self).__init__(env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 - def _reset(self): - self.env.reset() - obs, _, done, _ = self.env.step(1) - if done: - self.env.reset() - obs, _, done, _ = self.env.step(2) - if done: - self.env.reset() - return obs + def _reset(self): + self.env.reset() + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset() + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset() + return obs class EpisodicLifeEnv(gym.Wrapper): - def __init__(self, env=None): - """Make end-of-life == end-of-episode, but only reset on true game over. - Done by DeepMind for the DQN and co. since it helps value estimation. - """ - super(EpisodicLifeEnv, self).__init__(env) - self.lives = 0 - self.was_real_done = True - self.was_real_reset = False + def __init__(self, env=None): + """Make end-of-life == end-of-episode, but only reset on true game + over. Done by DeepMind for the DQN and co. since it helps value + estimation. + """ + super(EpisodicLifeEnv, self).__init__(env) + self.lives = 0 + self.was_real_done = True + self.was_real_reset = False - def _step(self, action): - obs, reward, done, info = self.env.step(action) - self.was_real_done = done - # check current lives, make loss of life terminal, - # then update lives to handle bonus lives - lives = self.env.unwrapped.ale.lives() - if lives < self.lives and lives > 0: - # for Qbert somtimes we stay in lives == 0 condtion for a few frames - # so its important to keep lives > 0, so that we only reset once - # the environment advertises done. - done = True - self.lives = lives - return obs, reward, done, info + def _step(self, action): + obs, reward, done, info = self.env.step(action) + self.was_real_done = done + # check current lives, make loss of life terminal, + # then update lives to handle bonus lives + lives = self.env.unwrapped.ale.lives() + if lives < self.lives and lives > 0: + # for Qbert somtimes we stay in lives == 0 condtion for a few + # frames so its important to keep lives > 0, so that we only reset + # once the environment advertises done. + done = True + self.lives = lives + return obs, reward, done, info - def _reset(self): - """Reset only when lives are exhausted. - This way all states are still reachable even though lives are episodic, - and the learner need not know about any of this behind-the-scenes. - """ - if self.was_real_done: - obs = self.env.reset() - self.was_real_reset = True - else: - # no-op step to advance from terminal/lost life state - obs, _, _, _ = self.env.step(0) - self.was_real_reset = False - self.lives = self.env.unwrapped.ale.lives() - return obs + def _reset(self): + """Reset only when lives are exhausted. + This way all states are still reachable even though lives are episodic, + and the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs = self.env.reset() + self.was_real_reset = True + else: + # no-op step to advance from terminal/lost life state + obs, _, _, _ = self.env.step(0) + self.was_real_reset = False + self.lives = self.env.unwrapped.ale.lives() + return obs class MaxAndSkipEnv(gym.Wrapper): - def __init__(self, env=None, skip=4): - """Return only every `skip`-th frame""" - super(MaxAndSkipEnv, self).__init__(env) - # most recent raw observations (for max pooling across time steps) - self._obs_buffer = deque(maxlen=2) - self._skip = skip + def __init__(self, env=None, skip=4): + """Return only every `skip`-th frame""" + super(MaxAndSkipEnv, self).__init__(env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = deque(maxlen=2) + self._skip = skip - def _step(self, action): - total_reward = 0.0 - done = None - for _ in range(self._skip): - obs, reward, done, info = self.env.step(action) - self._obs_buffer.append(obs) - total_reward += reward - if done: - break + def _step(self, action): + total_reward = 0.0 + done = None + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + self._obs_buffer.append(obs) + total_reward += reward + if done: + break - max_frame = np.max(np.stack(self._obs_buffer), axis=0) + max_frame = np.max(np.stack(self._obs_buffer), axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, done, info - def _reset(self): - """Clear past frame buffer and init. to first obs. from inner env.""" - self._obs_buffer.clear() - obs = self.env.reset() - self._obs_buffer.append(obs) - return obs + def _reset(self): + """Clear past frame buffer and init. to first obs. from inner env.""" + self._obs_buffer.clear() + obs = self.env.reset() + self._obs_buffer.append(obs) + return obs class ProcessFrame84(gym.ObservationWrapper): - def __init__(self, env=None): - super(ProcessFrame84, self).__init__(env) - self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) + def __init__(self, env=None): + super(ProcessFrame84, self).__init__(env) + self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) - def _observation(self, obs): - return ProcessFrame84.process(obs) + def _observation(self, obs): + return ProcessFrame84.process(obs) - @staticmethod - def process(frame): - if frame.size == 210 * 160 * 3: - img = np.reshape(frame, [210, 160, 3]).astype(np.float32) - elif frame.size == 250 * 160 * 3: - img = np.reshape(frame, [250, 160, 3]).astype(np.float32) - else: - assert False, "Unknown resolution." - img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 - resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA) - x_t = resized_screen[18:102, :] - x_t = np.reshape(x_t, [84, 84, 1]) - return x_t.astype(np.uint8) + @staticmethod + def process(frame): + if frame.size == 210 * 160 * 3: + img = np.reshape(frame, [210, 160, 3]).astype(np.float32) + elif frame.size == 250 * 160 * 3: + img = np.reshape(frame, [250, 160, 3]).astype(np.float32) + else: + assert False, "Unknown resolution." + img = (img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + + img[:, :, 2] * 0.114) + resized_screen = cv2.resize(img, (84, 110), + interpolation=cv2.INTER_AREA) + x_t = resized_screen[18:102, :] + x_t = np.reshape(x_t, [84, 84, 1]) + return x_t.astype(np.uint8) class ClippedRewardsWrapper(gym.RewardWrapper): - def _reward(self, reward): - """Change all the positive rewards to 1, negative to -1 and keep zero.""" - return np.sign(reward) + def _reward(self, reward): + """Change all the positive rewards to 1, negative to -1 and keep + zero.""" + return np.sign(reward) class LazyFrames(object): - def __init__(self, frames): - """This object ensures that common frames between the observations are only - stored once. It exists purely to optimize memory usage which can be huge - for DQN's 1M frames replay buffers. + def __init__(self, frames): + """This object ensures that common frames between the observations are + only stored once. It exists purely to optimize memory usage which can + be huge for DQN's 1M frames replay buffers. - This object should only be converted to numpy array before being passed to - the model. + This object should only be converted to numpy array before being passed + to the model. - You'd not belive how complex the previous solution was.""" - self._frames = frames + You'd not belive how complex the previous solution was.""" + self._frames = frames - def __array__(self, dtype=None): - out = np.concatenate(self._frames, axis=2) - if dtype is not None: - out = out.astype(dtype) - return out + def __array__(self, dtype=None): + out = np.concatenate(self._frames, axis=2) + if dtype is not None: + out = out.astype(dtype) + return out class FrameStack(gym.Wrapper): - def __init__(self, env, k): - """Stack k last frames. + def __init__(self, env, k): + """Stack k last frames. - Returns lazy array, which is much more memory efficient. + Returns lazy array, which is much more memory efficient. - See Also - -------- - ray.rllib.dqn.common.atari_wrappers.LazyFrames - """ - gym.Wrapper.__init__(self, env) - self.k = k - self.frames = deque([], maxlen=k) - shp = env.observation_space.shape - self.observation_space = spaces.Box( - low=0, high=255, shape=(shp[0], shp[1], shp[2] * k)) + See Also + -------- + ray.rllib.dqn.common.atari_wrappers.LazyFrames + """ + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = spaces.Box( + low=0, high=255, shape=(shp[0], shp[1], shp[2] * k)) - def _reset(self): - ob = self.env.reset() - for _ in range(self.k): - self.frames.append(ob) - return self._get_ob() + def _reset(self): + ob = self.env.reset() + for _ in range(self.k): + self.frames.append(ob) + return self._get_ob() - def _step(self, action): - ob, reward, done, info = self.env.step(action) - self.frames.append(ob) - return self._get_ob(), reward, done, info + def _step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self._get_ob(), reward, done, info - def _get_ob(self): - assert len(self.frames) == self.k - return LazyFrames(list(self.frames)) + def _get_ob(self): + assert len(self.frames) == self.k + return LazyFrames(list(self.frames)) class ScaledFloatFrame(gym.ObservationWrapper): - def _observation(self, obs): - # careful! This undoes the memory optimization, use - # with smaller replay buffers only. - return np.array(obs).astype(np.float32) / 255.0 + def _observation(self, obs): + # careful! This undoes the memory optimization, use + # with smaller replay buffers only. + return np.array(obs).astype(np.float32) / 255.0 def wrap_dqn(env): - """Apply a common set of wrappers for Atari games.""" - assert 'NoFrameskip' in env.spec.id - env = EpisodicLifeEnv(env) - env = NoopResetEnv(env, noop_max=30) - env = MaxAndSkipEnv(env, skip=4) - if 'FIRE' in env.unwrapped.get_action_meanings(): - env = FireResetEnv(env) - env = ProcessFrame84(env) - env = FrameStack(env, 4) - env = ClippedRewardsWrapper(env) - return env + """Apply a common set of wrappers for Atari games.""" + assert 'NoFrameskip' in env.spec.id + env = EpisodicLifeEnv(env) + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = ProcessFrame84(env) + env = FrameStack(env, 4) + env = ClippedRewardsWrapper(env) + return env class A2cProcessFrame(gym.Wrapper): - def __init__(self, env): - gym.Wrapper.__init__(self, env) - self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) + def __init__(self, env): + gym.Wrapper.__init__(self, env) + self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) - def _step(self, action): - ob, reward, done, info = self.env.step(action) - return A2cProcessFrame.process(ob), reward, done, info + def _step(self, action): + ob, reward, done, info = self.env.step(action) + return A2cProcessFrame.process(ob), reward, done, info - def _reset(self): - return A2cProcessFrame.process(self.env.reset()) + def _reset(self): + return A2cProcessFrame.process(self.env.reset()) - @staticmethod - def process(frame): - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA) - return frame.reshape(84, 84, 1) + @staticmethod + def process(frame): + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA) + return frame.reshape(84, 84, 1) diff --git a/python/ray/rllib/dqn/common/schedules.py b/python/ray/rllib/dqn/common/schedules.py index 21dc0695f..d9ceb2f76 100644 --- a/python/ray/rllib/dqn/common/schedules.py +++ b/python/ray/rllib/dqn/common/schedules.py @@ -13,95 +13,96 @@ from __future__ import print_function class Schedule(object): - def value(self, t): - """Value of the schedule at time t""" - raise NotImplementedError() + def value(self, t): + """Value of the schedule at time t""" + raise NotImplementedError() class ConstantSchedule(object): - def __init__(self, value): - """Value remains constant over time. + def __init__(self, value): + """Value remains constant over time. - Parameters - ---------- - value: float - Constant value of the schedule - """ - self._v = value + Parameters + ---------- + value: float + Constant value of the schedule + """ + self._v = value - def value(self, t): - """See Schedule.value""" - return self._v + def value(self, t): + """See Schedule.value""" + return self._v def linear_interpolation(l, r, alpha): - return l + alpha * (r - l) + return l + alpha * (r - l) class PiecewiseSchedule(object): - def __init__( - self, endpoints, interpolation=linear_interpolation, - outside_value=None): + def __init__( + self, endpoints, interpolation=linear_interpolation, + outside_value=None): - """Piecewise schedule. + """Piecewise schedule. - endpoints: [(int, int)] - list of pairs `(time, value)` meanining that schedule should output - `value` when `t==time`. All the values for time must be sorted in - an increasing order. When t is between two times, e.g. - `(time_a, value_a)` - and `(time_b, value_b)`, such that `time_a <= t < time_b` then value - outputs `interpolation(value_a, value_b, alpha)` where alpha is a - fraction of time passed between `time_a` and `time_b` for time `t`. - interpolation: lambda float, float, float: float - a function that takes value to the left and to the right of t according - to the `endpoints`. Alpha is the fraction of distance from left endpoint - to right endpoint that t has covered. See linear_interpolation for - example. - outside_value: float - if the value is requested outside of all the intervals sepecified in - `endpoints` this value is returned. If None then AssertionError is - raised when outside value is requested. - """ - idxes = [e[0] for e in endpoints] - assert idxes == sorted(idxes) - self._interpolation = interpolation - self._outside_value = outside_value - self._endpoints = endpoints + endpoints: [(int, int)] + list of pairs `(time, value)` meanining that schedule should output + `value` when `t==time`. All the values for time must be sorted in + an increasing order. When t is between two times, e.g. + `(time_a, value_a)` + and `(time_b, value_b)`, such that `time_a <= t < time_b` then value + outputs `interpolation(value_a, value_b, alpha)` where alpha is a + fraction of time passed between `time_a` and `time_b` for time `t`. + interpolation: lambda float, float, float: float + a function that takes value to the left and to the right of t + according to the `endpoints`. Alpha is the fraction of distance from + left endpoint to right endpoint that t has covered. See + linear_interpolation for example. + outside_value: float + if the value is requested outside of all the intervals sepecified in + `endpoints` this value is returned. If None then AssertionError is + raised when outside value is requested. + """ + idxes = [e[0] for e in endpoints] + assert idxes == sorted(idxes) + self._interpolation = interpolation + self._outside_value = outside_value + self._endpoints = endpoints - def value(self, t): - """See Schedule.value""" - for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]): - if l_t <= t and t < r_t: - alpha = float(t - l_t) / (r_t - l_t) - return self._interpolation(l, r, alpha) + def value(self, t): + """See Schedule.value""" + for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], + self._endpoints[1:]): + if l_t <= t and t < r_t: + alpha = float(t - l_t) / (r_t - l_t) + return self._interpolation(l, r, alpha) - # t does not belong to any of the pieces, so doom. - assert self._outside_value is not None - return self._outside_value + # t does not belong to any of the pieces, so doom. + assert self._outside_value is not None + return self._outside_value class LinearSchedule(object): - def __init__(self, schedule_timesteps, final_p, initial_p=1.0): - """Linear interpolation between initial_p and final_p over - schedule_timesteps. After this many timesteps pass final_p is - returned. + def __init__(self, schedule_timesteps, final_p, initial_p=1.0): + """Linear interpolation between initial_p and final_p over + schedule_timesteps. After this many timesteps pass final_p is + returned. - Parameters - ---------- - schedule_timesteps: int - Number of timesteps for which to linearly anneal initial_p - to final_p - initial_p: float - initial output value - final_p: float - final output value - """ - self.schedule_timesteps = schedule_timesteps - self.final_p = final_p - self.initial_p = initial_p + Parameters + ---------- + schedule_timesteps: int + Number of timesteps for which to linearly anneal initial_p + to final_p + initial_p: float + initial output value + final_p: float + final output value + """ + self.schedule_timesteps = schedule_timesteps + self.final_p = final_p + self.initial_p = initial_p - def value(self, t): - """See Schedule.value""" - fraction = min(float(t) / self.schedule_timesteps, 1.0) - return self.initial_p + fraction * (self.final_p - self.initial_p) + def value(self, t): + """See Schedule.value""" + fraction = min(float(t) / self.schedule_timesteps, 1.0) + return self.initial_p + fraction * (self.final_p - self.initial_p) diff --git a/python/ray/rllib/dqn/common/segment_tree.py b/python/ray/rllib/dqn/common/segment_tree.py index 70c18a63d..3409aa667 100644 --- a/python/ray/rllib/dqn/common/segment_tree.py +++ b/python/ray/rllib/dqn/common/segment_tree.py @@ -6,146 +6,148 @@ import operator class SegmentTree(object): - def __init__(self, capacity, operation, neutral_element): - """Build a Segment Tree data structure. + def __init__(self, capacity, operation, neutral_element): + """Build a Segment Tree data structure. - https://en.wikipedia.org/wiki/Segment_tree + https://en.wikipedia.org/wiki/Segment_tree - Can be used as regular array, but with two - important differences: + Can be used as regular array, but with two + important differences: - a) setting item's value is slightly slower. - It is O(lg capacity) instead of O(1). - b) user has access to an efficient `reduce` - operation which reduces `operation` over - a contiguous subsequence of items in the - array. + a) setting item's value is slightly slower. + It is O(lg capacity) instead of O(1). + b) user has access to an efficient `reduce` + operation which reduces `operation` over + a contiguous subsequence of items in the + array. - Paramters - --------- - capacity: int - Total size of the array - must be a power of two. - operation: lambda obj, obj -> obj - and operation for combining elements (eg. sum, max) - must for a mathematical group together with the set of - possible values for array elements. - neutral_element: obj - neutral element for the operation above. eg. float('-inf') - for max and 0 for sum. - """ + Paramters + --------- + capacity: int + Total size of the array - must be a power of two. + operation: lambda obj, obj -> obj + and operation for combining elements (eg. sum, max) + must for a mathematical group together with the set of + possible values for array elements. + neutral_element: obj + neutral element for the operation above. eg. float('-inf') + for max and 0 for sum. + """ - assert capacity > 0 and capacity & (capacity - 1) == 0, \ - "capacity must be positive and a power of 2." - self._capacity = capacity - self._value = [neutral_element for _ in range(2 * capacity)] - self._operation = operation + assert capacity > 0 and capacity & (capacity - 1) == 0, \ + "capacity must be positive and a power of 2." + self._capacity = capacity + self._value = [neutral_element for _ in range(2 * capacity)] + self._operation = operation - def _reduce_helper(self, start, end, node, node_start, node_end): - if start == node_start and end == node_end: - return self._value[node] - mid = (node_start + node_end) // 2 - if end <= mid: - return self._reduce_helper(start, end, 2 * node, node_start, mid) - else: - if mid + 1 <= start: - return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) - else: - return self._operation( - self._reduce_helper(start, mid, 2 * node, node_start, mid), - self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) - ) + def _reduce_helper(self, start, end, node, node_start, node_end): + if start == node_start and end == node_end: + return self._value[node] + mid = (node_start + node_end) // 2 + if end <= mid: + return self._reduce_helper(start, end, 2 * node, node_start, mid) + else: + if mid + 1 <= start: + return self._reduce_helper(start, end, 2 * node + 1, mid + 1, + node_end) + else: + return self._operation( + self._reduce_helper(start, mid, 2 * node, node_start, mid), + self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, + node_end) + ) - def reduce(self, start=0, end=None): - """Returns result of applying `self.operation` - to a contiguous subsequence of the array. + def reduce(self, start=0, end=None): + """Returns result of applying `self.operation` + to a contiguous subsequence of the array. - self.operation( - arr[start], operation(arr[start+1], operation(... arr[end]))) + self.operation( + arr[start], operation(arr[start+1], operation(... arr[end]))) - Parameters - ---------- - start: int - beginning of the subsequence - end: int - end of the subsequences + Parameters + ---------- + start: int + beginning of the subsequence + end: int + end of the subsequences - Returns - ------- - reduced: obj - result of reducing self.operation over the specified range of array - elements. - """ - if end is None: - end = self._capacity - if end < 0: - end += self._capacity - end -= 1 - return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + Returns + ------- + reduced: obj + result of reducing self.operation over the specified range of array + elements. + """ + if end is None: + end = self._capacity + if end < 0: + end += self._capacity + end -= 1 + return self._reduce_helper(start, end, 1, 0, self._capacity - 1) - def __setitem__(self, idx, val): - # index of the leaf - idx += self._capacity - self._value[idx] = val - idx //= 2 - while idx >= 1: - self._value[idx] = self._operation( - self._value[2 * idx], - self._value[2 * idx + 1]) - idx //= 2 + def __setitem__(self, idx, val): + # index of the leaf + idx += self._capacity + self._value[idx] = val + idx //= 2 + while idx >= 1: + self._value[idx] = self._operation( + self._value[2 * idx], + self._value[2 * idx + 1]) + idx //= 2 - def __getitem__(self, idx): - assert 0 <= idx < self._capacity - return self._value[self._capacity + idx] + def __getitem__(self, idx): + assert 0 <= idx < self._capacity + return self._value[self._capacity + idx] class SumSegmentTree(SegmentTree): - def __init__(self, capacity): - super(SumSegmentTree, self).__init__( - capacity=capacity, - operation=operator.add, - neutral_element=0.0) + def __init__(self, capacity): + super(SumSegmentTree, self).__init__( + capacity=capacity, + operation=operator.add, + neutral_element=0.0) - def sum(self, start=0, end=None): - """Returns arr[start] + ... + arr[end]""" - return super(SumSegmentTree, self).reduce(start, end) + def sum(self, start=0, end=None): + """Returns arr[start] + ... + arr[end]""" + return super(SumSegmentTree, self).reduce(start, end) - def find_prefixsum_idx(self, prefixsum): - """Find the highest index `i` in the array such that - sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum + def find_prefixsum_idx(self, prefixsum): + """Find the highest index `i` in the array such that + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum - if array values are probabilities, this function - allows to sample indexes according to the discrete - probability efficiently. + if array values are probabilities, this function + allows to sample indexes according to the discrete + probability efficiently. - Parameters - ---------- - perfixsum: float - upperbound on the sum of array prefix + Parameters + ---------- + perfixsum: float + upperbound on the sum of array prefix - Returns - ------- - idx: int - highest index satisfying the prefixsum constraint - """ - assert 0 <= prefixsum <= self.sum() + 1e-5 - idx = 1 - while idx < self._capacity: # while non-leaf - if self._value[2 * idx] > prefixsum: - idx = 2 * idx - else: - prefixsum -= self._value[2 * idx] - idx = 2 * idx + 1 - return idx - self._capacity + Returns + ------- + idx: int + highest index satisfying the prefixsum constraint + """ + assert 0 <= prefixsum <= self.sum() + 1e-5 + idx = 1 + while idx < self._capacity: # while non-leaf + if self._value[2 * idx] > prefixsum: + idx = 2 * idx + else: + prefixsum -= self._value[2 * idx] + idx = 2 * idx + 1 + return idx - self._capacity class MinSegmentTree(SegmentTree): - def __init__(self, capacity): - super(MinSegmentTree, self).__init__( - capacity=capacity, - operation=min, - neutral_element=float('inf')) + def __init__(self, capacity): + super(MinSegmentTree, self).__init__( + capacity=capacity, + operation=min, + neutral_element=float('inf')) - def min(self, start=0, end=None): - """Returns min(arr[start], ..., arr[end])""" + def min(self, start=0, end=None): + """Returns min(arr[start], ..., arr[end])""" - return super(MinSegmentTree, self).reduce(start, end) + return super(MinSegmentTree, self).reduce(start, end) diff --git a/python/ray/rllib/dqn/common/tf_util.py b/python/ray/rllib/dqn/common/tf_util.py index c4254c4fb..bbc718730 100644 --- a/python/ray/rllib/dqn/common/tf_util.py +++ b/python/ray/rllib/dqn/common/tf_util.py @@ -19,57 +19,57 @@ clip = tf.clip_by_value def sum(x, axis=None, keepdims=False): - axis = None if axis is None else [axis] - return tf.reduce_sum(x, axis=axis, keep_dims=keepdims) + axis = None if axis is None else [axis] + return tf.reduce_sum(x, axis=axis, keep_dims=keepdims) def mean(x, axis=None, keepdims=False): - axis = None if axis is None else [axis] - return tf.reduce_mean(x, axis=axis, keep_dims=keepdims) + axis = None if axis is None else [axis] + return tf.reduce_mean(x, axis=axis, keep_dims=keepdims) def var(x, axis=None, keepdims=False): - meanx = mean(x, axis=axis, keepdims=keepdims) - return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims) + meanx = mean(x, axis=axis, keepdims=keepdims) + return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims) def std(x, axis=None, keepdims=False): - return tf.sqrt(var(x, axis=axis, keepdims=keepdims)) + return tf.sqrt(var(x, axis=axis, keepdims=keepdims)) def max(x, axis=None, keepdims=False): - axis = None if axis is None else [axis] - return tf.reduce_max(x, axis=axis, keep_dims=keepdims) + axis = None if axis is None else [axis] + return tf.reduce_max(x, axis=axis, keep_dims=keepdims) def min(x, axis=None, keepdims=False): - axis = None if axis is None else [axis] - return tf.reduce_min(x, axis=axis, keep_dims=keepdims) + axis = None if axis is None else [axis] + return tf.reduce_min(x, axis=axis, keep_dims=keepdims) def concatenate(arrs, axis=0): - return tf.concat(axis=axis, values=arrs) + return tf.concat(axis=axis, values=arrs) def argmax(x, axis=None): - return tf.argmax(x, axis=axis) + return tf.argmax(x, axis=axis) def switch(condition, then_expression, else_expression): - """Switches between two operations depending on a scalar value (int or bool). - Note that both `then_expression` and `else_expression` - should be symbolic tensors of the *same shape*. + """Switches between two operations depending on a scalar value (int or + bool). Note that both `then_expression` and `else_expression` + should be symbolic tensors of the *same shape*. - # Arguments - condition: scalar tensor. - then_expression: TensorFlow operation. - else_expression: TensorFlow operation. - """ - x_shape = copy.copy(then_expression.get_shape()) - x = tf.cond(tf.cast(condition, 'bool'), - lambda: then_expression, lambda: else_expression) - x.set_shape(x_shape) - return x + # Arguments + condition: scalar tensor. + then_expression: TensorFlow operation. + else_expression: TensorFlow operation. + """ + x_shape = copy.copy(then_expression.get_shape()) + x = tf.cond(tf.cast(condition, 'bool'), + lambda: then_expression, lambda: else_expression) + x.set_shape(x_shape) + return x # ================================================================ # Extras @@ -77,22 +77,22 @@ def switch(condition, then_expression, else_expression): def l2loss(params): - if len(params) == 0: - return tf.constant(0.0) - else: - return tf.add_n([sum(tf.square(p)) for p in params]) + if len(params) == 0: + return tf.constant(0.0) + else: + return tf.add_n([sum(tf.square(p)) for p in params]) def lrelu(x, leak=0.2): - f1 = 0.5 * (1 + leak) - f2 = 0.5 * (1 - leak) - return f1 * x + f2 * abs(x) + f1 = 0.5 * (1 + leak) + f2 = 0.5 * (1 - leak) + return f1 * x + f2 * abs(x) def categorical_sample_logits(X): - # https://github.com/tensorflow/tensorflow/issues/456 - U = tf.random_uniform(tf.shape(X)) - return argmax(X - tf.log(-tf.log(U)), axis=1) + # https://github.com/tensorflow/tensorflow/issues/456 + U = tf.random_uniform(tf.shape(X)) + return argmax(X - tf.log(-tf.log(U)), axis=1) # ================================================================ @@ -101,89 +101,92 @@ def categorical_sample_logits(X): def is_placeholder(x): - return type(x) is tf.Tensor and len(x.op.inputs) == 0 + return type(x) is tf.Tensor and len(x.op.inputs) == 0 class TfInput(object): - def __init__(self, name="(unnamed)"): - """Generalized Tensorflow placeholder. The main differences are: - - possibly uses multiple placeholders internally and returns multiple - values - - can apply light postprocessing to the value feed to placeholder. - """ - self.name = name + def __init__(self, name="(unnamed)"): + """Generalized Tensorflow placeholder. The main differences are: + - possibly uses multiple placeholders internally and returns multiple + values + - can apply light postprocessing to the value feed to placeholder. + """ + self.name = name - def get(self): - """Return the tf variable(s) representing the possibly postprocessed value - of placeholder(s). - """ - raise NotImplemented() + def get(self): + """Return the tf variable(s) representing the possibly postprocessed + value of placeholder(s). + """ + raise NotImplemented() - def make_feed_dict(data): - """Given data input it to the placeholder(s).""" - raise NotImplemented() + def make_feed_dict(data): + """Given data input it to the placeholder(s).""" + raise NotImplemented() class PlacholderTfInput(TfInput): - def __init__(self, placeholder): - """Wrapper for regular tensorflow placeholder.""" - super().__init__(placeholder.name) - self._placeholder = placeholder + def __init__(self, placeholder): + """Wrapper for regular tensorflow placeholder.""" + super().__init__(placeholder.name) + self._placeholder = placeholder - def get(self): - return self._placeholder + def get(self): + return self._placeholder - def make_feed_dict(self, data): - return {self._placeholder: data} + def make_feed_dict(self, data): + return {self._placeholder: data} class BatchInput(PlacholderTfInput): - def __init__(self, shape, dtype=tf.float32, name=None): - """Creates a placeholder for a batch of tensors of a given shape and dtype + def __init__(self, shape, dtype=tf.float32, name=None): + """Creates a placeholder for a batch of tensors of a given shape and + dtype - Parameters - ---------- - shape: [int] - shape of a single elemenet of the batch - dtype: tf.dtype - number representation used for tensor contents - name: str - name of the underlying placeholder - """ - super().__init__(tf.placeholder(dtype, [None] + list(shape), name=name)) + Parameters + ---------- + shape: [int] + shape of a single elemenet of the batch + dtype: tf.dtype + number representation used for tensor contents + name: str + name of the underlying placeholder + """ + super().__init__(tf.placeholder(dtype, [None] + list(shape), + name=name)) class Uint8Input(PlacholderTfInput): - def __init__(self, shape, name=None): - """Takes input in uint8 format which is cast to float32 and divided by 255 - before passing it to the model. + def __init__(self, shape, name=None): + """Takes input in uint8 format which is cast to float32 and divided by + 255 before passing it to the model. - On GPU this ensures lower data transfer times. + On GPU this ensures lower data transfer times. - Parameters - ---------- - shape: [int] - shape of the tensor. - name: str - name of the underlying placeholder - """ + Parameters + ---------- + shape: [int] + shape of the tensor. + name: str + name of the underlying placeholder + """ - super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name)) - self._shape = shape - self._output = tf.cast(super().get(), tf.float32) / 255.0 + super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), + name=name)) + self._shape = shape + self._output = tf.cast(super().get(), tf.float32) / 255.0 - def get(self): - return self._output + def get(self): + return self._output def ensure_tf_input(thing): - """Takes either tf.placeholder of TfInput and outputs equivalent TfInput""" - if isinstance(thing, TfInput): - return thing - elif is_placeholder(thing): - return PlacholderTfInput(thing) - else: - raise ValueError("Must be a placeholder or TfInput") + """Takes either tf.placeholder of TfInput and outputs equivalent TfInput""" + if isinstance(thing, TfInput): + return thing + elif is_placeholder(thing): + return PlacholderTfInput(thing) + else: + raise ValueError("Must be a placeholder or TfInput") # ================================================================ # Mathematical utils @@ -191,11 +194,11 @@ def ensure_tf_input(thing): def huber_loss(x, delta=1.0): - """Reference: https://en.wikipedia.org/wiki/Huber_loss""" - return tf.where( - tf.abs(x) < delta, - tf.square(x) * 0.5, - delta * (tf.abs(x) - 0.5 * delta)) + """Reference: https://en.wikipedia.org/wiki/Huber_loss""" + return tf.where( + tf.abs(x) < delta, + tf.square(x) * 0.5, + delta * (tf.abs(x) - 0.5 * delta)) # ================================================================ # Optimizer utils @@ -203,15 +206,15 @@ def huber_loss(x, delta=1.0): def minimize_and_clip(optimizer, objective, var_list, clip_val=10): - """Minimized `objective` using `optimizer` w.r.t. variables in - `var_list` while ensure the norm of the gradients for each - variable is clipped to `clip_val` - """ - gradients = optimizer.compute_gradients(objective, var_list=var_list) - for i, (grad, var) in enumerate(gradients): - if grad is not None: - gradients[i] = (tf.clip_by_norm(grad, clip_val), var) - return optimizer.apply_gradients(gradients) + """Minimized `objective` using `optimizer` w.r.t. variables in + `var_list` while ensure the norm of the gradients for each + variable is clipped to `clip_val` + """ + gradients = optimizer.compute_gradients(objective, var_list=var_list) + for i, (grad, var) in enumerate(gradients): + if grad is not None: + gradients[i] = (tf.clip_by_norm(grad, clip_val), var) + return optimizer.apply_gradients(gradients) # ================================================================ @@ -219,51 +222,51 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10): # ================================================================ def get_session(): - """Returns recently made Tensorflow session""" - return tf.get_default_session() + """Returns recently made Tensorflow session""" + return tf.get_default_session() def make_session(num_cpu): - """Returns a session that will use CPU's only""" - tf_config = tf.ConfigProto( - inter_op_parallelism_threads=num_cpu, - intra_op_parallelism_threads=num_cpu) - return tf.Session(config=tf_config) + """Returns a session that will use CPU's only""" + tf_config = tf.ConfigProto( + inter_op_parallelism_threads=num_cpu, + intra_op_parallelism_threads=num_cpu) + return tf.Session(config=tf_config) def single_threaded_session(): - """Returns a session which will only use a single CPU""" - return make_session(1) + """Returns a session which will only use a single CPU""" + return make_session(1) ALREADY_INITIALIZED = set() def initialize(): - """Initialize all the uninitialized variables in the global scope.""" - new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED - get_session().run(tf.variables_initializer(new_variables)) - ALREADY_INITIALIZED.update(new_variables) + """Initialize all the uninitialized variables in the global scope.""" + new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED + get_session().run(tf.variables_initializer(new_variables)) + ALREADY_INITIALIZED.update(new_variables) def eval(expr, feed_dict=None): - if feed_dict is None: - feed_dict = {} - return get_session().run(expr, feed_dict=feed_dict) + if feed_dict is None: + feed_dict = {} + return get_session().run(expr, feed_dict=feed_dict) VALUE_SETTERS = collections.OrderedDict() def set_value(v, val): - global VALUE_SETTERS - if v in VALUE_SETTERS: - set_op, set_endpoint = VALUE_SETTERS[v] - else: - set_endpoint = tf.placeholder(v.dtype) - set_op = v.assign(set_endpoint) - VALUE_SETTERS[v] = (set_op, set_endpoint) - get_session().run(set_op, feed_dict={set_endpoint: val}) + global VALUE_SETTERS + if v in VALUE_SETTERS: + set_op, set_endpoint = VALUE_SETTERS[v] + else: + set_endpoint = tf.placeholder(v.dtype) + set_op = v.assign(set_endpoint) + VALUE_SETTERS[v] = (set_op, set_endpoint) + get_session().run(set_op, feed_dict={set_endpoint: val}) # ================================================================ @@ -272,14 +275,14 @@ def set_value(v, val): def load_state(fname): - saver = tf.train.Saver() - saver.restore(get_session(), fname) + saver = tf.train.Saver() + saver.restore(get_session(), fname) def save_state(fname): - os.makedirs(os.path.dirname(fname), exist_ok=True) - saver = tf.train.Saver() - saver.save(get_session(), fname) + os.makedirs(os.path.dirname(fname), exist_ok=True) + saver = tf.train.Saver() + saver.save(get_session(), fname) # ================================================================ # Model components @@ -287,88 +290,89 @@ def save_state(fname): def normc_initializer(std=1.0): - # pylint: disable=W0613 - def _initializer(shape, dtype=None, partition_info=None): - out = np.random.randn(*shape).astype(np.float32) - out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) - return tf.constant(out) - return _initializer + # pylint: disable=W0613 + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer def conv2d( x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, summary_tag=None): - with tf.variable_scope(name): - stride_shape = [1, stride[0], stride[1], 1] - filter_shape = [ - filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] + with tf.variable_scope(name): + stride_shape = [1, stride[0], stride[1], 1] + filter_shape = [ + filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] - # there are "num input feature maps * filter height * filter width" - # inputs to each hidden unit - fan_in = intprod(filter_shape[:3]) - # each unit in the lower layer receives a gradient from: - # "num output feature maps * filter height * filter width" / - # pooling size - fan_out = intprod(filter_shape[:2]) * num_filters - # initialize weights with random weights - w_bound = np.sqrt(6. / (fan_in + fan_out)) + # there are "num input feature maps * filter height * filter width" + # inputs to each hidden unit + fan_in = intprod(filter_shape[:3]) + # each unit in the lower layer receives a gradient from: + # "num output feature maps * filter height * filter width" / + # pooling size + fan_out = intprod(filter_shape[:2]) * num_filters + # initialize weights with random weights + w_bound = np.sqrt(6. / (fan_in + fan_out)) - w = tf.get_variable( - "W", filter_shape, dtype, - tf.random_uniform_initializer(-w_bound, w_bound), - collections=collections) - b = tf.get_variable( - "b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(), - collections=collections) + w = tf.get_variable( + "W", filter_shape, dtype, + tf.random_uniform_initializer(-w_bound, w_bound), + collections=collections) + b = tf.get_variable( + "b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(), + collections=collections) - if summary_tag is not None: - tf.summary.image( - summary_tag, - tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]), - [2, 0, 1, 3]), - max_images=10) + if summary_tag is not None: + tf.summary.image( + summary_tag, + tf.transpose(tf.reshape(w, [filter_size[0], + filter_size[1], -1, 1]), + [2, 0, 1, 3]), + max_images=10) - return tf.nn.conv2d(x, w, stride_shape, pad) + b + return tf.nn.conv2d(x, w, stride_shape, pad) + b def dense(x, size, name, weight_init=None, bias=True): - w = tf.get_variable(name + "/w", [x.get_shape()[1], size], - initializer=weight_init) - ret = tf.matmul(x, w) - if bias: - b = tf.get_variable( - name + "/b", [size], initializer=tf.zeros_initializer()) - return ret + b - else: - return ret + w = tf.get_variable(name + "/w", [x.get_shape()[1], size], + initializer=weight_init) + ret = tf.matmul(x, w) + if bias: + b = tf.get_variable( + name + "/b", [size], initializer=tf.zeros_initializer()) + return ret + b + else: + return ret def wndense(x, size, name, init_scale=1.0): - v = tf.get_variable( - name + "/V", [int(x.get_shape()[1]), size], - initializer=tf.random_normal_initializer(0, 0.05)) - g = tf.get_variable( - name + "/g", [size], initializer=tf.constant_initializer(init_scale)) - b = tf.get_variable( - name + "/b", [size], initializer=tf.constant_initializer(0.0)) + v = tf.get_variable( + name + "/V", [int(x.get_shape()[1]), size], + initializer=tf.random_normal_initializer(0, 0.05)) + g = tf.get_variable( + name + "/g", [size], initializer=tf.constant_initializer(init_scale)) + b = tf.get_variable( + name + "/b", [size], initializer=tf.constant_initializer(0.0)) - # use weight normalization (Salimans & Kingma, 2016) - x = tf.matmul(x, v) - scaler = g / tf.sqrt(sum(tf.square(v), axis=0, keepdims=True)) - return tf.reshape(scaler, [1, size]) * x + tf.reshape(b, [1, size]) + # use weight normalization (Salimans & Kingma, 2016) + x = tf.matmul(x, v) + scaler = g / tf.sqrt(sum(tf.square(v), axis=0, keepdims=True)) + return tf.reshape(scaler, [1, size]) * x + tf.reshape(b, [1, size]) def densenobias(x, size, name, weight_init=None): - return dense(x, size, name, weight_init=weight_init, bias=False) + return dense(x, size, name, weight_init=weight_init, bias=False) def dropout(x, pkeep, phase=None, mask=None): - mask = tf.floor( - pkeep + tf.random_uniform(tf.shape(x))) if mask is None else mask - if phase is None: - return mask * x - else: - return switch(phase, mask * x, pkeep * x) + mask = tf.floor( + pkeep + tf.random_uniform(tf.shape(x))) if mask is None else mask + if phase is None: + return mask * x + else: + return switch(phase, mask * x, pkeep * x) # ================================================================ @@ -377,138 +381,143 @@ def dropout(x, pkeep, phase=None, mask=None): def function(inputs, outputs, updates=None, givens=None): - """Just like Theano function. Take a bunch of tensorflow placeholders and - expressions computed based on those placeholders and produces f(inputs) -> - outputs. Function f takes values to be fed to the input's placeholders and - produces the values of the expressions in outputs. + """Just like Theano function. Take a bunch of tensorflow placeholders and + expressions computed based on those placeholders and produces f(inputs) -> + outputs. Function f takes values to be fed to the input's placeholders and + produces the values of the expressions in outputs. - Input values can be passed in the same order as inputs or can be provided as - kwargs based on placeholder name (passed to constructor or accessible via - placeholder.op.name). + Input values can be passed in the same order as inputs or can be provided + as kwargs based on placeholder name (passed to constructor or accessible + via placeholder.op.name). - Example: - x = tf.placeholder(tf.int32, (), name="x") - y = tf.placeholder(tf.int32, (), name="y") - z = 3 * x + 2 * y - lin = function([x, y], z, givens={y: 0}) + Example: + x = tf.placeholder(tf.int32, (), name="x") + y = tf.placeholder(tf.int32, (), name="y") + z = 3 * x + 2 * y + lin = function([x, y], z, givens={y: 0}) - with single_threaded_session(): - initialize() + with single_threaded_session(): + initialize() - assert lin(2) == 6 - assert lin(x=3) == 9 - assert lin(2, 2) == 10 - assert lin(x=2, y=3) == 12 + assert lin(2) == 6 + assert lin(x=3) == 9 + assert lin(2, 2) == 10 + assert lin(x=2, y=3) == 12 - Parameters - ---------- - inputs: [tf.placeholder or TfInput] - list of input arguments - outputs: [tf.Variable] or tf.Variable - list of outputs or a single output to be returned from function. Returned - value will also have the same shape. - """ - if isinstance(outputs, list): - return _Function(inputs, outputs, updates, givens=givens) - elif isinstance(outputs, (dict, collections.OrderedDict)): - f = _Function(inputs, outputs.values(), updates, givens=givens) - return lambda *args, **kwargs: type(outputs)( - zip(outputs.keys(), f(*args, **kwargs))) - else: - f = _Function(inputs, [outputs], updates, givens=givens) - return lambda *args, **kwargs: f(*args, **kwargs)[0] + Parameters + ---------- + inputs: [tf.placeholder or TfInput] + list of input arguments + outputs: [tf.Variable] or tf.Variable + list of outputs or a single output to be returned from function. Returned + value will also have the same shape. + """ + if isinstance(outputs, list): + return _Function(inputs, outputs, updates, givens=givens) + elif isinstance(outputs, (dict, collections.OrderedDict)): + f = _Function(inputs, outputs.values(), updates, givens=givens) + return lambda *args, **kwargs: type(outputs)( + zip(outputs.keys(), f(*args, **kwargs))) + else: + f = _Function(inputs, [outputs], updates, givens=givens) + return lambda *args, **kwargs: f(*args, **kwargs)[0] class _Function(object): - def __init__(self, inputs, outputs, updates, givens, check_nan=False): - for inpt in inputs: - if not issubclass(type(inpt), TfInput): - assert len(inpt.op.inputs) == 0, \ - "inputs should all be placeholders of ray.rllib.dqn.common.TfInput" - self.inputs = inputs - updates = updates or [] - self.update_group = tf.group(*updates) - self.outputs_update = list(outputs) + [self.update_group] - self.givens = {} if givens is None else givens - self.check_nan = check_nan + def __init__(self, inputs, outputs, updates, givens, check_nan=False): + for inpt in inputs: + if not issubclass(type(inpt), TfInput): + assert len(inpt.op.inputs) == 0, ( + "inputs should all be placeholders of " + "ray.rllib.dqn.common.TfInput") + self.inputs = inputs + updates = updates or [] + self.update_group = tf.group(*updates) + self.outputs_update = list(outputs) + [self.update_group] + self.givens = {} if givens is None else givens + self.check_nan = check_nan - def _feed_input(self, feed_dict, inpt, value): - if issubclass(type(inpt), TfInput): - feed_dict.update(inpt.make_feed_dict(value)) - elif is_placeholder(inpt): - feed_dict[inpt] = value + def _feed_input(self, feed_dict, inpt, value): + if issubclass(type(inpt), TfInput): + feed_dict.update(inpt.make_feed_dict(value)) + elif is_placeholder(inpt): + feed_dict[inpt] = value - def __call__(self, *args, **kwargs): - assert len(args) <= len(self.inputs), "Too many arguments provided" - feed_dict = {} - # Update the args - for inpt, value in zip(self.inputs, args): - self._feed_input(feed_dict, inpt, value) - # Update the kwargs - kwargs_passed_inpt_names = set() - for inpt in self.inputs[len(args):]: - inpt_name = inpt.name.split(':')[0] - inpt_name = inpt_name.split('/')[-1] - assert inpt_name not in kwargs_passed_inpt_names, \ - ("this function has two arguments with the same name \"{}\", " + - "so kwargs cannot be used.".format(inpt_name)) - if inpt_name in kwargs: - kwargs_passed_inpt_names.add(inpt_name) - self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name)) - else: - assert inpt in self.givens, "Missing argument " + inpt_name - assert len(kwargs) == 0, \ - "Function got extra arguments " + str(list(kwargs.keys())) - # Update feed dict with givens. - for inpt in self.givens: - feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt]) - results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] - if self.check_nan: - if any(np.isnan(r).any() for r in results): - raise RuntimeError("Nan detected") - return results + def __call__(self, *args, **kwargs): + assert len(args) <= len(self.inputs), "Too many arguments provided" + feed_dict = {} + # Update the args + for inpt, value in zip(self.inputs, args): + self._feed_input(feed_dict, inpt, value) + # Update the kwargs + kwargs_passed_inpt_names = set() + for inpt in self.inputs[len(args):]: + inpt_name = inpt.name.split(':')[0] + inpt_name = inpt_name.split('/')[-1] + assert inpt_name not in kwargs_passed_inpt_names, ( + "this function has two arguments with " + "the same name \"{}\", " + + "so kwargs cannot be used.".format(inpt_name)) + if inpt_name in kwargs: + kwargs_passed_inpt_names.add(inpt_name) + self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name)) + else: + assert inpt in self.givens, "Missing argument " + inpt_name + assert len(kwargs) == 0, \ + "Function got extra arguments " + str(list(kwargs.keys())) + # Update feed dict with givens. + for inpt in self.givens: + feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt]) + results = get_session().run(self.outputs_update, + feed_dict=feed_dict)[:-1] + if self.check_nan: + if any(np.isnan(r).any() for r in results): + raise RuntimeError("Nan detected") + return results def mem_friendly_function(nondata_inputs, data_inputs, outputs, batch_size): - if isinstance(outputs, list): - return _MemFriendlyFunction( - nondata_inputs, data_inputs, outputs, batch_size) - else: - f = _MemFriendlyFunction( - nondata_inputs, data_inputs, [outputs], batch_size) - return lambda *inputs: f(*inputs)[0] + if isinstance(outputs, list): + return _MemFriendlyFunction( + nondata_inputs, data_inputs, outputs, batch_size) + else: + f = _MemFriendlyFunction( + nondata_inputs, data_inputs, [outputs], batch_size) + return lambda *inputs: f(*inputs)[0] class _MemFriendlyFunction(object): - def __init__(self, nondata_inputs, data_inputs, outputs, batch_size): - self.nondata_inputs = nondata_inputs - self.data_inputs = data_inputs - self.outputs = list(outputs) - self.batch_size = batch_size + def __init__(self, nondata_inputs, data_inputs, outputs, batch_size): + self.nondata_inputs = nondata_inputs + self.data_inputs = data_inputs + self.outputs = list(outputs) + self.batch_size = batch_size - def __call__(self, *inputvals): - assert len(inputvals) == len(self.nondata_inputs) + len(self.data_inputs) - nondata_vals = inputvals[0:len(self.nondata_inputs)] - data_vals = inputvals[len(self.nondata_inputs):] - feed_dict = dict(zip(self.nondata_inputs, nondata_vals)) - n = data_vals[0].shape[0] - for v in data_vals[1:]: - assert v.shape[0] == n - for i_start in range(0, n, self.batch_size): - slice_vals = [ - v[i_start:builtins.min(i_start + self.batch_size, n)] - for v in data_vals] - for (var, val) in zip(self.data_inputs, slice_vals): - feed_dict[var] = val - results = tf.get_default_session().run(self.outputs, feed_dict=feed_dict) - if i_start == 0: - sum_results = results - else: + def __call__(self, *inputvals): + assert len(inputvals) == (len(self.nondata_inputs) + + len(self.data_inputs)) + nondata_vals = inputvals[0:len(self.nondata_inputs)] + data_vals = inputvals[len(self.nondata_inputs):] + feed_dict = dict(zip(self.nondata_inputs, nondata_vals)) + n = data_vals[0].shape[0] + for v in data_vals[1:]: + assert v.shape[0] == n + for i_start in range(0, n, self.batch_size): + slice_vals = [ + v[i_start:builtins.min(i_start + self.batch_size, n)] + for v in data_vals] + for (var, val) in zip(self.data_inputs, slice_vals): + feed_dict[var] = val + results = tf.get_default_session().run(self.outputs, + feed_dict=feed_dict) + if i_start == 0: + sum_results = results + else: + for i in range(len(results)): + sum_results[i] = sum_results[i] + results[i] for i in range(len(results)): - sum_results[i] = sum_results[i] + results[i] - for i in range(len(results)): - sum_results[i] = sum_results[i] / n - return sum_results + sum_results[i] = sum_results[i] / n + return sum_results # ================================================================ # Modules @@ -516,54 +525,55 @@ class _MemFriendlyFunction(object): class Module(object): - def __init__(self, name): - self.name = name - self.first_time = True - self.scope = None - self.cache = {} + def __init__(self, name): + self.name = name + self.first_time = True + self.scope = None + self.cache = {} - def __call__(self, *args): - if args in self.cache: - print("(%s) retrieving value from cache" % (self.name,)) - return self.cache[args] - with tf.variable_scope(self.name, reuse=not self.first_time): - scope = tf.get_variable_scope().name - if self.first_time: - self.scope = scope - print("(%s) running function for the first time" % (self.name,)) - else: - assert self.scope == scope, \ - "Tried calling function with a different scope" - print("(%s) running function on new inputs" % (self.name,)) - self.first_time = False - out = self._call(*args) - self.cache[args] = out - return out + def __call__(self, *args): + if args in self.cache: + print("(%s) retrieving value from cache" % (self.name,)) + return self.cache[args] + with tf.variable_scope(self.name, reuse=not self.first_time): + scope = tf.get_variable_scope().name + if self.first_time: + self.scope = scope + print("(%s) running function for the first time" % + (self.name,)) + else: + assert self.scope == scope, \ + "Tried calling function with a different scope" + print("(%s) running function on new inputs" % (self.name,)) + self.first_time = False + out = self._call(*args) + self.cache[args] = out + return out - def _call(self, *args): - raise NotImplementedError + def _call(self, *args): + raise NotImplementedError - @property - def trainable_variables(self): - assert self.scope is not None, \ - "need to call module once before getting variables" - return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) + @property + def trainable_variables(self): + assert self.scope is not None, \ + "need to call module once before getting variables" + return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) - @property - def variables(self): - assert self.scope is not None, \ - "need to call module once before getting variables" - return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) + @property + def variables(self): + assert self.scope is not None, \ + "need to call module once before getting variables" + return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) def module(name): - @functools.wraps - def wrapper(f): - class WrapperModule(Module): - def _call(self, *args): - return f(*args) - return WrapperModule(name) - return wrapper + @functools.wraps + def wrapper(f): + class WrapperModule(Module): + def _call(self, *args): + return f(*args) + return WrapperModule(name) + return wrapper # ================================================================ # Graph traversal @@ -574,44 +584,44 @@ VARIABLES = {} def get_parents(node): - return node.op.inputs + return node.op.inputs def topsorted(outputs): - """ - Topological sort via non-recursive depth-first search - """ - assert isinstance(outputs, (list, tuple)) - marks = {} - out = [] - stack = [] # pylint: disable=W0621 - # i: node - # jidx = number of children visited so far from that node - # marks: state of each node, which is one of - # 0: haven't visited - # 1: have visited, but not done visiting children - # 2: done visiting children - for x in outputs: - stack.append((x, 0)) - while stack: - (i, jidx) = stack.pop() - if jidx == 0: - m = marks.get(i, 0) - if m == 0: - marks[i] = 1 - elif m == 1: - raise ValueError("not a dag") - else: - continue - ps = get_parents(i) - if jidx == len(ps): - marks[i] = 2 - out.append(i) - else: - stack.append((i, jidx + 1)) - j = ps[jidx] - stack.append((j, 0)) - return out + """ + Topological sort via non-recursive depth-first search + """ + assert isinstance(outputs, (list, tuple)) + marks = {} + out = [] + stack = [] # pylint: disable=W0621 + # i: node + # jidx = number of children visited so far from that node + # marks: state of each node, which is one of + # 0: haven't visited + # 1: have visited, but not done visiting children + # 2: done visiting children + for x in outputs: + stack.append((x, 0)) + while stack: + (i, jidx) = stack.pop() + if jidx == 0: + m = marks.get(i, 0) + if m == 0: + marks[i] = 1 + elif m == 1: + raise ValueError("not a dag") + else: + continue + ps = get_parents(i) + if jidx == len(ps): + marks[i] = 2 + out.append(i) + else: + stack.append((i, jidx + 1)) + j = ps[jidx] + stack.append((j, 0)) + return out # ================================================================ @@ -619,55 +629,55 @@ def topsorted(outputs): # ================================================================ def var_shape(x): - out = x.get_shape().as_list() - assert all(isinstance(a, int) for a in out), \ - "shape function assumes that shape is fully known" - return out + out = x.get_shape().as_list() + assert all(isinstance(a, int) for a in out), \ + "shape function assumes that shape is fully known" + return out def numel(x): - return intprod(var_shape(x)) + return intprod(var_shape(x)) def intprod(x): - return int(np.prod(x)) + return int(np.prod(x)) def flatgrad(loss, var_list): - grads = tf.gradients(loss, var_list) - return tf.concat(axis=0, values=[ - tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)]) - for (v, grad) in zip(var_list, grads) - ]) + grads = tf.gradients(loss, var_list) + return tf.concat(axis=0, values=[ + tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)]) + for (v, grad) in zip(var_list, grads) + ]) class SetFromFlat(object): - def __init__(self, var_list, dtype=tf.float32): - assigns = [] - shapes = list(map(var_shape, var_list)) - total_size = np.sum([intprod(shape) for shape in shapes]) + def __init__(self, var_list, dtype=tf.float32): + assigns = [] + shapes = list(map(var_shape, var_list)) + total_size = np.sum([intprod(shape) for shape in shapes]) - self.theta = theta = tf.placeholder(dtype, [total_size]) - start = 0 - assigns = [] - for (shape, v) in zip(shapes, var_list): - size = intprod(shape) - assigns.append( - tf.assign(v, tf.reshape(theta[start:start + size], shape))) - start += size - self.op = tf.group(*assigns) + self.theta = theta = tf.placeholder(dtype, [total_size]) + start = 0 + assigns = [] + for (shape, v) in zip(shapes, var_list): + size = intprod(shape) + assigns.append( + tf.assign(v, tf.reshape(theta[start:start + size], shape))) + start += size + self.op = tf.group(*assigns) - def __call__(self, theta): - get_session().run(self.op, feed_dict={self.theta: theta}) + def __call__(self, theta): + get_session().run(self.op, feed_dict={self.theta: theta}) class GetFlat(object): - def __init__(self, var_list): - self.op = tf.concat( - axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list]) + def __init__(self, var_list): + self.op = tf.concat( + axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list]) - def __call__(self): - return get_session().run(self.op) + def __call__(self): + return get_session().run(self.op) # ================================================================ # Misc @@ -675,16 +685,16 @@ class GetFlat(object): def fancy_slice_2d(X, inds0, inds1): - """ - like numpy X[inds0, inds1] - XXX this implementation is bad - """ - inds0 = tf.cast(inds0, tf.int64) - inds1 = tf.cast(inds1, tf.int64) - shape = tf.cast(tf.shape(X), tf.int64) - ncols = shape[1] - Xflat = tf.reshape(X, [-1]) - return tf.gather(Xflat, inds0 * ncols + inds1) + """ + like numpy X[inds0, inds1] + XXX this implementation is bad + """ + inds0 = tf.cast(inds0, tf.int64) + inds1 = tf.cast(inds1, tf.int64) + shape = tf.cast(tf.shape(X), tf.int64) + ncols = shape[1] + Xflat = tf.reshape(X, [-1]) + return tf.gather(Xflat, inds0 * ncols + inds1) # ================================================================ @@ -693,90 +703,91 @@ def fancy_slice_2d(X, inds0, inds1): def scope_vars(scope, trainable_only=False): - """ - Get variables inside a scope - The scope can be specified as a string + """ + Get variables inside a scope + The scope can be specified as a string - Parameters - ---------- - scope: str or VariableScope - scope in which the variables reside. - trainable_only: bool - whether or not to return only the variables that were marked as trainable. + Parameters + ---------- + scope: str or VariableScope + scope in which the variables reside. + trainable_only: bool + whether or not to return only the variables that were marked as + trainable. - Returns - ------- - vars: [tf.Variable] - list of variables in `scope`. - """ - return tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES - if trainable_only else tf.GraphKeys.VARIABLES, - scope=scope if isinstance(scope, str) else scope.name) + Returns + ------- + vars: [tf.Variable] + list of variables in `scope`. + """ + return tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES + if trainable_only else tf.GraphKeys.VARIABLES, + scope=scope if isinstance(scope, str) else scope.name) def scope_name(): - """Returns the name of current scope as a string, e.g. deepq/q_func""" - return tf.get_variable_scope().name + """Returns the name of current scope as a string, e.g. deepq/q_func""" + return tf.get_variable_scope().name def absolute_scope_name(relative_scope_name): - """Appends parent scope name to `relative_scope_name`""" - return scope_name() + "/" + relative_scope_name + """Appends parent scope name to `relative_scope_name`""" + return scope_name() + "/" + relative_scope_name def lengths_to_mask(lengths_b, max_length): - """ - Turns a vector of lengths into a boolean mask + """ + Turns a vector of lengths into a boolean mask - Args: - lengths_b: an integer vector of lengths - max_length: maximum length to fill the mask + Args: + lengths_b: an integer vector of lengths + max_length: maximum length to fill the mask - Returns: - a boolean array of shape (batch_size, max_length) - row[i] consists of True repeated lengths_b[i] times, followed by False - """ - lengths_b = tf.convert_to_tensor(lengths_b) - assert lengths_b.get_shape().ndims == 1 - mask_bt = tf.expand_dims( - tf.range(max_length), 0) < tf.expand_dims(lengths_b, 1) - return mask_bt + Returns: + a boolean array of shape (batch_size, max_length) + row[i] consists of True repeated lengths_b[i] times, followed by False + """ + lengths_b = tf.convert_to_tensor(lengths_b) + assert lengths_b.get_shape().ndims == 1 + mask_bt = tf.expand_dims( + tf.range(max_length), 0) < tf.expand_dims(lengths_b, 1) + return mask_bt def in_session(f): - @functools.wraps(f) - def newfunc(*args, **kwargs): - with tf.Session(): - f(*args, **kwargs) - return newfunc + @functools.wraps(f) + def newfunc(*args, **kwargs): + with tf.Session(): + f(*args, **kwargs) + return newfunc _PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape) def get_placeholder(name, dtype, shape): - if name in _PLACEHOLDER_CACHE: - out, dtype1, shape1 = _PLACEHOLDER_CACHE[name] - assert dtype1 == dtype and shape1 == shape - return out - else: - out = tf.placeholder(dtype=dtype, shape=shape, name=name) - _PLACEHOLDER_CACHE[name] = (out, dtype, shape) - return out + if name in _PLACEHOLDER_CACHE: + out, dtype1, shape1 = _PLACEHOLDER_CACHE[name] + assert dtype1 == dtype and shape1 == shape + return out + else: + out = tf.placeholder(dtype=dtype, shape=shape, name=name) + _PLACEHOLDER_CACHE[name] = (out, dtype, shape) + return out def get_placeholder_cached(name): - return _PLACEHOLDER_CACHE[name][0] + return _PLACEHOLDER_CACHE[name][0] def flattenallbut0(x): - return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) + return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) def reset(): - global _PLACEHOLDER_CACHE - global VARIABLES - _PLACEHOLDER_CACHE = {} - VARIABLES = {} - tf.reset_default_graph() + global _PLACEHOLDER_CACHE + global VARIABLES + _PLACEHOLDER_CACHE = {} + VARIABLES = {} + tf.reset_default_graph() diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 23eb57910..b67e6dd90 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -88,132 +88,141 @@ DEFAULT_CONFIG = dict( class DQN(Algorithm): - def __init__(self, env_name, config, upload_dir=None): - config.update({"alg": "DQN"}) - Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) - env = gym.make(env_name) - env = ScaledFloatFrame(wrap_dqn(env)) - self.env = env - model = models.cnn_to_mlp( - convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], - hiddens=[256], dueling=True) - sess = U.make_session(num_cpu=config["num_cpu"]) - sess.__enter__() + def __init__(self, env_name, config, upload_dir=None): + config.update({"alg": "DQN"}) + Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) + env = gym.make(env_name) + env = ScaledFloatFrame(wrap_dqn(env)) + self.env = env + model = models.cnn_to_mlp( + convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], + hiddens=[256], dueling=True) + sess = U.make_session(num_cpu=config["num_cpu"]) + sess.__enter__() - def make_obs_ph(name): - return U.BatchInput(env.observation_space.shape, name=name) + def make_obs_ph(name): + return U.BatchInput(env.observation_space.shape, name=name) - self.act, self.optimize, self.update_target, self.debug = build_train( - make_obs_ph=make_obs_ph, - q_func=model, - num_actions=env.action_space.n, - optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]), - gamma=config["gamma"], - grad_norm_clipping=10) - # Create the replay buffer - if config["prioritized_replay"]: - self.replay_buffer = PrioritizedReplayBuffer( - config["buffer_size"], alpha=config["prioritized_replay_alpha"]) - prioritized_replay_beta_iters = config["prioritized_replay_beta_iters"] - if prioritized_replay_beta_iters is None: - prioritized_replay_beta_iters = config["schedule_max_timesteps"] - self.beta_schedule = LinearSchedule( - prioritized_replay_beta_iters, - initial_p=config["prioritized_replay_beta0"], - final_p=1.0) - else: - self.replay_buffer = ReplayBuffer(config["buffer_size"]) - self.beta_schedule = None - # Create the schedule for exploration starting from 1. - self.exploration = LinearSchedule( - schedule_timesteps=int( - config["exploration_fraction"] * config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=config["exploration_final_eps"]) - - # Initialize the parameters and copy them to the target network. - U.initialize() - self.update_target() - - self.episode_rewards = [0.0] - self.episode_lengths = [0.0] - self.saved_mean_reward = None - self.obs = self.env.reset() - self.num_timesteps = 0 - self.num_iterations = 0 - - def train(self): - config = self.config - sample_time, learn_time = 0, 0 - - for t in range(config["timesteps_per_iteration"]): - self.num_timesteps += 1 - dt = time.time() - # Take action and update exploration to the newest value - action = self.act( - np.array(self.obs)[None], update_eps=self.exploration.value(t))[0] - new_obs, rew, done, _ = self.env.step(action) - # Store transition in the replay buffer. - self.replay_buffer.add(self.obs, action, rew, new_obs, float(done)) - self.obs = new_obs - - self.episode_rewards[-1] += rew - self.episode_lengths[-1] += 1 - if done: - self.obs = self.env.reset() - self.episode_rewards.append(0.0) - self.episode_lengths.append(0.0) - sample_time += time.time() - dt - - if self.num_timesteps > config["learning_starts"] and \ - self.num_timesteps % config["train_freq"] == 0: - dt = time.time() - # Minimize the error in Bellman's equation on a batch sampled from - # replay buffer. + self.act, self.optimize, self.update_target, self.debug = build_train( + make_obs_ph=make_obs_ph, + q_func=model, + num_actions=env.action_space.n, + optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]), + gamma=config["gamma"], + grad_norm_clipping=10) + # Create the replay buffer if config["prioritized_replay"]: - experience = self.replay_buffer.sample( - config["batch_size"], beta=self.beta_schedule.value(t)) - (obses_t, actions, rewards, obses_tp1, - dones, _, batch_idxes) = experience + self.replay_buffer = PrioritizedReplayBuffer( + config["buffer_size"], + alpha=config["prioritized_replay_alpha"]) + prioritized_replay_beta_iters = ( + config["prioritized_replay_beta_iters"]) + if prioritized_replay_beta_iters is None: + prioritized_replay_beta_iters = ( + config["schedule_max_timesteps"]) + self.beta_schedule = LinearSchedule( + prioritized_replay_beta_iters, + initial_p=config["prioritized_replay_beta0"], + final_p=1.0) else: - obses_t, actions, rewards, obses_tp1, dones = \ - self.replay_buffer.sample(config["batch_size"]) - batch_idxes = None - td_errors = self.optimize( - obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards)) - if config["prioritized_replay"]: - new_priorities = np.abs(td_errors) + config["prioritized_replay_eps"] - self.replay_buffer.update_priorities(batch_idxes, new_priorities) - learn_time += (time.time() - dt) + self.replay_buffer = ReplayBuffer(config["buffer_size"]) + self.beta_schedule = None + # Create the schedule for exploration starting from 1. + self.exploration = LinearSchedule( + schedule_timesteps=int( + config["exploration_fraction"] * + config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_eps"]) - if self.num_timesteps > config["learning_starts"] and \ - self.num_timesteps % config["target_network_update_freq"] == 0: - # Update target network periodically. + # Initialize the parameters and copy them to the target network. + U.initialize() self.update_target() - mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1) - mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1) - num_episodes = len(self.episode_rewards) + self.episode_rewards = [0.0] + self.episode_lengths = [0.0] + self.saved_mean_reward = None + self.obs = self.env.reset() + self.num_timesteps = 0 + self.num_iterations = 0 - info = { - "sample_time": sample_time, - "learn_time": learn_time, - "steps": self.num_timesteps, - "episodes": num_episodes, - "exploration": int(100 * self.exploration.value(t)) - } + def train(self): + config = self.config + sample_time, learn_time = 0, 0 - logger.record_tabular("sample_time", sample_time) - logger.record_tabular("learn_time", learn_time) - logger.record_tabular("steps", self.num_timesteps) - logger.record_tabular("episodes", num_episodes) - logger.record_tabular("mean 100 episode reward", mean_100ep_reward) - logger.record_tabular( - "% time spent exploring", int(100 * self.exploration.value(t))) - logger.dump_tabular() + for t in range(config["timesteps_per_iteration"]): + self.num_timesteps += 1 + dt = time.time() + # Take action and update exploration to the newest value + action = self.act( + np.array(self.obs)[None], + update_eps=self.exploration.value(t))[0] + new_obs, rew, done, _ = self.env.step(action) + # Store transition in the replay buffer. + self.replay_buffer.add(self.obs, action, rew, new_obs, float(done)) + self.obs = new_obs - res = TrainingResult( - self.experiment_id.hex, self.num_iterations, mean_100ep_reward, - mean_100ep_length, info) - self.num_iterations += 1 - return res + self.episode_rewards[-1] += rew + self.episode_lengths[-1] += 1 + if done: + self.obs = self.env.reset() + self.episode_rewards.append(0.0) + self.episode_lengths.append(0.0) + sample_time += time.time() - dt + + if self.num_timesteps > config["learning_starts"] and \ + self.num_timesteps % config["train_freq"] == 0: + dt = time.time() + # Minimize the error in Bellman's equation on a batch sampled + # from replay buffer. + if config["prioritized_replay"]: + experience = self.replay_buffer.sample( + config["batch_size"], beta=self.beta_schedule.value(t)) + (obses_t, actions, rewards, obses_tp1, + dones, _, batch_idxes) = experience + else: + obses_t, actions, rewards, obses_tp1, dones = \ + self.replay_buffer.sample(config["batch_size"]) + batch_idxes = None + td_errors = self.optimize( + obses_t, actions, rewards, obses_tp1, dones, + np.ones_like(rewards)) + if config["prioritized_replay"]: + new_priorities = (np.abs(td_errors) + + config["prioritized_replay_eps"]) + self.replay_buffer.update_priorities(batch_idxes, + new_priorities) + learn_time += (time.time() - dt) + + if (self.num_timesteps > config["learning_starts"] and + self.num_timesteps % + config["target_network_update_freq"] == 0): + # Update target network periodically. + self.update_target() + + mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1) + mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1) + num_episodes = len(self.episode_rewards) + + info = { + "sample_time": sample_time, + "learn_time": learn_time, + "steps": self.num_timesteps, + "episodes": num_episodes, + "exploration": int(100 * self.exploration.value(t)) + } + + logger.record_tabular("sample_time", sample_time) + logger.record_tabular("learn_time", learn_time) + logger.record_tabular("steps", self.num_timesteps) + logger.record_tabular("episodes", num_episodes) + logger.record_tabular("mean 100 episode reward", mean_100ep_reward) + logger.record_tabular( + "% time spent exploring", int(100 * self.exploration.value(t))) + logger.dump_tabular() + + res = TrainingResult( + self.experiment_id.hex, self.num_iterations, mean_100ep_reward, + mean_100ep_length, info) + self.num_iterations += 1 + return res diff --git a/python/ray/rllib/dqn/example.py b/python/ray/rllib/dqn/example.py index 5602633e9..8d40eadfb 100755 --- a/python/ray/rllib/dqn/example.py +++ b/python/ray/rllib/dqn/example.py @@ -24,8 +24,8 @@ def main(): dqn = DQN("PongNoFrameskip-v4", config) while True: - res = dqn.train() - print("current status: {}".format(res)) + res = dqn.train() + print("current status: {}".format(res)) if __name__ == '__main__': diff --git a/python/ray/rllib/dqn/logger.py b/python/ray/rllib/dqn/logger.py index 17906e1a1..ea9957c4a 100644 --- a/python/ray/rllib/dqn/logger.py +++ b/python/ray/rllib/dqn/logger.py @@ -29,88 +29,88 @@ DISABLED = 50 class OutputFormat(object): - def writekvs(self, kvs): - """ - Write key-value pairs - """ - raise NotImplementedError + def writekvs(self, kvs): + """ + Write key-value pairs + """ + raise NotImplementedError - def writeseq(self, args): - """ - Write a sequence of other data (e.g. a logging message) - """ - pass + def writeseq(self, args): + """ + Write a sequence of other data (e.g. a logging message) + """ + pass - def close(self): - return + def close(self): + return class HumanOutputFormat(OutputFormat): - def __init__(self, file): - self.file = file + def __init__(self, file): + self.file = file - def writekvs(self, kvs): - # Create strings for printing - key2str = OrderedDict() - for (key, val) in kvs.items(): - valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val - key2str[self._truncate(key)] = self._truncate(valstr) + def writekvs(self, kvs): + # Create strings for printing + key2str = OrderedDict() + for (key, val) in kvs.items(): + valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val + key2str[self._truncate(key)] = self._truncate(valstr) - # Find max widths - keywidth = max(map(len, key2str.keys())) - valwidth = max(map(len, key2str.values())) + # Find max widths + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) - # Write out the data - dashes = '-' * (keywidth + valwidth + 7) - lines = [dashes] - for (key, val) in key2str.items(): - lines.append('| %s%s | %s%s |' % ( - key, - ' ' * (keywidth - len(key)), - val, - ' ' * (valwidth - len(val)), - )) - lines.append(dashes) - self.file.write('\n'.join(lines) + '\n') + # Write out the data + dashes = '-' * (keywidth + valwidth + 7) + lines = [dashes] + for (key, val) in key2str.items(): + lines.append('| %s%s | %s%s |' % ( + key, + ' ' * (keywidth - len(key)), + val, + ' ' * (valwidth - len(val)), + )) + lines.append(dashes) + self.file.write('\n'.join(lines) + '\n') - # Flush the output to the file - self.file.flush() + # Flush the output to the file + self.file.flush() - def _truncate(self, s): - return s[:20] + '...' if len(s) > 23 else s + def _truncate(self, s): + return s[:20] + '...' if len(s) > 23 else s - def writeseq(self, args): - for arg in args: - self.file.write(arg) - self.file.write('\n') - self.file.flush() + def writeseq(self, args): + for arg in args: + self.file.write(arg) + self.file.write('\n') + self.file.flush() class JSONOutputFormat(OutputFormat): - def __init__(self, file): - self.file = file + def __init__(self, file): + self.file = file - def writekvs(self, kvs): - for k, v in kvs.items(): - if hasattr(v, 'dtype'): - v = v.tolist() - kvs[k] = float(v) - self.file.write(json.dumps(kvs) + '\n') - self.file.flush() + def writekvs(self, kvs): + for k, v in kvs.items(): + if hasattr(v, 'dtype'): + v = v.tolist() + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + '\n') + self.file.flush() def make_output_format(format, ev_dir): - os.makedirs(ev_dir, exist_ok=True) - if format == 'stdout': - return HumanOutputFormat(sys.stdout) - elif format == 'log': - log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') - return HumanOutputFormat(log_file) - elif format == 'json': - json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') - return JSONOutputFormat(json_file) - else: - raise ValueError('Unknown format specified: %s' % (format,)) + os.makedirs(ev_dir, exist_ok=True) + if format == 'stdout': + return HumanOutputFormat(sys.stdout) + elif format == 'log': + log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') + return HumanOutputFormat(log_file) + elif format == 'json': + json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') + return JSONOutputFormat(json_file) + else: + raise ValueError('Unknown format specified: %s' % (format,)) # ================================================================ # API @@ -118,21 +118,21 @@ def make_output_format(format, ev_dir): def logkv(key, val): - """ - Log a value of some diagnostic - Call this once for each diagnostic quantity, each iteration - """ - Logger.CURRENT.logkv(key, val) + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + """ + Logger.CURRENT.logkv(key, val) def dumpkvs(): - """ - Write all of the diagnostics from the current iteration + """ + Write all of the diagnostics from the current iteration - level: int. (see logger.py docs) If the global logger level is higher than - the level argument here, don't print to stdout. - """ - Logger.CURRENT.dumpkvs() + level: int. (see logger.py docs) If the global logger level is higher than + the level argument here, don't print to stdout. + """ + Logger.CURRENT.dumpkvs() # for backwards compatibility @@ -141,49 +141,50 @@ dump_tabular = dumpkvs def log(*args, level=INFO): - """ - Write the sequence of args, with no separators, to the console and output - files (if you've configured an output file). - """ - Logger.CURRENT.log(*args, level=level) + """ + Write the sequence of args, with no separators, to the console and output + files (if you've configured an output file). + """ + Logger.CURRENT.log(*args, level=level) def debug(*args): - log(*args, level=DEBUG) + log(*args, level=DEBUG) def info(*args): - log(*args, level=INFO) + log(*args, level=INFO) def warn(*args): - log(*args, level=WARN) + log(*args, level=WARN) def error(*args): - log(*args, level=ERROR) + log(*args, level=ERROR) def set_level(level): - """ - Set logging threshold on current logger. - """ - Logger.CURRENT.set_level(level) + """ + Set logging threshold on current logger. + """ + Logger.CURRENT.set_level(level) def get_dir(): - """ - Get directory that log files are being written to. - will be None if there is no output directory (i.e., if you didn't call start) - """ - return Logger.CURRENT.get_dir() + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call + start) + """ + return Logger.CURRENT.get_dir() def get_expt_dir(): - sys.stderr.write( - "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % - (get_dir(),)) - return get_dir() + sys.stderr.write( + "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % + (get_dir(),)) + return get_dir() # ================================================================ @@ -192,50 +193,50 @@ def get_expt_dir(): class Logger(object): - # A logger with no output files. (See right below class definition) - # So that you can still log to the terminal without setting up any output - DEFAULT = None + # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output + DEFAULT = None - # Current logger being used by the free functions above - CURRENT = None + # Current logger being used by the free functions above + CURRENT = None - def __init__(self, dir, output_formats): - self.name2val = OrderedDict() # values this iteration - self.level = INFO - self.dir = dir - self.output_formats = output_formats + def __init__(self, dir, output_formats): + self.name2val = OrderedDict() # values this iteration + self.level = INFO + self.dir = dir + self.output_formats = output_formats - # Logging API, forwarded - # ---------------------------------------- - def logkv(self, key, val): - self.name2val[key] = val + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val - def dumpkvs(self): - for fmt in self.output_formats: - fmt.writekvs(self.name2val) - self.name2val.clear() + def dumpkvs(self): + for fmt in self.output_formats: + fmt.writekvs(self.name2val) + self.name2val.clear() - def log(self, *args, level=INFO): - if self.level <= level: - self._do_log(args) + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) - # Configuration - # ---------------------------------------- - def set_level(self, level): - self.level = level + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level - def get_dir(self): - return self.dir + def get_dir(self): + return self.dir - def close(self): - for fmt in self.output_formats: - fmt.close() + def close(self): + for fmt in self.output_formats: + fmt.close() - # Misc - # ---------------------------------------- - def _do_log(self, args): - for fmt in self.output_formats: - fmt.writeseq(args) + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + fmt.writeseq(args) # ================================================================ @@ -246,60 +247,60 @@ Logger.CURRENT = Logger.DEFAULT class session(object): - """ - Context manager that sets up the loggers for an experiment. - """ + """ + Context manager that sets up the loggers for an experiment. + """ - CURRENT = None # Set to a LoggerContext object using enter/exit or cm + CURRENT = None # Set to a LoggerContext object using enter/exit or cm - def __init__(self, dir, format_strs=None): - self.dir = dir - if format_strs is None: - format_strs = LOG_OUTPUT_FORMATS - output_formats = [make_output_format(f, dir) for f in format_strs] - Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) + def __init__(self, dir, format_strs=None): + self.dir = dir + if format_strs is None: + format_strs = LOG_OUTPUT_FORMATS + output_formats = [make_output_format(f, dir) for f in format_strs] + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) - def __enter__(self): - os.makedirs(self.evaluation_dir(), exist_ok=True) - output_formats = [ - make_output_format( - f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS] - Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats) + def __enter__(self): + os.makedirs(self.evaluation_dir(), exist_ok=True) + output_formats = [ + make_output_format( + f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS] + Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats) - def __exit__(self, *args): - Logger.CURRENT.close() - Logger.CURRENT = Logger.DEFAULT + def __exit__(self, *args): + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT - def evaluation_dir(self): - return self.dir + def evaluation_dir(self): + return self.dir # ================================================================ def _demo(): - info("hi") - debug("shouldn't appear") - set_level(DEBUG) - debug("should appear") - dir = "/tmp/testlogging" - if os.path.exists(dir): - shutil.rmtree(dir) - with session(dir=dir): - record_tabular("a", 3) - record_tabular("b", 2.5) - dump_tabular() + info("hi") + debug("shouldn't appear") + set_level(DEBUG) + debug("should appear") + dir = "/tmp/testlogging" + if os.path.exists(dir): + shutil.rmtree(dir) + with session(dir=dir): + record_tabular("a", 3) + record_tabular("b", 2.5) + dump_tabular() + record_tabular("b", -2.5) + record_tabular("a", 5.5) + dump_tabular() + info("^^^ should see a = 5.5") + record_tabular("b", -2.5) - record_tabular("a", 5.5) dump_tabular() - info("^^^ should see a = 5.5") - record_tabular("b", -2.5) - dump_tabular() - - record_tabular("a", "longasslongasslongasslongasslongasslongassvalue") - dump_tabular() + record_tabular("a", "longasslongasslongasslongasslongasslongassvalue") + dump_tabular() if __name__ == "__main__": - _demo() + _demo() diff --git a/python/ray/rllib/dqn/models.py b/python/ray/rllib/dqn/models.py index e017e0174..fa8917156 100644 --- a/python/ray/rllib/dqn/models.py +++ b/python/ray/rllib/dqn/models.py @@ -7,91 +7,92 @@ import tensorflow.contrib.layers as layers def _mlp(hiddens, inpt, num_actions, scope, reuse=False): - with tf.variable_scope(scope, reuse=reuse): - out = inpt - for hidden in hiddens: - out = layers.fully_connected( - out, num_outputs=hidden, activation_fn=tf.nn.relu) - out = layers.fully_connected( - out, num_outputs=num_actions, activation_fn=None) - return out + with tf.variable_scope(scope, reuse=reuse): + out = inpt + for hidden in hiddens: + out = layers.fully_connected( + out, num_outputs=hidden, activation_fn=tf.nn.relu) + out = layers.fully_connected( + out, num_outputs=num_actions, activation_fn=None) + return out def mlp(hiddens=[]): - """This model takes as input an observation and returns values of all - actions. + """This model takes as input an observation and returns values of all + actions. - Parameters - ---------- - hiddens: [int] - list of sizes of hidden layers + Parameters + ---------- + hiddens: [int] + list of sizes of hidden layers - Returns - ------- - q_func: function - q_function for DQN algorithm. - """ - return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs) + Returns + ------- + q_func: function + q_function for DQN algorithm. + """ + return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs) def _cnn_to_mlp( convs, hiddens, dueling, inpt, num_actions, scope, reuse=False): - with tf.variable_scope(scope, reuse=reuse): - out = inpt - with tf.variable_scope("convnet"): - for num_outputs, kernel_size, stride in convs: - out = layers.convolution2d( - out, - num_outputs=num_outputs, - kernel_size=kernel_size, - stride=stride, - activation_fn=tf.nn.relu) - out = layers.flatten(out) - with tf.variable_scope("action_value"): - action_out = out - for hidden in hiddens: - action_out = layers.fully_connected( - action_out, num_outputs=hidden, activation_fn=tf.nn.relu) - action_scores = layers.fully_connected( - action_out, num_outputs=num_actions, activation_fn=None) + with tf.variable_scope(scope, reuse=reuse): + out = inpt + with tf.variable_scope("convnet"): + for num_outputs, kernel_size, stride in convs: + out = layers.convolution2d( + out, + num_outputs=num_outputs, + kernel_size=kernel_size, + stride=stride, + activation_fn=tf.nn.relu) + out = layers.flatten(out) + with tf.variable_scope("action_value"): + action_out = out + for hidden in hiddens: + action_out = layers.fully_connected( + action_out, num_outputs=hidden, activation_fn=tf.nn.relu) + action_scores = layers.fully_connected( + action_out, num_outputs=num_actions, activation_fn=None) - if dueling: - with tf.variable_scope("state_value"): - state_out = out - for hidden in hiddens: - state_out = layers.fully_connected( - state_out, num_outputs=hidden, activation_fn=tf.nn.relu) - state_score = layers.fully_connected( - state_out, num_outputs=1, activation_fn=None) - action_scores_mean = tf.reduce_mean(action_scores, 1) - action_scores_centered = action_scores - tf.expand_dims( - action_scores_mean, 1) - return state_score + action_scores_centered - else: - return action_scores - return out + if dueling: + with tf.variable_scope("state_value"): + state_out = out + for hidden in hiddens: + state_out = layers.fully_connected( + state_out, num_outputs=hidden, + activation_fn=tf.nn.relu) + state_score = layers.fully_connected( + state_out, num_outputs=1, activation_fn=None) + action_scores_mean = tf.reduce_mean(action_scores, 1) + action_scores_centered = action_scores - tf.expand_dims( + action_scores_mean, 1) + return state_score + action_scores_centered + else: + return action_scores + return out def cnn_to_mlp(convs, hiddens, dueling=False): - """This model takes as input an observation and returns values of all actions. + """This model takes an observation and returns values for all actions. - Parameters - ---------- - convs: [(int, int int)] - list of convolutional layers in form of - (num_outputs, kernel_size, stride) - hiddens: [int] - list of sizes of hidden layers - dueling: bool - if true double the output MLP to compute a baseline - for action scores + Parameters + ---------- + convs: [(int, int int)] + list of convolutional layers in form of + (num_outputs, kernel_size, stride) + hiddens: [int] + list of sizes of hidden layers + dueling: bool + if true double the output MLP to compute a baseline + for action scores - Returns - ------- - q_func: function - q_function for DQN algorithm. - """ + Returns + ------- + q_func: function + q_function for DQN algorithm. + """ - return lambda *args, **kwargs: _cnn_to_mlp( - convs, hiddens, dueling, *args, **kwargs) + return lambda *args, **kwargs: _cnn_to_mlp( + convs, hiddens, dueling, *args, **kwargs) diff --git a/python/ray/rllib/dqn/replay_buffer.py b/python/ray/rllib/dqn/replay_buffer.py index 150ee6470..ff19f434e 100644 --- a/python/ray/rllib/dqn/replay_buffer.py +++ b/python/ray/rllib/dqn/replay_buffer.py @@ -9,188 +9,189 @@ from ray.rllib.dqn.common.segment_tree import SumSegmentTree, MinSegmentTree class ReplayBuffer(object): - def __init__(self, size): - """Create Prioritized Replay buffer. + def __init__(self, size): + """Create Prioritized Replay buffer. - Parameters - ---------- - size: int - Max number of transitions to store in the buffer. When the buffer - overflows the old memories are dropped. - """ - self._storage = [] - self._maxsize = size - self._next_idx = 0 + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + """ + self._storage = [] + self._maxsize = size + self._next_idx = 0 - def __len__(self): - return len(self._storage) + def __len__(self): + return len(self._storage) - def add(self, obs_t, action, reward, obs_tp1, done): - data = (obs_t, action, reward, obs_tp1, done) + def add(self, obs_t, action, reward, obs_tp1, done): + data = (obs_t, action, reward, obs_tp1, done) - if self._next_idx >= len(self._storage): - self._storage.append(data) - else: - self._storage[self._next_idx] = data - self._next_idx = (self._next_idx + 1) % self._maxsize + if self._next_idx >= len(self._storage): + self._storage.append(data) + else: + self._storage[self._next_idx] = data + self._next_idx = (self._next_idx + 1) % self._maxsize - def _encode_sample(self, idxes): - obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] - for i in idxes: - data = self._storage[i] - obs_t, action, reward, obs_tp1, done = data - obses_t.append(np.array(obs_t, copy=False)) - actions.append(np.array(action, copy=False)) - rewards.append(reward) - obses_tp1.append(np.array(obs_tp1, copy=False)) - dones.append(done) - return np.array(obses_t), np.array(actions), np.array(rewards), \ - np.array(obses_tp1), np.array(dones) + def _encode_sample(self, idxes): + obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] + for i in idxes: + data = self._storage[i] + obs_t, action, reward, obs_tp1, done = data + obses_t.append(np.array(obs_t, copy=False)) + actions.append(np.array(action, copy=False)) + rewards.append(reward) + obses_tp1.append(np.array(obs_tp1, copy=False)) + dones.append(done) + return (np.array(obses_t), np.array(actions), np.array(rewards), + np.array(obses_tp1), np.array(dones)) - def sample(self, batch_size): - """Sample a batch of experiences. + def sample(self, batch_size): + """Sample a batch of experiences. - Parameters - ---------- - batch_size: int - How many transitions to sample. + Parameters + ---------- + batch_size: int + How many transitions to sample. - Returns - ------- - obs_batch: np.array - batch of observations - act_batch: np.array - batch of actions executed given obs_batch - rew_batch: np.array - rewards received as results of executing act_batch - next_obs_batch: np.array - next set of observations seen after executing act_batch - done_mask: np.array - done_mask[i] = 1 if executing act_batch[i] resulted in - the end of an episode and 0 otherwise. - """ - idxes = [random.randint(0, len(self._storage) - 1) - for _ in range(batch_size)] - return self._encode_sample(idxes) + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + """ + idxes = [random.randint(0, len(self._storage) - 1) + for _ in range(batch_size)] + return self._encode_sample(idxes) class PrioritizedReplayBuffer(ReplayBuffer): - def __init__(self, size, alpha): - """Create Prioritized Replay buffer. + def __init__(self, size, alpha): + """Create Prioritized Replay buffer. - Parameters - ---------- - size: int - Max number of transitions to store in the buffer. When the buffer - overflows the old memories are dropped. - alpha: float - how much prioritization is used - (0 - no prioritization, 1 - full prioritization) + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + alpha: float + how much prioritization is used + (0 - no prioritization, 1 - full prioritization) - See Also - -------- - ReplayBuffer.__init__ - """ - super(PrioritizedReplayBuffer, self).__init__(size) - assert alpha > 0 - self._alpha = alpha + See Also + -------- + ReplayBuffer.__init__ + """ + super(PrioritizedReplayBuffer, self).__init__(size) + assert alpha > 0 + self._alpha = alpha - it_capacity = 1 - while it_capacity < size: - it_capacity *= 2 + it_capacity = 1 + while it_capacity < size: + it_capacity *= 2 - self._it_sum = SumSegmentTree(it_capacity) - self._it_min = MinSegmentTree(it_capacity) - self._max_priority = 1.0 + self._it_sum = SumSegmentTree(it_capacity) + self._it_min = MinSegmentTree(it_capacity) + self._max_priority = 1.0 - def add(self, *args, **kwargs): - """See ReplayBuffer.store_effect""" - idx = self._next_idx - super().add(*args, **kwargs) - self._it_sum[idx] = self._max_priority ** self._alpha - self._it_min[idx] = self._max_priority ** self._alpha + def add(self, *args, **kwargs): + """See ReplayBuffer.store_effect""" + idx = self._next_idx + super().add(*args, **kwargs) + self._it_sum[idx] = self._max_priority ** self._alpha + self._it_min[idx] = self._max_priority ** self._alpha - def _sample_proportional(self, batch_size): - res = [] - for _ in range(batch_size): - # TODO(szymon): should we ensure no repeats? - mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) - idx = self._it_sum.find_prefixsum_idx(mass) - res.append(idx) - return res + def _sample_proportional(self, batch_size): + res = [] + for _ in range(batch_size): + # TODO(szymon): should we ensure no repeats? + mass = random.random() * self._it_sum.sum(0, + len(self._storage) - 1) + idx = self._it_sum.find_prefixsum_idx(mass) + res.append(idx) + return res - def sample(self, batch_size, beta): - """Sample a batch of experiences. + def sample(self, batch_size, beta): + """Sample a batch of experiences. - compared to ReplayBuffer.sample - it also returns importance weights and idxes - of sampled experiences. + compared to ReplayBuffer.sample + it also returns importance weights and idxes + of sampled experiences. - Parameters - ---------- - batch_size: int - How many transitions to sample. - beta: float - To what degree to use importance weights - (0 - no corrections, 1 - full correction) + Parameters + ---------- + batch_size: int + How many transitions to sample. + beta: float + To what degree to use importance weights + (0 - no corrections, 1 - full correction) - Returns - ------- - obs_batch: np.array - batch of observations - act_batch: np.array - batch of actions executed given obs_batch - rew_batch: np.array - rewards received as results of executing act_batch - next_obs_batch: np.array - next set of observations seen after executing act_batch - done_mask: np.array - done_mask[i] = 1 if executing act_batch[i] resulted in - the end of an episode and 0 otherwise. - weights: np.array - Array of shape (batch_size,) and dtype np.float32 - denoting importance weight of each sampled transition - idxes: np.array - Array of shape (batch_size,) and dtype np.int32 - idexes in buffer of sampled experiences - """ - assert beta > 0 + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + weights: np.array + Array of shape (batch_size,) and dtype np.float32 + denoting importance weight of each sampled transition + idxes: np.array + Array of shape (batch_size,) and dtype np.int32 + idexes in buffer of sampled experiences + """ + assert beta > 0 - idxes = self._sample_proportional(batch_size) + idxes = self._sample_proportional(batch_size) - weights = [] - p_min = self._it_min.min() / self._it_sum.sum() - max_weight = (p_min * len(self._storage)) ** (-beta) + weights = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self._storage)) ** (-beta) - for idx in idxes: - p_sample = self._it_sum[idx] / self._it_sum.sum() - weight = (p_sample * len(self._storage)) ** (-beta) - weights.append(weight / max_weight) - weights = np.array(weights) - encoded_sample = self._encode_sample(idxes) - return tuple(list(encoded_sample) + [weights, idxes]) + for idx in idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self._storage)) ** (-beta) + weights.append(weight / max_weight) + weights = np.array(weights) + encoded_sample = self._encode_sample(idxes) + return tuple(list(encoded_sample) + [weights, idxes]) - def update_priorities(self, idxes, priorities): - """Update priorities of sampled transitions. + def update_priorities(self, idxes, priorities): + """Update priorities of sampled transitions. - sets priority of transition at index idxes[i] in buffer - to priorities[i]. + sets priority of transition at index idxes[i] in buffer + to priorities[i]. - Parameters - ---------- - idxes: [int] - List of idxes of sampled transitions - priorities: [float] - List of updated priorities corresponding to - transitions at the sampled idxes denoted by - variable `idxes`. - """ - assert len(idxes) == len(priorities) - for idx, priority in zip(idxes, priorities): - assert priority > 0 - assert 0 <= idx < len(self._storage) - self._it_sum[idx] = priority ** self._alpha - self._it_min[idx] = priority ** self._alpha + Parameters + ---------- + idxes: [int] + List of idxes of sampled transitions + priorities: [float] + List of updated priorities corresponding to + transitions at the sampled idxes denoted by + variable `idxes`. + """ + assert len(idxes) == len(priorities) + for idx, priority in zip(idxes, priorities): + assert priority > 0 + assert 0 <= idx < len(self._storage) + self._it_sum[idx] = priority ** self._alpha + self._it_min[idx] = priority ** self._alpha - self._max_priority = max(self._max_priority, priority) + self._max_priority = max(self._max_priority, priority) diff --git a/python/ray/rllib/evolution_strategies/evolution_strategies.py b/python/ray/rllib/evolution_strategies/evolution_strategies.py index 8c330bef8..f56944091 100644 --- a/python/ray/rllib/evolution_strategies/evolution_strategies.py +++ b/python/ray/rllib/evolution_strategies/evolution_strategies.py @@ -43,256 +43,267 @@ DEFAULT_CONFIG = dict( @ray.remote def create_shared_noise(): - """Create a large array of noise to be shared by all workers.""" - seed = 123 - count = 250000000 - noise = np.random.RandomState(seed).randn(count).astype(np.float32) - return noise + """Create a large array of noise to be shared by all workers.""" + seed = 123 + count = 250000000 + noise = np.random.RandomState(seed).randn(count).astype(np.float32) + return noise class SharedNoiseTable(object): - def __init__(self, noise): - self.noise = noise - assert self.noise.dtype == np.float32 + def __init__(self, noise): + self.noise = noise + assert self.noise.dtype == np.float32 - def get(self, i, dim): - return self.noise[i:i + dim] + def get(self, i, dim): + return self.noise[i:i + dim] - def sample_index(self, stream, dim): - return stream.randint(0, len(self.noise) - dim + 1) + def sample_index(self, stream, dim): + return stream.randint(0, len(self.noise) - dim + 1) @ray.remote class Worker(object): - def __init__(self, config, policy_params, env_name, noise, - min_task_runtime=0.2): - self.min_task_runtime = min_task_runtime - self.config = config - self.policy_params = policy_params - self.noise = SharedNoiseTable(noise) + def __init__(self, config, policy_params, env_name, noise, + min_task_runtime=0.2): + self.min_task_runtime = min_task_runtime + self.config = config + self.policy_params = policy_params + self.noise = SharedNoiseTable(noise) - self.env = gym.make(env_name) - self.sess = utils.make_session(single_threaded=True) - self.policy = policies.MujocoPolicy(self.env.observation_space, - self.env.action_space, - **policy_params) - tf_util.initialize() + self.env = gym.make(env_name) + self.sess = utils.make_session(single_threaded=True) + self.policy = policies.MujocoPolicy(self.env.observation_space, + self.env.action_space, + **policy_params) + tf_util.initialize() - self.rs = np.random.RandomState() + self.rs = np.random.RandomState() - assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] != 0) + assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] != + 0) - def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat): - if (self.policy.needs_ob_stat and self.config["calc_obstat_prob"] != 0 and - self.rs.rand() < self.config["calc_obstat_prob"]): - rollout_rews, rollout_len, obs = self.policy.rollout( - self.env, timestep_limit=timestep_limit, save_obs=True, - random_stream=self.rs) - task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0), - len(obs)) - else: - rollout_rews, rollout_len = self.policy.rollout( - self.env, timestep_limit=timestep_limit, random_stream=self.rs) - return rollout_rews, rollout_len + def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat): + if (self.policy.needs_ob_stat and + self.config["calc_obstat_prob"] != 0 and + self.rs.rand() < self.config["calc_obstat_prob"]): + rollout_rews, rollout_len, obs = self.policy.rollout( + self.env, timestep_limit=timestep_limit, save_obs=True, + random_stream=self.rs) + task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0), + len(obs)) + else: + rollout_rews, rollout_len = self.policy.rollout( + self.env, timestep_limit=timestep_limit, random_stream=self.rs) + return rollout_rews, rollout_len - def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None): - # Set the network weights. - self.policy.set_trainable_flat(params) + def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None): + # Set the network weights. + self.policy.set_trainable_flat(params) - if self.policy.needs_ob_stat: - self.policy.set_ob_stat(ob_mean, ob_std) + if self.policy.needs_ob_stat: + self.policy.set_ob_stat(ob_mean, ob_std) - if self.config["eval_prob"] != 0: - raise NotImplementedError("Eval rollouts are not implemented.") + if self.config["eval_prob"] != 0: + raise NotImplementedError("Eval rollouts are not implemented.") - noise_inds, returns, sign_returns, lengths = [], [], [], [] - # We set eps=0 because we're incrementing only. - task_ob_stat = utils.RunningStat(self.env.observation_space.shape, eps=0) + noise_inds, returns, sign_returns, lengths = [], [], [], [] + # We set eps=0 because we're incrementing only. + task_ob_stat = utils.RunningStat(self.env.observation_space.shape, + eps=0) - # Perform some rollouts with noise. - task_tstart = time.time() - while (len(noise_inds) == 0 or - time.time() - task_tstart < self.min_task_runtime): - noise_idx = self.noise.sample_index(self.rs, self.policy.num_params) - perturbation = self.config["noise_stdev"] * self.noise.get( - noise_idx, self.policy.num_params) + # Perform some rollouts with noise. + task_tstart = time.time() + while (len(noise_inds) == 0 or + time.time() - task_tstart < self.min_task_runtime): + noise_idx = self.noise.sample_index(self.rs, + self.policy.num_params) + perturbation = self.config["noise_stdev"] * self.noise.get( + noise_idx, self.policy.num_params) - # These two sampling steps could be done in parallel on different actors - # letting us update twice as frequently. - self.policy.set_trainable_flat(params + perturbation) - rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit, - task_ob_stat) + # These two sampling steps could be done in parallel on different + # actors letting us update twice as frequently. + self.policy.set_trainable_flat(params + perturbation) + rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit, + task_ob_stat) - self.policy.set_trainable_flat(params - perturbation) - rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit, - task_ob_stat) + self.policy.set_trainable_flat(params - perturbation) + rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit, + task_ob_stat) - noise_inds.append(noise_idx) - returns.append([rews_pos.sum(), rews_neg.sum()]) - sign_returns.append([np.sign(rews_pos).sum(), np.sign(rews_neg).sum()]) - lengths.append([len_pos, len_neg]) + noise_inds.append(noise_idx) + returns.append([rews_pos.sum(), rews_neg.sum()]) + sign_returns.append([np.sign(rews_pos).sum(), + np.sign(rews_neg).sum()]) + lengths.append([len_pos, len_neg]) - return Result( - noise_inds_n=np.array(noise_inds), - returns_n2=np.array(returns, dtype=np.float32), - sign_returns_n2=np.array(sign_returns, dtype=np.float32), - lengths_n2=np.array(lengths, dtype=np.int32), - eval_return=None, - eval_length=None, - ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum), - ob_sumsq=(None if task_ob_stat.count == 0 else task_ob_stat.sumsq), - ob_count=task_ob_stat.count) + return Result( + noise_inds_n=np.array(noise_inds), + returns_n2=np.array(returns, dtype=np.float32), + sign_returns_n2=np.array(sign_returns, dtype=np.float32), + lengths_n2=np.array(lengths, dtype=np.int32), + eval_return=None, + eval_length=None, + ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum), + ob_sumsq=(None if task_ob_stat.count == 0 + else task_ob_stat.sumsq), + ob_count=task_ob_stat.count) class EvolutionStrategies(Algorithm): - def __init__(self, env_name, config, upload_dir=None): - config.update({"alg": "EvolutionStrategies"}) + def __init__(self, env_name, config, upload_dir=None): + config.update({"alg": "EvolutionStrategies"}) - Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) + Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) - policy_params = { - "ac_bins": "continuous:", - "ac_noise_std": 0.01, - "nonlin_type": "tanh", - "hidden_dims": [256, 256], - "connection_type": "ff" - } + policy_params = { + "ac_bins": "continuous:", + "ac_noise_std": 0.01, + "nonlin_type": "tanh", + "hidden_dims": [256, 256], + "connection_type": "ff" + } - # Create the shared noise table. - print("Creating shared noise table.") - noise_id = create_shared_noise.remote() - self.noise = SharedNoiseTable(ray.get(noise_id)) + # Create the shared noise table. + print("Creating shared noise table.") + noise_id = create_shared_noise.remote() + self.noise = SharedNoiseTable(ray.get(noise_id)) - # Create the actors. - print("Creating actors.") - self.workers = [Worker.remote(config, policy_params, env_name, noise_id) - for _ in range(config["num_workers"])] + # Create the actors. + print("Creating actors.") + self.workers = [Worker.remote(config, policy_params, env_name, + noise_id) + for _ in range(config["num_workers"])] - env = gym.make(env_name) - utils.make_session(single_threaded=False) - self.policy = policies.MujocoPolicy( - env.observation_space, env.action_space, **policy_params) - tf_util.initialize() - self.optimizer = optimizers.Adam(self.policy, config["stepsize"]) - self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2) + env = gym.make(env_name) + utils.make_session(single_threaded=False) + self.policy = policies.MujocoPolicy( + env.observation_space, env.action_space, **policy_params) + tf_util.initialize() + self.optimizer = optimizers.Adam(self.policy, config["stepsize"]) + self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2) - self.episodes_so_far = 0 - self.timesteps_so_far = 0 - self.tstart = time.time() - self.iteration = 0 + self.episodes_so_far = 0 + self.timesteps_so_far = 0 + self.tstart = time.time() + self.iteration = 0 - def train(self): - config = self.config + def train(self): + config = self.config - step_tstart = time.time() - theta = self.policy.get_trainable_flat() - assert theta.dtype == np.float32 + step_tstart = time.time() + theta = self.policy.get_trainable_flat() + assert theta.dtype == np.float32 - # Put the current policy weights in the object store. - theta_id = ray.put(theta) - # Use the actors to do rollouts, note that we pass in the ID of the policy - # weights. - rollout_ids = [worker.do_rollouts.remote( - theta_id, - self.ob_stat.mean if self.policy.needs_ob_stat else None, - self.ob_stat.std if self.policy.needs_ob_stat else None) - for worker in self.workers] + # Put the current policy weights in the object store. + theta_id = ray.put(theta) + # Use the actors to do rollouts, note that we pass in the ID of the + # policy weights. + rollout_ids = [worker.do_rollouts.remote( + theta_id, + self.ob_stat.mean if self.policy.needs_ob_stat else None, + self.ob_stat.std if self.policy.needs_ob_stat else None) + for worker in self.workers] - # Get the results of the rollouts. - results = ray.get(rollout_ids) + # Get the results of the rollouts. + results = ray.get(rollout_ids) - curr_task_results = [] - ob_count_this_batch = 0 - # Loop over the results - for result in results: - assert result.eval_length is None, "We aren't doing eval rollouts." - assert result.noise_inds_n.ndim == 1 - assert result.returns_n2.shape == (len(result.noise_inds_n), 2) - assert result.lengths_n2.shape == (len(result.noise_inds_n), 2) - assert result.returns_n2.dtype == np.float32 + curr_task_results = [] + ob_count_this_batch = 0 + # Loop over the results + for result in results: + assert result.eval_length is None, "We aren't doing eval rollouts." + assert result.noise_inds_n.ndim == 1 + assert result.returns_n2.shape == (len(result.noise_inds_n), 2) + assert result.lengths_n2.shape == (len(result.noise_inds_n), 2) + assert result.returns_n2.dtype == np.float32 - result_num_eps = result.lengths_n2.size - result_num_timesteps = result.lengths_n2.sum() - self.episodes_so_far += result_num_eps - self.timesteps_so_far += result_num_timesteps + result_num_eps = result.lengths_n2.size + result_num_timesteps = result.lengths_n2.sum() + self.episodes_so_far += result_num_eps + self.timesteps_so_far += result_num_timesteps - curr_task_results.append(result) - # Update ob stats. - if self.policy.needs_ob_stat and result.ob_count > 0: - self.ob_stat.increment(result.ob_sum, result.ob_sumsq, result.ob_count) - ob_count_this_batch += result.ob_count + curr_task_results.append(result) + # Update ob stats. + if self.policy.needs_ob_stat and result.ob_count > 0: + self.ob_stat.increment(result.ob_sum, result.ob_sumsq, + result.ob_count) + ob_count_this_batch += result.ob_count - # Assemble the results. - noise_inds_n = np.concatenate([r.noise_inds_n for - r in curr_task_results]) - returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results]) - lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results]) - assert noise_inds_n.shape[0] == returns_n2.shape[0] == lengths_n2.shape[0] - # Process the returns. - if config["return_proc_mode"] == "centered_rank": - proc_returns_n2 = utils.compute_centered_ranks(returns_n2) - else: - raise NotImplementedError(config["return_proc_mode"]) + # Assemble the results. + noise_inds_n = np.concatenate([r.noise_inds_n for + r in curr_task_results]) + returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results]) + lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results]) + assert (noise_inds_n.shape[0] == returns_n2.shape[0] == + lengths_n2.shape[0]) + # Process the returns. + if config["return_proc_mode"] == "centered_rank": + proc_returns_n2 = utils.compute_centered_ranks(returns_n2) + else: + raise NotImplementedError(config["return_proc_mode"]) - # Compute and take a step. - g, count = utils.batched_weighted_sum( - proc_returns_n2[:, 0] - proc_returns_n2[:, 1], - (self.noise.get(idx, self.policy.num_params) for idx in noise_inds_n), - batch_size=500) - g /= returns_n2.size - assert (g.shape == (self.policy.num_params,) and g.dtype == np.float32 and - count == len(noise_inds_n)) - update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta) + # Compute and take a step. + g, count = utils.batched_weighted_sum( + proc_returns_n2[:, 0] - proc_returns_n2[:, 1], + (self.noise.get(idx, self.policy.num_params) + for idx in noise_inds_n), + batch_size=500) + g /= returns_n2.size + assert (g.shape == (self.policy.num_params,) and + g.dtype == np.float32 and + count == len(noise_inds_n)) + update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta) - # Update ob stat (we're never running the policy in the master, but we - # might be snapshotting the policy). - if self.policy.needs_ob_stat: - self.policy.set_ob_stat(self.ob_stat.mean, self.ob_stat.std) + # Update ob stat (we're never running the policy in the master, but we + # might be snapshotting the policy). + if self.policy.needs_ob_stat: + self.policy.set_ob_stat(self.ob_stat.mean, self.ob_stat.std) - step_tend = time.time() - tlogger.record_tabular("EpRewMean", returns_n2.mean()) - tlogger.record_tabular("EpRewStd", returns_n2.std()) - tlogger.record_tabular("EpLenMean", lengths_n2.mean()) + step_tend = time.time() + tlogger.record_tabular("EpRewMean", returns_n2.mean()) + tlogger.record_tabular("EpRewStd", returns_n2.std()) + tlogger.record_tabular("EpLenMean", lengths_n2.mean()) - tlogger.record_tabular( - "Norm", float(np.square(self.policy.get_trainable_flat()).sum())) - tlogger.record_tabular("GradNorm", float(np.square(g).sum())) - tlogger.record_tabular("UpdateRatio", float(update_ratio)) + tlogger.record_tabular( + "Norm", float(np.square(self.policy.get_trainable_flat()).sum())) + tlogger.record_tabular("GradNorm", float(np.square(g).sum())) + tlogger.record_tabular("UpdateRatio", float(update_ratio)) - tlogger.record_tabular("EpisodesThisIter", lengths_n2.size) - tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far) - tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum()) - tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far) + tlogger.record_tabular("EpisodesThisIter", lengths_n2.size) + tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far) + tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum()) + tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far) - tlogger.record_tabular("ObCount", ob_count_this_batch) + tlogger.record_tabular("ObCount", ob_count_this_batch) - tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart) - tlogger.record_tabular("TimeElapsed", step_tend - self.tstart) - tlogger.dump_tabular() + tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart) + tlogger.record_tabular("TimeElapsed", step_tend - self.tstart) + tlogger.dump_tabular() - if (config["snapshot_freq"] != 0 and - self.iteration % config["snapshot_freq"] == 0): - filename = os.path.join( - self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration)) - assert not os.path.exists(filename) - self.policy.save(filename) - tlogger.log("Saved snapshot {}".format(filename)) + if (config["snapshot_freq"] != 0 and + self.iteration % config["snapshot_freq"] == 0): + filename = os.path.join( + self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration)) + assert not os.path.exists(filename) + self.policy.save(filename) + tlogger.log("Saved snapshot {}".format(filename)) - info = { - "weights_norm": np.square(self.policy.get_trainable_flat()).sum(), - "grad_norm": np.square(g).sum(), - "update_ratio": update_ratio, - "episodes_this_iter": lengths_n2.size, - "episodes_so_far": self.episodes_so_far, - "timesteps_this_iter": lengths_n2.sum(), - "timesteps_so_far": self.timesteps_so_far, - "ob_count": ob_count_this_batch, - "time_elapsed_this_iter": step_tend - step_tstart, - "time_elapsed": step_tend - self.tstart - } - res = TrainingResult(self.experiment_id.hex, self.iteration, - returns_n2.mean(), lengths_n2.mean(), info) + info = { + "weights_norm": np.square(self.policy.get_trainable_flat()).sum(), + "grad_norm": np.square(g).sum(), + "update_ratio": update_ratio, + "episodes_this_iter": lengths_n2.size, + "episodes_so_far": self.episodes_so_far, + "timesteps_this_iter": lengths_n2.sum(), + "timesteps_so_far": self.timesteps_so_far, + "ob_count": ob_count_this_batch, + "time_elapsed_this_iter": step_tend - step_tstart, + "time_elapsed": step_tend - self.tstart + } + res = TrainingResult(self.experiment_id.hex, self.iteration, + returns_n2.mean(), lengths_n2.mean(), info) - self.iteration += 1 + self.iteration += 1 - return res + return res diff --git a/python/ray/rllib/evolution_strategies/example.py b/python/ray/rllib/evolution_strategies/example.py index 5af53e367..e2ee67295 100755 --- a/python/ray/rllib/evolution_strategies/example.py +++ b/python/ray/rllib/evolution_strategies/example.py @@ -11,30 +11,30 @@ from ray.rllib.evolution_strategies import EvolutionStrategies, DEFAULT_CONFIG if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Train an RL agent on Pong.") - parser.add_argument("--num-workers", default=10, type=int, - help=("The number of actors to create in aggregate " - "across the cluster.")) - parser.add_argument("--env-name", default="Pendulum-v0", type=str, - help="The name of the gym environment to use.") - parser.add_argument("--stepsize", default=0.01, type=float, - help="The stepsize to use.") - parser.add_argument("--redis-address", default=None, type=str, - help="The Redis address of the cluster.") + parser = argparse.ArgumentParser(description="Train an RL agent on Pong.") + parser.add_argument("--num-workers", default=10, type=int, + help=("The number of actors to create in aggregate " + "across the cluster.")) + parser.add_argument("--env-name", default="Pendulum-v0", type=str, + help="The name of the gym environment to use.") + parser.add_argument("--stepsize", default=0.01, type=float, + help="The stepsize to use.") + parser.add_argument("--redis-address", default=None, type=str, + help="The Redis address of the cluster.") - args = parser.parse_args() - num_workers = args.num_workers - env_name = args.env_name - stepsize = args.stepsize + args = parser.parse_args() + num_workers = args.num_workers + env_name = args.env_name + stepsize = args.stepsize - ray.init(redis_address=args.redis_address, - num_workers=(0 if args.redis_address is None else None)) + ray.init(redis_address=args.redis_address, + num_workers=(0 if args.redis_address is None else None)) - config = DEFAULT_CONFIG._replace( - num_workers=num_workers, - stepsize=stepsize) + config = DEFAULT_CONFIG._replace( + num_workers=num_workers, + stepsize=stepsize) - alg = EvolutionStrategies(env_name, config) - while True: - result = alg.train() - print("current status: {}".format(result)) + alg = EvolutionStrategies(env_name, config) + while True: + result = alg.train() + print("current status: {}".format(result)) diff --git a/python/ray/rllib/evolution_strategies/optimizers.py b/python/ray/rllib/evolution_strategies/optimizers.py index d768314c5..66a0ca69d 100644 --- a/python/ray/rllib/evolution_strategies/optimizers.py +++ b/python/ray/rllib/evolution_strategies/optimizers.py @@ -9,49 +9,49 @@ import numpy as np class Optimizer(object): - def __init__(self, pi): - self.pi = pi - self.dim = pi.num_params - self.t = 0 + def __init__(self, pi): + self.pi = pi + self.dim = pi.num_params + self.t = 0 - def update(self, globalg): - self.t += 1 - step = self._compute_step(globalg) - theta = self.pi.get_trainable_flat() - ratio = np.linalg.norm(step) / np.linalg.norm(theta) - self.pi.set_trainable_flat(theta + step) - return ratio + def update(self, globalg): + self.t += 1 + step = self._compute_step(globalg) + theta = self.pi.get_trainable_flat() + ratio = np.linalg.norm(step) / np.linalg.norm(theta) + self.pi.set_trainable_flat(theta + step) + return ratio - def _compute_step(self, globalg): - raise NotImplementedError + def _compute_step(self, globalg): + raise NotImplementedError class SGD(Optimizer): - def __init__(self, pi, stepsize, momentum=0.9): - Optimizer.__init__(self, pi) - self.v = np.zeros(self.dim, dtype=np.float32) - self.stepsize, self.momentum = stepsize, momentum + def __init__(self, pi, stepsize, momentum=0.9): + Optimizer.__init__(self, pi) + self.v = np.zeros(self.dim, dtype=np.float32) + self.stepsize, self.momentum = stepsize, momentum - def _compute_step(self, globalg): - self.v = self.momentum * self.v + (1. - self.momentum) * globalg - step = -self.stepsize * self.v - return step + def _compute_step(self, globalg): + self.v = self.momentum * self.v + (1. - self.momentum) * globalg + step = -self.stepsize * self.v + return step class Adam(Optimizer): - def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08): - Optimizer.__init__(self, pi) - self.stepsize = stepsize - self.beta1 = beta1 - self.beta2 = beta2 - self.epsilon = epsilon - self.m = np.zeros(self.dim, dtype=np.float32) - self.v = np.zeros(self.dim, dtype=np.float32) + def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08): + Optimizer.__init__(self, pi) + self.stepsize = stepsize + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.m = np.zeros(self.dim, dtype=np.float32) + self.v = np.zeros(self.dim, dtype=np.float32) - def _compute_step(self, globalg): - a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) / - (1 - self.beta1 ** self.t)) - self.m = self.beta1 * self.m + (1 - self.beta1) * globalg - self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) - step = -a * self.m / (np.sqrt(self.v) + self.epsilon) - return step + def _compute_step(self, globalg): + a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) / + (1 - self.beta1 ** self.t)) + self.m = self.beta1 * self.m + (1 - self.beta1) * globalg + self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) + step = -a * self.m / (np.sqrt(self.v) + self.epsilon) + return step diff --git a/python/ray/rllib/evolution_strategies/policies.py b/python/ray/rllib/evolution_strategies/policies.py index 7997fb8ce..20273130f 100644 --- a/python/ray/rllib/evolution_strategies/policies.py +++ b/python/ray/rllib/evolution_strategies/policies.py @@ -18,224 +18,235 @@ logger = logging.getLogger(__name__) class Policy: - def __init__(self, *args, **kwargs): - self.args, self.kwargs = args, kwargs - self.scope = self._initialize(*args, **kwargs) - self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, - self.scope.name) + def __init__(self, *args, **kwargs): + self.args, self.kwargs = args, kwargs + self.scope = self._initialize(*args, **kwargs) + self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, + self.scope.name) - self.trainable_variables = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name) - self.num_params = sum(int(np.prod(v.get_shape().as_list())) - for v in self.trainable_variables) - self._setfromflat = U.SetFromFlat(self.trainable_variables) - self._getflat = U.GetFlat(self.trainable_variables) + self.trainable_variables = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name) + self.num_params = sum(int(np.prod(v.get_shape().as_list())) + for v in self.trainable_variables) + self._setfromflat = U.SetFromFlat(self.trainable_variables) + self._getflat = U.GetFlat(self.trainable_variables) - logger.info('Trainable variables ({} parameters)'.format(self.num_params)) - for v in self.trainable_variables: - shp = v.get_shape().as_list() - logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp))) - logger.info('All variables') - for v in self.all_variables: - shp = v.get_shape().as_list() - logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp))) + logger.info('Trainable variables ({} parameters)' + .format(self.num_params)) + for v in self.trainable_variables: + shp = v.get_shape().as_list() + logger.info('- {} shape:{} size:{}'.format(v.name, shp, + np.prod(shp))) + logger.info('All variables') + for v in self.all_variables: + shp = v.get_shape().as_list() + logger.info('- {} shape:{} size:{}'.format(v.name, shp, + np.prod(shp))) - placeholders = [tf.placeholder(v.value().dtype, v.get_shape().as_list()) - for v in self.all_variables] - self.set_all_vars = U.function( - inputs=placeholders, - outputs=[], - updates=[tf.group(*[v.assign(p) for v, p - in zip(self.all_variables, placeholders)])] - ) + placeholders = [tf.placeholder(v.value().dtype, + v.get_shape().as_list()) + for v in self.all_variables] + self.set_all_vars = U.function( + inputs=placeholders, + outputs=[], + updates=[tf.group(*[v.assign(p) for v, p + in zip(self.all_variables, placeholders)])] + ) - def _initialize(self, *args, **kwargs): - raise NotImplementedError + def _initialize(self, *args, **kwargs): + raise NotImplementedError - def save(self, filename): - assert filename.endswith('.h5') - with h5py.File(filename, 'w') as f: - for v in self.all_variables: - f[v.name] = v.eval() - # TODO: It would be nice to avoid pickle, but it's convenient to pass - # Python objects to _initialize (like Gym spaces or numpy arrays). - f.attrs['name'] = type(self).__name__ - f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args, - self.kwargs), - protocol=-1)) + def save(self, filename): + assert filename.endswith('.h5') + with h5py.File(filename, 'w') as f: + for v in self.all_variables: + f[v.name] = v.eval() + # TODO: It would be nice to avoid pickle, but it's convenient to + # pass Python objects to _initialize (like Gym spaces or numpy + # arrays). + f.attrs['name'] = type(self).__name__ + f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args, + self.kwargs), + protocol=-1)) - @classmethod - def Load(cls, filename, extra_kwargs=None): - with h5py.File(filename, 'r') as f: - args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring()) - if extra_kwargs: - kwargs.update(extra_kwargs) - policy = cls(*args, **kwargs) - policy.set_all_vars(*[f[v.name][...] for v in policy.all_variables]) - return policy + @classmethod + def Load(cls, filename, extra_kwargs=None): + with h5py.File(filename, 'r') as f: + args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring()) + if extra_kwargs: + kwargs.update(extra_kwargs) + policy = cls(*args, **kwargs) + policy.set_all_vars(*[f[v.name][...] + for v in policy.all_variables]) + return policy - # === Rollouts/training === + # === Rollouts/training === - def rollout(self, env, render=False, timestep_limit=None, save_obs=False, - random_stream=None): - """Do a rollout. + def rollout(self, env, render=False, timestep_limit=None, save_obs=False, + random_stream=None): + """Do a rollout. - If random_stream is provided, the rollout will take noisy actions with - noise drawn from that stream. Otherwise, no action noise will be added. - """ - env_timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit" - ".max_episode_steps") - timestep_limit = (env_timestep_limit if timestep_limit is None - else min(timestep_limit, env_timestep_limit)) - rews = [] - t = 0 - if save_obs: - obs = [] - ob = env.reset() - for _ in range(timestep_limit): - ac = self.act(ob[None], random_stream=random_stream)[0] - if save_obs: - obs.append(ob) - ob, rew, done, _ = env.step(ac) - rews.append(rew) - t += 1 - if render: - env.render() - if done: - break - rews = np.array(rews, dtype=np.float32) - if save_obs: - return rews, t, np.array(obs) - return rews, t + If random_stream is provided, the rollout will take noisy actions with + noise drawn from that stream. Otherwise, no action noise will be added. + """ + env_timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit" + ".max_episode_steps") + timestep_limit = (env_timestep_limit if timestep_limit is None + else min(timestep_limit, env_timestep_limit)) + rews = [] + t = 0 + if save_obs: + obs = [] + ob = env.reset() + for _ in range(timestep_limit): + ac = self.act(ob[None], random_stream=random_stream)[0] + if save_obs: + obs.append(ob) + ob, rew, done, _ = env.step(ac) + rews.append(rew) + t += 1 + if render: + env.render() + if done: + break + rews = np.array(rews, dtype=np.float32) + if save_obs: + return rews, t, np.array(obs) + return rews, t - def act(self, ob, random_stream=None): - raise NotImplementedError + def act(self, ob, random_stream=None): + raise NotImplementedError - def set_trainable_flat(self, x): - self._setfromflat(x) + def set_trainable_flat(self, x): + self._setfromflat(x) - def get_trainable_flat(self): - return self._getflat() + def get_trainable_flat(self): + return self._getflat() - @property - def needs_ob_stat(self): - raise NotImplementedError + @property + def needs_ob_stat(self): + raise NotImplementedError - def set_ob_stat(self, ob_mean, ob_std): - raise NotImplementedError + def set_ob_stat(self, ob_mean, ob_std): + raise NotImplementedError def bins(x, dim, num_bins, name): - scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01)) - scores_nab = tf.reshape(scores, [-1, dim, num_bins]) - return tf.argmax(scores_nab, 2) + scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01)) + scores_nab = tf.reshape(scores, [-1, dim, num_bins]) + return tf.argmax(scores_nab, 2) class MujocoPolicy(Policy): - def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std, nonlin_type, - hidden_dims, connection_type): - self.ac_space = ac_space - self.ac_bins = ac_bins - self.ac_noise_std = ac_noise_std - self.hidden_dims = hidden_dims - self.connection_type = connection_type + def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std, + nonlin_type, hidden_dims, connection_type): + self.ac_space = ac_space + self.ac_bins = ac_bins + self.ac_noise_std = ac_noise_std + self.hidden_dims = hidden_dims + self.connection_type = connection_type - assert len(ob_space.shape) == len(self.ac_space.shape) == 1 - assert (np.all(np.isfinite(self.ac_space.low)) and - np.all(np.isfinite(self.ac_space.high))), "Action bounds required" + assert len(ob_space.shape) == len(self.ac_space.shape) == 1 + assert (np.all(np.isfinite(self.ac_space.low)) and + np.all(np.isfinite(self.ac_space.high))), ("Action bounds " + "required") - self.nonlin = {'tanh': tf.tanh, - 'relu': tf.nn.relu, - 'lrelu': U.lrelu, - 'elu': tf.nn.elu}[nonlin_type] + self.nonlin = {'tanh': tf.tanh, + 'relu': tf.nn.relu, + 'lrelu': U.lrelu, + 'elu': tf.nn.elu}[nonlin_type] - with tf.variable_scope(type(self).__name__) as scope: - # Observation normalization. - ob_mean = tf.get_variable( - 'ob_mean', ob_space.shape, tf.float32, - tf.constant_initializer(np.nan), trainable=False) - ob_std = tf.get_variable( - 'ob_std', ob_space.shape, tf.float32, - tf.constant_initializer(np.nan), trainable=False) - in_mean = tf.placeholder(tf.float32, ob_space.shape) - in_std = tf.placeholder(tf.float32, ob_space.shape) - self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[ - tf.assign(ob_mean, in_mean), - tf.assign(ob_std, in_std), - ]) + with tf.variable_scope(type(self).__name__) as scope: + # Observation normalization. + ob_mean = tf.get_variable( + 'ob_mean', ob_space.shape, tf.float32, + tf.constant_initializer(np.nan), trainable=False) + ob_std = tf.get_variable( + 'ob_std', ob_space.shape, tf.float32, + tf.constant_initializer(np.nan), trainable=False) + in_mean = tf.placeholder(tf.float32, ob_space.shape) + in_std = tf.placeholder(tf.float32, ob_space.shape) + self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[ + tf.assign(ob_mean, in_mean), + tf.assign(ob_std, in_std), + ]) - # Policy network. - o = tf.placeholder(tf.float32, [None] + list(ob_space.shape)) - a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std, -5.0, 5.0)) - self._act = U.function([o], a) - return scope + # Policy network. + o = tf.placeholder(tf.float32, [None] + list(ob_space.shape)) + a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std, + -5.0, 5.0)) + self._act = U.function([o], a) + return scope - def _make_net(self, o): - # Process observation. - if self.connection_type == 'ff': - x = o - for ilayer, hd in enumerate(self.hidden_dims): - x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer), - U.normc_initializer(1.0))) - else: - raise NotImplementedError(self.connection_type) + def _make_net(self, o): + # Process observation. + if self.connection_type == 'ff': + x = o + for ilayer, hd in enumerate(self.hidden_dims): + x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer), + U.normc_initializer(1.0))) + else: + raise NotImplementedError(self.connection_type) - # Map to action. - adim = self.ac_space.shape[0] - ahigh = self.ac_space.high - alow = self.ac_space.low - assert isinstance(self.ac_bins, str) - ac_bin_mode, ac_bin_arg = self.ac_bins.split(':') + # Map to action. + adim = self.ac_space.shape[0] + ahigh = self.ac_space.high + alow = self.ac_space.low + assert isinstance(self.ac_bins, str) + ac_bin_mode, ac_bin_arg = self.ac_bins.split(':') - if ac_bin_mode == 'uniform': - # Uniformly spaced bins, from ac_space.low to ac_space.high. - num_ac_bins = int(ac_bin_arg) - aidx_na = bins(x, adim, num_ac_bins, 'out') - ac_range_1a = (ahigh - alow)[None, :] - a = (1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a + - alow[None, :]) + if ac_bin_mode == 'uniform': + # Uniformly spaced bins, from ac_space.low to ac_space.high. + num_ac_bins = int(ac_bin_arg) + aidx_na = bins(x, adim, num_ac_bins, 'out') + ac_range_1a = (ahigh - alow)[None, :] + a = (1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a + + alow[None, :]) - elif ac_bin_mode == 'custom': - # Custom bins specified as a list of values from -1 to 1. - # The bins are rescaled to ac_space.low to ac_space.high. - acvals_k = np.array(list(map(float, ac_bin_arg.split(','))), - dtype=np.float32) - logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x) - for x in acvals_k)) - assert acvals_k.ndim == 1 and acvals_k[0] == -1 and acvals_k[-1] == 1 - acvals_ak = ((ahigh - alow)[:, None] / (acvals_k[-1] - acvals_k[0]) * - (acvals_k - acvals_k[0])[None, :] + alow[:, None]) + elif ac_bin_mode == 'custom': + # Custom bins specified as a list of values from -1 to 1. + # The bins are rescaled to ac_space.low to ac_space.high. + acvals_k = np.array(list(map(float, ac_bin_arg.split(','))), + dtype=np.float32) + logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x) + for x in acvals_k)) + assert (acvals_k.ndim == 1 and acvals_k[0] == -1 and + acvals_k[-1] == 1) + acvals_ak = ((ahigh - alow)[:, None] / + (acvals_k[-1] - acvals_k[0]) * + (acvals_k - acvals_k[0])[None, :] + alow[:, None]) - aidx_na = bins(x, adim, len(acvals_k), 'out') # Values in [0, k-1]. - a = tf.gather_nd( - acvals_ak, - tf.concat([ - tf.tile(np.arange(adim)[None, :, None], - [tf.shape(aidx_na)[0], 1, 1]), - 2, - tf.expand_dims(aidx_na, -1) - ]) # (n, a, 2) - ) # (n, a) - elif ac_bin_mode == 'continuous': - a = U.dense(x, adim, 'out', U.normc_initializer(0.01)) - else: - raise NotImplementedError(ac_bin_mode) + aidx_na = bins(x, adim, len(acvals_k), + 'out') # Values in [0, k-1]. + a = tf.gather_nd( + acvals_ak, + tf.concat([ + tf.tile(np.arange(adim)[None, :, None], + [tf.shape(aidx_na)[0], 1, 1]), + 2, + tf.expand_dims(aidx_na, -1) + ]) # (n, a, 2) + ) # (n, a) + elif ac_bin_mode == 'continuous': + a = U.dense(x, adim, 'out', U.normc_initializer(0.01)) + else: + raise NotImplementedError(ac_bin_mode) - return a + return a - def act(self, ob, random_stream=None): - a = self._act(ob) - if random_stream is not None and self.ac_noise_std != 0: - a += random_stream.randn(*a.shape) * self.ac_noise_std - return a + def act(self, ob, random_stream=None): + a = self._act(ob) + if random_stream is not None and self.ac_noise_std != 0: + a += random_stream.randn(*a.shape) * self.ac_noise_std + return a - @property - def needs_ob_stat(self): - return True + @property + def needs_ob_stat(self): + return True - @property - def needs_ref_batch(self): - return False + @property + def needs_ref_batch(self): + return False - def set_ob_stat(self, ob_mean, ob_std): - self._set_ob_mean_std(ob_mean, ob_std) + def set_ob_stat(self, ob_mean, ob_std): + self._set_ob_mean_std(ob_mean, ob_std) diff --git a/python/ray/rllib/evolution_strategies/tabular_logger.py b/python/ray/rllib/evolution_strategies/tabular_logger.py index 633dbbdf0..80e7b5b37 100644 --- a/python/ray/rllib/evolution_strategies/tabular_logger.py +++ b/python/ray/rllib/evolution_strategies/tabular_logger.py @@ -24,199 +24,201 @@ DISABLED = 50 class TbWriter(object): - """Based on SummaryWriter, but changed to allow for a different prefix.""" - def __init__(self, dir, prefix): - self.dir = dir - # Start at 1, because EvWriter automatically generates an object with - # step = 0. - self.step = 1 - self.evwriter = pywrap_tensorflow.EventsWriter( - compat.as_bytes(os.path.join(dir, prefix))) + """Based on SummaryWriter, but changed to allow for a different prefix.""" + def __init__(self, dir, prefix): + self.dir = dir + # Start at 1, because EvWriter automatically generates an object with + # step = 0. + self.step = 1 + self.evwriter = pywrap_tensorflow.EventsWriter( + compat.as_bytes(os.path.join(dir, prefix))) - def write_values(self, key2val): - summary = tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=float(v)) - for (k, v) in key2val.items()]) - event = event_pb2.Event(wall_time=time.time(), summary=summary) - event.step = self.step - self.evwriter.WriteEvent(event) - self.evwriter.Flush() - self.step += 1 + def write_values(self, key2val): + summary = tf.Summary(value=[tf.Summary.Value(tag=k, + simple_value=float(v)) + for (k, v) in key2val.items()]) + event = event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = self.step + self.evwriter.WriteEvent(event) + self.evwriter.Flush() + self.step += 1 - def close(self): - self.evwriter.Close() + def close(self): + self.evwriter.Close() # API def start(dir): - if _Logger.CURRENT is not _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to start logging (dir=%s), but you " - "never stopped the previous logger (dir=%s)." - "\n" % (dir, _Logger.CURRENT.dir)) - _Logger.CURRENT = _Logger(dir=dir) + if _Logger.CURRENT is not _Logger.DEFAULT: + sys.stderr.write("WARNING: You asked to start logging (dir=%s), but " + "you never stopped the previous logger (dir=%s)." + "\n" % (dir, _Logger.CURRENT.dir)) + _Logger.CURRENT = _Logger(dir=dir) def stop(): - if _Logger.CURRENT is _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to stop logging, but you never " - "started any previous logger." - "\n" % (dir, _Logger.CURRENT.dir)) - return - _Logger.CURRENT.close() - _Logger.CURRENT = _Logger.DEFAULT + if _Logger.CURRENT is _Logger.DEFAULT: + sys.stderr.write("WARNING: You asked to stop logging, but you never " + "started any previous logger." + "\n" % (dir, _Logger.CURRENT.dir)) + return + _Logger.CURRENT.close() + _Logger.CURRENT = _Logger.DEFAULT def record_tabular(key, val): - """Log a value of some diagnostic. + """Log a value of some diagnostic. - Call this once for each diagnostic quantity, each iteration. - """ - _Logger.CURRENT.record_tabular(key, val) + Call this once for each diagnostic quantity, each iteration. + """ + _Logger.CURRENT.record_tabular(key, val) def dump_tabular(): - """Write all of the diagnostics from the current iteration.""" - _Logger.CURRENT.dump_tabular() + """Write all of the diagnostics from the current iteration.""" + _Logger.CURRENT.dump_tabular() def log(*args, **kwargs): - """Write the sequence of args, with no separators. + """Write the sequence of args, with no separators. - This is written to the console and output files (if you've configured an - output file). - """ - level = kwargs['level'] if 'level' in kwargs else INFO - _Logger.CURRENT.log(*args, level=level) + This is written to the console and output files (if you've configured an + output file). + """ + level = kwargs['level'] if 'level' in kwargs else INFO + _Logger.CURRENT.log(*args, level=level) def debug(*args): - log(*args, level=DEBUG) + log(*args, level=DEBUG) def info(*args): - log(*args, level=INFO) + log(*args, level=INFO) def warn(*args): - log(*args, level=WARN) + log(*args, level=WARN) def error(*args): - log(*args, level=ERROR) + log(*args, level=ERROR) def set_level(level): - """ - Set logging threshold on current logger. - """ - _Logger.CURRENT.set_level(level) + """ + Set logging threshold on current logger. + """ + _Logger.CURRENT.set_level(level) def get_dir(): - """ - Get directory that log files are being written to. - will be None if there is no output directory (i.e., if you didn't call start) - """ - return _Logger.CURRENT.get_dir() + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call + start) + """ + return _Logger.CURRENT.get_dir() def get_expt_dir(): - sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n") - return get_dir() + sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n") + return get_dir() # Backend class _Logger(object): - # A logger with no output files. (See right below class definition) so that - # you can still log to the terminal without setting up any output files. - DEFAULT = None - # Current logger being used by the free functions above. - CURRENT = None + # A logger with no output files. (See right below class definition) so that + # you can still log to the terminal without setting up any output files. + DEFAULT = None + # Current logger being used by the free functions above. + CURRENT = None - def __init__(self, dir=None): - self.name2val = OrderedDict() # Values this iteration. - self.level = INFO - self.dir = dir - self.text_outputs = [sys.stdout] - if dir is not None: - os.makedirs(dir, exist_ok=True) - self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w")) - self.tbwriter = TbWriter(dir=dir, prefix="events") - else: - self.tbwriter = None + def __init__(self, dir=None): + self.name2val = OrderedDict() # Values this iteration. + self.level = INFO + self.dir = dir + self.text_outputs = [sys.stdout] + if dir is not None: + os.makedirs(dir, exist_ok=True) + self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w")) + self.tbwriter = TbWriter(dir=dir, prefix="events") + else: + self.tbwriter = None - # Logging API, forwarded + # Logging API, forwarded - def record_tabular(self, key, val): - self.name2val[key] = val + def record_tabular(self, key, val): + self.name2val[key] = val - def dump_tabular(self): - # Create strings for printing. - key2str = OrderedDict() - for (key, val) in self.name2val.items(): - if hasattr(val, "__float__"): - valstr = "%-8.3g" % val - else: - valstr = val - key2str[self._truncate(key)] = self._truncate(valstr) - keywidth = max(map(len, key2str.keys())) - valwidth = max(map(len, key2str.values())) - # Write to all text outputs - self._write_text("-" * (keywidth + valwidth + 7), "\n") - for (key, val) in key2str.items(): - self._write_text("| ", key, " " * (keywidth - len(key)), " | ", val, - " " * (valwidth - len(val)), " |\n") - self._write_text("-" * (keywidth + valwidth + 7), "\n") - for f in self.text_outputs: - try: - f.flush() - except OSError: - sys.stderr.write('Warning! OSError when flushing.\n') - # Write to tensorboard - if self.tbwriter is not None: - self.tbwriter.write_values(self.name2val) - self.name2val.clear() + def dump_tabular(self): + # Create strings for printing. + key2str = OrderedDict() + for (key, val) in self.name2val.items(): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = val + key2str[self._truncate(key)] = self._truncate(valstr) + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + # Write to all text outputs + self._write_text("-" * (keywidth + valwidth + 7), "\n") + for (key, val) in key2str.items(): + self._write_text("| ", key, " " * (keywidth - len(key)), + " | ", val, " " * (valwidth - len(val)), " |\n") + self._write_text("-" * (keywidth + valwidth + 7), "\n") + for f in self.text_outputs: + try: + f.flush() + except OSError: + sys.stderr.write('Warning! OSError when flushing.\n') + # Write to tensorboard + if self.tbwriter is not None: + self.tbwriter.write_values(self.name2val) + self.name2val.clear() - def log(self, *args, **kwargs): - level = kwargs['level'] if 'level' in kwargs else INFO - if self.level <= level: - self._do_log(*args) + def log(self, *args, **kwargs): + level = kwargs['level'] if 'level' in kwargs else INFO + if self.level <= level: + self._do_log(*args) - # Configuration + # Configuration - def set_level(self, level): - self.level = level + def set_level(self, level): + self.level = level - def get_dir(self): - return self.dir + def get_dir(self): + return self.dir - def close(self): - for f in self.text_outputs[1:]: - f.close() - if self.tbwriter: - self.tbwriter.close() + def close(self): + for f in self.text_outputs[1:]: + f.close() + if self.tbwriter: + self.tbwriter.close() - # Misc + # Misc - def _do_log(self, *args): - self._write_text(*args + ('\n',)) - for f in self.text_outputs: - try: - f.flush() - except OSError: - print('Warning! OSError when flushing.') + def _do_log(self, *args): + self._write_text(*args + ('\n',)) + for f in self.text_outputs: + try: + f.flush() + except OSError: + print('Warning! OSError when flushing.') - def _write_text(self, *strings): - for f in self.text_outputs: - for string in strings: - f.write(string) + def _write_text(self, *strings): + for f in self.text_outputs: + for string in strings: + f.write(string) - def _truncate(self, s): - if len(s) > 33: - return s[:30] + "..." - else: - return s + def _truncate(self, s): + if len(s) > 33: + return s[:30] + "..." + else: + return s _Logger.DEFAULT = _Logger() diff --git a/python/ray/rllib/evolution_strategies/tf_util.py b/python/ray/rllib/evolution_strategies/tf_util.py index 1143d4bcb..5a7cc63fb 100644 --- a/python/ray/rllib/evolution_strategies/tf_util.py +++ b/python/ray/rllib/evolution_strategies/tf_util.py @@ -12,8 +12,8 @@ import os # Tensorflow must be at least version 1.0.0 for the example to work. if int(tf.__version__.split(".")[0]) < 1: - raise Exception("Your Tensorflow version is less than 1.0.0. Please update " - "Tensorflow to the latest version.") + raise Exception("Your Tensorflow version is less than 1.0.0. Please " + "update Tensorflow to the latest version.") # ================================================================ # Import all names into common namespace @@ -25,160 +25,163 @@ clip = tf.clip_by_value def sum(x, axis=None, keepdims=False): - return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis], - keep_dims=keepdims) + return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis], + keep_dims=keepdims) def mean(x, axis=None, keepdims=False): - return tf.reduce_mean(x, reduction_indices=None if axis is None else [axis], - keep_dims=keepdims) + return tf.reduce_mean(x, reduction_indices=(None if axis is None + else [axis]), + keep_dims=keepdims) def var(x, axis=None, keepdims=False): - meanx = mean(x, axis=axis, keepdims=keepdims) - return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims) + meanx = mean(x, axis=axis, keepdims=keepdims) + return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims) def std(x, axis=None, keepdims=False): - return tf.sqrt(var(x, axis=axis, keepdims=keepdims)) + return tf.sqrt(var(x, axis=axis, keepdims=keepdims)) def max(x, axis=None, keepdims=False): - return tf.reduce_max(x, reduction_indices=None if axis is None else [axis], - keep_dims=keepdims) + return tf.reduce_max(x, reduction_indices=None if axis is None else [axis], + keep_dims=keepdims) def min(x, axis=None, keepdims=False): - return tf.reduce_min(x, reduction_indices=None if axis is None else [axis], - keep_dims=keepdims) + return tf.reduce_min(x, reduction_indices=None if axis is None else [axis], + keep_dims=keepdims) def concatenate(arrs, axis=0): - return tf.concat(arrs, axis) + return tf.concat(arrs, axis) def argmax(x, axis=None): - return tf.argmax(x, dimension=axis) + return tf.argmax(x, dimension=axis) # Extras def l2loss(params): - if len(params) == 0: - return tf.constant(0.0) - else: - return tf.add_n([sum(tf.square(p)) for p in params]) + if len(params) == 0: + return tf.constant(0.0) + else: + return tf.add_n([sum(tf.square(p)) for p in params]) def lrelu(x, leak=0.2): - f1 = 0.5 * (1 + leak) - f2 = 0.5 * (1 - leak) - return f1 * x + f2 * abs(x) + f1 = 0.5 * (1 + leak) + f2 = 0.5 * (1 - leak) + return f1 * x + f2 * abs(x) def categorical_sample_logits(X): - # https://github.com/tensorflow/tensorflow/issues/456 - U = tf.random_uniform(tf.shape(X)) - return argmax(X - tf.log(-tf.log(U)), axis=1) + # https://github.com/tensorflow/tensorflow/issues/456 + U = tf.random_uniform(tf.shape(X)) + return argmax(X - tf.log(-tf.log(U)), axis=1) # Global session def get_session(): - return tf.get_default_session() + return tf.get_default_session() def single_threaded_session(): - tf_config = tf.ConfigProto(inter_op_parallelism_threads=1, - intra_op_parallelism_threads=1) - return tf.Session(config=tf_config) + tf_config = tf.ConfigProto(inter_op_parallelism_threads=1, + intra_op_parallelism_threads=1) + return tf.Session(config=tf_config) ALREADY_INITIALIZED = set() def initialize(): - new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED - get_session().run(tf.variables_initializer(new_variables)) - ALREADY_INITIALIZED.update(new_variables) + new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED + get_session().run(tf.variables_initializer(new_variables)) + ALREADY_INITIALIZED.update(new_variables) def eval(expr, feed_dict=None): - if feed_dict is None: - feed_dict = {} - return get_session().run(expr, feed_dict=feed_dict) + if feed_dict is None: + feed_dict = {} + return get_session().run(expr, feed_dict=feed_dict) def set_value(v, val): - get_session().run(v.assign(val)) + get_session().run(v.assign(val)) def load_state(fname): - saver = tf.train.Saver() - saver.restore(get_session(), fname) + saver = tf.train.Saver() + saver.restore(get_session(), fname) def save_state(fname): - os.makedirs(os.path.dirname(fname), exist_ok=True) - saver = tf.train.Saver() - saver.save(get_session(), fname) + os.makedirs(os.path.dirname(fname), exist_ok=True) + saver = tf.train.Saver() + saver.save(get_session(), fname) # Model components def normc_initializer(std=1.0): - def _initializer(shape, dtype=None, partition_info=None): - out = np.random.randn(*shape).astype(np.float32) - out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) - return tf.constant(out) - return _initializer + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer def dense(x, size, name, weight_init=None, bias=True): - w = tf.get_variable(name + "/w", [x.get_shape()[1], size], - initializer=weight_init) - ret = tf.matmul(x, w) - if bias: - b = tf.get_variable(name + "/b", [size], - initializer=tf.zeros_initializer()) - return ret + b - else: - return ret + w = tf.get_variable(name + "/w", [x.get_shape()[1], size], + initializer=weight_init) + ret = tf.matmul(x, w) + if bias: + b = tf.get_variable(name + "/b", [size], + initializer=tf.zeros_initializer()) + return ret + b + else: + return ret # Basic Stuff def function(inputs, outputs, updates=None, givens=None): - if isinstance(outputs, list): - return _Function(inputs, outputs, updates, givens=givens) - elif isinstance(outputs, dict): - f = _Function(inputs, outputs.values(), updates, givens=givens) - return lambda *inputs: dict(zip(outputs.keys(), f(*inputs))) - else: - f = _Function(inputs, [outputs], updates, givens=givens) - return lambda *inputs: f(*inputs)[0] + if isinstance(outputs, list): + return _Function(inputs, outputs, updates, givens=givens) + elif isinstance(outputs, dict): + f = _Function(inputs, outputs.values(), updates, givens=givens) + return lambda *inputs: dict(zip(outputs.keys(), f(*inputs))) + else: + f = _Function(inputs, [outputs], updates, givens=givens) + return lambda *inputs: f(*inputs)[0] class _Function(object): - def __init__(self, inputs, outputs, updates, givens, check_nan=False): - assert all(len(i.op.inputs) == 0 for i in inputs), ("inputs should all be " - "placeholders") - self.inputs = inputs - updates = updates or [] - self.update_group = tf.group(*updates) - self.outputs_update = list(outputs) + [self.update_group] - self.givens = {} if givens is None else givens - self.check_nan = check_nan + def __init__(self, inputs, outputs, updates, givens, check_nan=False): + assert all(len(i.op.inputs) == 0 for i in inputs), ("inputs should " + "all be " + "placeholders") + self.inputs = inputs + updates = updates or [] + self.update_group = tf.group(*updates) + self.outputs_update = list(outputs) + [self.update_group] + self.givens = {} if givens is None else givens + self.check_nan = check_nan - def __call__(self, *inputvals): - assert len(inputvals) == len(self.inputs) - feed_dict = dict(zip(self.inputs, inputvals)) - feed_dict.update(self.givens) - results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] - if self.check_nan: - if any(np.isnan(r).any() for r in results): - raise RuntimeError("Nan detected") - return results + def __call__(self, *inputvals): + assert len(inputvals) == len(self.inputs) + feed_dict = dict(zip(self.inputs, inputvals)) + feed_dict.update(self.givens) + results = get_session().run(self.outputs_update, + feed_dict=feed_dict)[:-1] + if self.check_nan: + if any(np.isnan(r).any() for r in results): + raise RuntimeError("Nan detected") + return results # Graph traversal @@ -189,71 +192,72 @@ VARIABLES = {} def var_shape(x): - out = [k.value for k in x.get_shape()] - assert all(isinstance(a, int) for a in out), ("shape function assumes that " - "shape is fully known") - return out + out = [k.value for k in x.get_shape()] + assert all(isinstance(a, int) for a in out), ("shape function assumes " + "that shape is fully known") + return out def numel(x): - return intprod(var_shape(x)) + return intprod(var_shape(x)) def intprod(x): - return int(np.prod(x)) + return int(np.prod(x)) def flatgrad(loss, var_list): - grads = tf.gradients(loss, var_list) - return tf.concat([tf.reshape(grad, [numel(v)], 0) - for (v, grad) in zip(var_list, grads)]) + grads = tf.gradients(loss, var_list) + return tf.concat([tf.reshape(grad, [numel(v)], 0) + for (v, grad) in zip(var_list, grads)]) class SetFromFlat(object): - def __init__(self, var_list, dtype=tf.float32): - assigns = [] - shapes = list(map(var_shape, var_list)) - total_size = np.sum([intprod(shape) for shape in shapes]) + def __init__(self, var_list, dtype=tf.float32): + assigns = [] + shapes = list(map(var_shape, var_list)) + total_size = np.sum([intprod(shape) for shape in shapes]) - self.theta = theta = tf.placeholder(dtype, [total_size]) - start = 0 - assigns = [] - for (shape, v) in zip(shapes, var_list): - size = intprod(shape) - assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], - shape))) - start += size - assert start == total_size - self.op = tf.group(*assigns) + self.theta = theta = tf.placeholder(dtype, [total_size]) + start = 0 + assigns = [] + for (shape, v) in zip(shapes, var_list): + size = intprod(shape) + assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], + shape))) + start += size + assert start == total_size + self.op = tf.group(*assigns) - def __call__(self, theta): - get_session().run(self.op, feed_dict={self.theta: theta}) + def __call__(self, theta): + get_session().run(self.op, feed_dict={self.theta: theta}) class GetFlat(object): - def __init__(self, var_list): - self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0) + def __init__(self, var_list): + self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0) - def __call__(self): - return get_session().run(self.op) + def __call__(self): + return get_session().run(self.op) # Misc def scope_vars(scope, trainable_only): - """Get variables inside a scope. The scope can be specified as a string.""" - return tf.get_collection((tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only - else tf.GraphKeys.GLOBAL_VARIABLES), - scope=(scope if isinstance(scope, str) - else scope.name)) + """Get variables inside a scope. The scope can be specified as a string.""" + return tf.get_collection((tf.GraphKeys.TRAINABLE_VARIABLES + if trainable_only + else tf.GraphKeys.GLOBAL_VARIABLES), + scope=(scope if isinstance(scope, str) + else scope.name)) def in_session(f): - @functools.wraps(f) - def newfunc(*args, **kwargs): - with tf.Session(): - f(*args, **kwargs) - return newfunc + @functools.wraps(f) + def newfunc(*args, **kwargs): + with tf.Session(): + f(*args, **kwargs) + return newfunc # A mapping from name -> (placeholder, dtype, shape). @@ -261,28 +265,28 @@ _PLACEHOLDER_CACHE = {} def get_placeholder(name, dtype, shape): - print("calling get_placeholder", name) - if name in _PLACEHOLDER_CACHE: - out, dtype1, shape1 = _PLACEHOLDER_CACHE[name] - assert dtype1 == dtype and shape1 == shape - return out - else: - out = tf.placeholder(dtype=dtype, shape=shape, name=name) - _PLACEHOLDER_CACHE[name] = (out, dtype, shape) - return out + print("calling get_placeholder", name) + if name in _PLACEHOLDER_CACHE: + out, dtype1, shape1 = _PLACEHOLDER_CACHE[name] + assert dtype1 == dtype and shape1 == shape + return out + else: + out = tf.placeholder(dtype=dtype, shape=shape, name=name) + _PLACEHOLDER_CACHE[name] = (out, dtype, shape) + return out def get_placeholder_cached(name): - return _PLACEHOLDER_CACHE[name][0] + return _PLACEHOLDER_CACHE[name][0] def flattenallbut0(x): - return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) + return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) def reset(): - global _PLACEHOLDER_CACHE - global VARIABLES - _PLACEHOLDER_CACHE = {} - VARIABLES = {} - tf.reset_default_graph() + global _PLACEHOLDER_CACHE + global VARIABLES + _PLACEHOLDER_CACHE = {} + VARIABLES = {} + tf.reset_default_graph() diff --git a/python/ray/rllib/evolution_strategies/utils.py b/python/ray/rllib/evolution_strategies/utils.py index 5a9c3a386..733badf72 100644 --- a/python/ray/rllib/evolution_strategies/utils.py +++ b/python/ray/rllib/evolution_strategies/utils.py @@ -10,77 +10,78 @@ import tensorflow as tf def compute_ranks(x): - """Returns ranks in [0, len(x)) + """Returns ranks in [0, len(x)) - Note: This is different from scipy.stats.rankdata, which returns ranks in - [1, len(x)]. - """ - assert x.ndim == 1 - ranks = np.empty(len(x), dtype=int) - ranks[x.argsort()] = np.arange(len(x)) - return ranks + Note: This is different from scipy.stats.rankdata, which returns ranks in + [1, len(x)]. + """ + assert x.ndim == 1 + ranks = np.empty(len(x), dtype=int) + ranks[x.argsort()] = np.arange(len(x)) + return ranks def compute_centered_ranks(x): - y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32) - y /= (x.size - 1) - y -= 0.5 - return y + y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32) + y /= (x.size - 1) + y -= 0.5 + return y def make_session(single_threaded): - if not single_threaded: - return tf.InteractiveSession() - return tf.InteractiveSession( - config=tf.ConfigProto(inter_op_parallelism_threads=1, - intra_op_parallelism_threads=1)) + if not single_threaded: + return tf.InteractiveSession() + return tf.InteractiveSession( + config=tf.ConfigProto(inter_op_parallelism_threads=1, + intra_op_parallelism_threads=1)) def itergroups(items, group_size): - assert group_size >= 1 - group = [] - for x in items: - group.append(x) - if len(group) == group_size: - yield tuple(group) - del group[:] - if group: - yield tuple(group) + assert group_size >= 1 + group = [] + for x in items: + group.append(x) + if len(group) == group_size: + yield tuple(group) + del group[:] + if group: + yield tuple(group) def batched_weighted_sum(weights, vecs, batch_size): - total = 0 - num_items_summed = 0 - for batch_weights, batch_vecs in zip(itergroups(weights, batch_size), - itergroups(vecs, batch_size)): - assert len(batch_weights) == len(batch_vecs) <= batch_size - total += np.dot(np.asarray(batch_weights, dtype=np.float32), - np.asarray(batch_vecs, dtype=np.float32)) - num_items_summed += len(batch_weights) - return total, num_items_summed + total = 0 + num_items_summed = 0 + for batch_weights, batch_vecs in zip(itergroups(weights, batch_size), + itergroups(vecs, batch_size)): + assert len(batch_weights) == len(batch_vecs) <= batch_size + total += np.dot(np.asarray(batch_weights, dtype=np.float32), + np.asarray(batch_vecs, dtype=np.float32)) + num_items_summed += len(batch_weights) + return total, num_items_summed class RunningStat(object): - def __init__(self, shape, eps): - self.sum = np.zeros(shape, dtype=np.float32) - self.sumsq = np.full(shape, eps, dtype=np.float32) - self.count = eps + def __init__(self, shape, eps): + self.sum = np.zeros(shape, dtype=np.float32) + self.sumsq = np.full(shape, eps, dtype=np.float32) + self.count = eps - def increment(self, s, ssq, c): - self.sum += s - self.sumsq += ssq - self.count += c + def increment(self, s, ssq, c): + self.sum += s + self.sumsq += ssq + self.count += c - @property - def mean(self): - return self.sum / self.count + @property + def mean(self): + return self.sum / self.count - @property - def std(self): - return np.sqrt(np.maximum(self.sumsq / self.count - np.square(self.mean), - 1e-2)) + @property + def std(self): + return np.sqrt(np.maximum( + self.sumsq / self.count - np.square(self.mean), 1e-2)) - def set_from_init(self, init_mean, init_std, init_count): - self.sum[:] = init_mean * init_count - self.sumsq[:] = (np.square(init_mean) + np.square(init_std)) * init_count - self.count = init_count + def set_from_init(self, init_mean, init_std, init_count): + self.sum[:] = init_mean * init_count + self.sumsq[:] = (np.square(init_mean) + + np.square(init_std)) * init_count + self.count = init_count diff --git a/python/ray/rllib/evolution_strategies/viz.py b/python/ray/rllib/evolution_strategies/viz.py index 7c86fc281..d208b9d16 100644 --- a/python/ray/rllib/evolution_strategies/viz.py +++ b/python/ray/rllib/evolution_strategies/viz.py @@ -11,32 +11,33 @@ import click @click.option("--stochastic", is_flag=True) @click.option("--extra_kwargs") def main(env_id, policy_file, record, stochastic, extra_kwargs): - import gym - from gym import wrappers - import tensorflow as tf - from policies import MujocoPolicy - import numpy as np - - env = gym.make(env_id) - if record: - import uuid - env = wrappers.Monitor(env, "/tmp/" + str(uuid.uuid4()), force=True) - - if extra_kwargs: - import json - extra_kwargs = json.loads(extra_kwargs) - - with tf.Session(): - pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs) - while True: - rews, t = pi.rollout(env, render=True, - random_stream=np.random if stochastic else None) - print("return={:.4f} len={}".format(rews.sum(), t)) + import gym + from gym import wrappers + import tensorflow as tf + from policies import MujocoPolicy + import numpy as np + env = gym.make(env_id) if record: - env.close() - return + import uuid + env = wrappers.Monitor(env, "/tmp/" + str(uuid.uuid4()), force=True) + + if extra_kwargs: + import json + extra_kwargs = json.loads(extra_kwargs) + + with tf.Session(): + pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs) + while True: + rews, t = pi.rollout(env, render=True, + random_stream=(np.random if stochastic + else None)) + print("return={:.4f} len={}".format(rews.sum(), t)) + + if record: + env.close() + return if __name__ == "__main__": - main() + main() diff --git a/python/ray/rllib/example.py b/python/ray/rllib/example.py index 8d2b353a7..03ce00d55 100755 --- a/python/ray/rllib/example.py +++ b/python/ray/rllib/example.py @@ -11,15 +11,15 @@ import ray.rllib.policy_gradient as pg if __name__ == "__main__": - ray.init() + ray.init() - # TODO(ekl): get the algorithms working on a common set of envs - env_name = "CartPole-v0" - alg1 = es.EvolutionStrategies(env_name, es.DEFAULT_CONFIG) - alg2 = pg.PolicyGradient(env_name, pg.DEFAULT_CONFIG) + # TODO(ekl): get the algorithms working on a common set of envs + env_name = "CartPole-v0" + alg1 = es.EvolutionStrategies(env_name, es.DEFAULT_CONFIG) + alg2 = pg.PolicyGradient(env_name, pg.DEFAULT_CONFIG) - while True: - r1 = alg1.train() - r2 = alg2.train() - print("evolution strategies: {}".format(r1)) - print("policy gradient: {}".format(r2)) + while True: + r1 = alg1.train() + r2 = alg2.train() + print("evolution strategies: {}".format(r1)) + print("policy gradient: {}".format(r2)) diff --git a/python/ray/rllib/parallel.py b/python/ray/rllib/parallel.py index 9cc7d9f32..3aee4c170 100644 --- a/python/ray/rllib/parallel.py +++ b/python/ray/rllib/parallel.py @@ -10,189 +10,192 @@ import tensorflow as tf class LocalSyncParallelOptimizer(object): - """Optimizer that runs in parallel across multiple local devices. + """Optimizer that runs in parallel across multiple local devices. - LocalSyncParallelOptimizer automatically splits up and loads training data - onto specified local devices (e.g. GPUs) with `load_data()`. During a call to - `optimize()`, the devices compute gradients over slices of the data in - parallel. The gradients are then averaged and applied to the shared weights. + LocalSyncParallelOptimizer automatically splits up and loads training data + onto specified local devices (e.g. GPUs) with `load_data()`. During a call + to `optimize()`, the devices compute gradients over slices of the data in + parallel. The gradients are then averaged and applied to the shared + weights. - The data loaded is pinned in device memory until the next call to - `load_data`, so you can make multiple passes (possibly in randomized order) - over the same data once loaded. + The data loaded is pinned in device memory until the next call to + `load_data`, so you can make multiple passes (possibly in randomized order) + over the same data once loaded. - This is similar to tf.train.SyncReplicasOptimizer, but works within a single - TensorFlow graph, i.e. implements in-graph replicated training: + This is similar to tf.train.SyncReplicasOptimizer, but works within a + single TensorFlow graph, i.e. implements in-graph replicated training: - https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer - - Args: - optimizer: delegate TensorFlow optimizer object. - devices: list of the names of TensorFlow devices to parallelize over. - input_placeholders: list of inputs for the loss function. Tensors of - these shapes will be passed to build_loss() in order - to define the per-device loss ops. - per_device_batch_size: number of tuples to optimize over at a time per - device. In each call to `optimize()`, - `len(devices) * per_device_batch_size` tuples of - data will be processed. - build_loss: function that takes the specified inputs and returns an - object with a 'loss' property that is a scalar Tensor. For - example, ray.rllib.policy_gradient.ProximalPolicyLoss. - logdir: directory to place debugging output in. - """ - - def __init__( - self, - optimizer, - devices, - input_placeholders, - per_device_batch_size, - build_loss, - logdir): - self.optimizer = optimizer - self.devices = devices - self.batch_size = per_device_batch_size * len(devices) - self.per_device_batch_size = per_device_batch_size - self.input_placeholders = input_placeholders - self.build_loss = build_loss - self.logdir = logdir - - # First initialize the shared loss network - with tf.variable_scope("tower"): - self._shared_loss = build_loss(*input_placeholders) - - # Then setup the per-device loss graphs that use the shared weights - self._batch_index = tf.placeholder(tf.int32) - data_splits = zip( - *[tf.split(ph, len(devices)) for ph in input_placeholders]) - self._towers = [] - for device, device_placeholders in zip(self.devices, data_splits): - self._towers.append(self._setup_device(device, device_placeholders)) - - avg = average_gradients([t.grads for t in self._towers]) - self._train_op = self.optimizer.apply_gradients(avg) - - def load_data(self, sess, inputs, full_trace=False): - """Bulk loads the specified inputs into device memory. - - The shape of the inputs must conform to the shapes of the input - placeholders this optimizer was constructed with. - - The data is split equally across all the devices. If the data is not - evenly divisible by the batch size, excess data will be discarded. + https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer Args: - sess: TensorFlow session. - inputs: list of Tensors matching the input placeholders specified at - construction time of this optimizer. - full_trace: whether to profile data loading. - - Returns: - The number of tuples loaded per device. + optimizer: Delegate TensorFlow optimizer object. + devices: List of the names of TensorFlow devices to parallelize over. + input_placeholders: List of inputs for the loss function. Tensors of + these shapes will be passed to build_loss() in order to define the + per-device loss ops. + per_device_batch_size: Number of tuples to optimize over at a time per + device. In each call to `optimize()`, + `len(devices) * per_device_batch_size` tuples of data will be + processed. + build_loss: Function that takes the specified inputs and returns an + object with a 'loss' property that is a scalar Tensor. For example, + ray.rllib.policy_gradient.ProximalPolicyLoss. + logdir: Directory to place debugging output in. """ - feed_dict = {} - assert len(self.input_placeholders) == len(inputs) - for ph, arr in zip(self.input_placeholders, inputs): - truncated_arr = make_divisible_by(arr, self.batch_size) - feed_dict[ph] = truncated_arr - truncated_len = len(truncated_arr) + def __init__(self, optimizer, devices, input_placeholders, + per_device_batch_size, build_loss, logdir): + self.optimizer = optimizer + self.devices = devices + self.batch_size = per_device_batch_size * len(devices) + self.per_device_batch_size = per_device_batch_size + self.input_placeholders = input_placeholders + self.build_loss = build_loss + self.logdir = logdir - if full_trace: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - else: - run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) - run_metadata = tf.RunMetadata() + # First initialize the shared loss network + with tf.variable_scope("tower"): + self._shared_loss = build_loss(*input_placeholders) - sess.run( - [t.init_op for t in self._towers], - feed_dict=feed_dict, - options=run_options, - run_metadata=run_metadata) - if full_trace: - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w") - trace_file.write(trace.generate_chrome_trace_format()) + # Then setup the per-device loss graphs that use the shared weights + self._batch_index = tf.placeholder(tf.int32) + data_splits = zip( + *[tf.split(ph, len(devices)) for ph in input_placeholders]) + self._towers = [] + for device, device_placeholders in zip(self.devices, data_splits): + self._towers.append(self._setup_device(device, + device_placeholders)) - tuples_per_device = truncated_len / len(self.devices) - assert tuples_per_device % self.per_device_batch_size == 0 - return tuples_per_device + avg = average_gradients([t.grads for t in self._towers]) + self._train_op = self.optimizer.apply_gradients(avg) - def optimize( - self, sess, batch_index, - extra_ops=[], extra_feed_dict={}, file_writer=None): - """Run a single step of SGD. + def load_data(self, sess, inputs, full_trace=False): + """Bulk loads the specified inputs into device memory. - Runs a SGD step over a slice of the preloaded batch with size given by - self.per_device_batch_size and offset given by the batch_index argument. + The shape of the inputs must conform to the shapes of the input + placeholders this optimizer was constructed with. - Updates shared model weights based on the averaged per-device gradients. + The data is split equally across all the devices. If the data is not + evenly divisible by the batch size, excess data will be discarded. - Args: - sess: TensorFlow session. - batch_index: offset into the preloaded data. This value must be - between `0` and `tuples_per_device`. The amount of data - to process is always fixed to `per_device_batch_size`. - extra_ops: extra ops to run with this step (e.g. for metrics). - extra_feed_dict: extra args to feed into this session run. - file_writer: if specified, tf metrics will be written out using this. + Args: + sess: TensorFlow session. + inputs: List of Tensors matching the input placeholders specified + at construction time of this optimizer. + full_trace: Whether to profile data loading. - Returns: - the outputs of extra_ops evaluated over the batch. - """ + Returns: + The number of tuples loaded per device. + """ - if file_writer: - run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - else: - run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) - run_metadata = tf.RunMetadata() + feed_dict = {} + assert len(self.input_placeholders) == len(inputs) + for ph, arr in zip(self.input_placeholders, inputs): + truncated_arr = make_divisible_by(arr, self.batch_size) + feed_dict[ph] = truncated_arr + truncated_len = len(truncated_arr) - feed_dict = {self._batch_index: batch_index} - feed_dict.update(extra_feed_dict) - outs = sess.run( - [self._train_op] + extra_ops, - feed_dict=feed_dict, - options=run_options, - run_metadata=run_metadata) + if full_trace: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + else: + run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) + run_metadata = tf.RunMetadata() - if file_writer: - trace = timeline.Timeline(step_stats=run_metadata.step_stats) - trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w") - trace_file.write(trace.generate_chrome_trace_format()) - file_writer.add_run_metadata( - run_metadata, "sgd_train_{}".format(batch_index)) + sess.run( + [t.init_op for t in self._towers], + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata) + if full_trace: + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + trace_file = open(os.path.join(self.logdir, "timeline-load.json"), + "w") + trace_file.write(trace.generate_chrome_trace_format()) - return outs[1:] + tuples_per_device = truncated_len / len(self.devices) + assert tuples_per_device % self.per_device_batch_size == 0 + return tuples_per_device - def get_common_loss(self): - return self._shared_loss + def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={}, + file_writer=None): + """Run a single step of SGD. - def get_device_losses(self): - return [t.loss_object for t in self._towers] + Runs a SGD step over a slice of the preloaded batch with size given by + self.per_device_batch_size and offset given by the batch_index + argument. - def _setup_device(self, device, device_input_placeholders): - with tf.device(device): - with tf.variable_scope("tower", reuse=True): - device_input_batches = [] - device_input_slices = [] - for ph in device_input_placeholders: - current_batch = tf.Variable( - ph, trainable=False, validate_shape=False, collections=[]) - device_input_batches.append(current_batch) - current_slice = tf.slice( - current_batch, - [self._batch_index] + [0] * len(ph.shape[1:]), - [self.per_device_batch_size] + [-1] * len(ph.shape[1:])) - current_slice.set_shape(ph.shape) - device_input_slices.append(current_slice) - device_loss_obj = self.build_loss(*device_input_slices) - device_grads = self.optimizer.compute_gradients( - device_loss_obj.loss, colocate_gradients_with_ops=True) - return Tower( - tf.group(*[batch.initializer for batch in device_input_batches]), - device_grads, - device_loss_obj) + Updates shared model weights based on the averaged per-device + gradients. + + Args: + sess: TensorFlow session. + batch_index: Offset into the preloaded data. This value must be + between `0` and `tuples_per_device`. The amount of data to + process is always fixed to `per_device_batch_size`. + extra_ops: Extra ops to run with this step (e.g. for metrics). + extra_feed_dict: Extra args to feed into this session run. + file_writer: If specified, tf metrics will be written out using + this. + + Returns: + The outputs of extra_ops evaluated over the batch. + """ + + if file_writer: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + else: + run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) + run_metadata = tf.RunMetadata() + + feed_dict = {self._batch_index: batch_index} + feed_dict.update(extra_feed_dict) + outs = sess.run( + [self._train_op] + extra_ops, + feed_dict=feed_dict, + options=run_options, + run_metadata=run_metadata) + + if file_writer: + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), + "w") + trace_file.write(trace.generate_chrome_trace_format()) + file_writer.add_run_metadata( + run_metadata, "sgd_train_{}".format(batch_index)) + + return outs[1:] + + def get_common_loss(self): + return self._shared_loss + + def get_device_losses(self): + return [t.loss_object for t in self._towers] + + def _setup_device(self, device, device_input_placeholders): + with tf.device(device): + with tf.variable_scope("tower", reuse=True): + device_input_batches = [] + device_input_slices = [] + for ph in device_input_placeholders: + current_batch = tf.Variable( + ph, trainable=False, validate_shape=False, + collections=[]) + device_input_batches.append(current_batch) + current_slice = tf.slice( + current_batch, + [self._batch_index] + [0] * len(ph.shape[1:]), + ([self.per_device_batch_size] + [-1] * + len(ph.shape[1:]))) + current_slice.set_shape(ph.shape) + device_input_slices.append(current_slice) + device_loss_obj = self.build_loss(*device_input_slices) + device_grads = self.optimizer.compute_gradients( + device_loss_obj.loss, colocate_gradients_with_ops=True) + return Tower( + tf.group(*[batch.initializer + for batch in device_input_batches]), + device_grads, + device_loss_obj) # Each tower is a copy of the loss graph pinned to a specific device. @@ -200,50 +203,51 @@ Tower = namedtuple("Tower", ["init_op", "grads", "loss_object"]) def make_divisible_by(array, n): - return array[0:array.shape[0] - array.shape[0] % n] + return array[0:array.shape[0] - array.shape[0] % n] def average_gradients(tower_grads): - """Averages gradients across towers. + """Averages gradients across towers. - Calculate the average gradient for each shared variable across all towers. - Note that this function provides a synchronization point across all towers. + Calculate the average gradient for each shared variable across all towers. + Note that this function provides a synchronization point across all towers. - Args: - tower_grads: List of lists of (gradient, variable) tuples. The outer list - is over individual gradients. The inner list is over the gradient - calculation for each tower. + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer + list is over individual gradients. The inner list is over the + gradient calculation for each tower. - Returns: - List of pairs of (gradient, variable) where the gradient has been averaged - across all towers. + Returns: + List of pairs of (gradient, variable) where the gradient has been + averaged across all towers. - TODO(ekl): We could use NCCL if this becomes a bottleneck. - """ + TODO(ekl): We could use NCCL if this becomes a bottleneck. + """ - average_grads = [] - for grad_and_vars in zip(*tower_grads): + average_grads = [] + for grad_and_vars in zip(*tower_grads): - # Note that each grad_and_vars looks like the following: - # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) - grads = [] - for g, _ in grad_and_vars: - if g is not None: - # Add 0 dimension to the gradients to represent the tower. - expanded_g = tf.expand_dims(g, 0) + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, _ in grad_and_vars: + if g is not None: + # Add 0 dimension to the gradients to represent the tower. + expanded_g = tf.expand_dims(g, 0) - # Append on a 'tower' dimension which we will average over below. - grads.append(expanded_g) + # Append on a 'tower' dimension which we will average over + # below. + grads.append(expanded_g) - # Average over the 'tower' dimension. - grad = tf.concat(axis=0, values=grads) - grad = tf.reduce_mean(grad, 0) + # Average over the 'tower' dimension. + grad = tf.concat(axis=0, values=grads) + grad = tf.reduce_mean(grad, 0) - # Keep in mind that the Variables are redundant because they are shared - # across towers. So .. we will just return the first tower's pointer to - # the Variable. - v = grad_and_vars[0][1] - grad_and_var = (grad, v) - average_grads.append(grad_and_var) + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) - return average_grads + return average_grads diff --git a/python/ray/rllib/policy_gradient/agent.py b/python/ray/rllib/policy_gradient/agent.py index 9aa44e250..ef1499c0b 100644 --- a/python/ray/rllib/policy_gradient/agent.py +++ b/python/ray/rllib/policy_gradient/agent.py @@ -25,133 +25,141 @@ from ray.rllib.policy_gradient.rollout import rollouts, add_advantage_values class Agent(object): - """ - Agent class that holds the simulator environment and the policy. + """ + Agent class that holds the simulator environment and the policy. - Initializes the tensorflow graphs for both training and evaluation. - One common policy graph is initialized on '/cpu:0' and holds all the shared - network weights. When run as a remote agent, only this graph is used. - """ + Initializes the tensorflow graphs for both training and evaluation. + One common policy graph is initialized on '/cpu:0' and holds all the shared + network weights. When run as a remote agent, only this graph is used. + """ - def __init__(self, name, batchsize, preprocessor, config, logdir, is_remote): - if is_remote: - os.environ["CUDA_VISIBLE_DEVICES"] = "" - devices = ["/cpu:0"] - else: - devices = config["devices"] - self.devices = devices - self.config = config - self.logdir = logdir - self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor) - if preprocessor.shape is None: - preprocessor.shape = self.env.observation_space.shape - if is_remote: - config_proto = tf.ConfigProto() - else: - config_proto = tf.ConfigProto(**config["tf_session_args"]) - self.preprocessor = preprocessor - self.sess = tf.Session(config=config_proto) - if config["use_tf_debugger"] and not is_remote: - self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) - self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) + def __init__(self, name, batchsize, preprocessor, config, logdir, + is_remote): + if is_remote: + os.environ["CUDA_VISIBLE_DEVICES"] = "" + devices = ["/cpu:0"] + else: + devices = config["devices"] + self.devices = devices + self.config = config + self.logdir = logdir + self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor) + if preprocessor.shape is None: + preprocessor.shape = self.env.observation_space.shape + if is_remote: + config_proto = tf.ConfigProto() + else: + config_proto = tf.ConfigProto(**config["tf_session_args"]) + self.preprocessor = preprocessor + self.sess = tf.Session(config=config_proto) + if config["use_tf_debugger"] and not is_remote: + self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) + self.sess.add_tensor_filter("has_inf_or_nan", + tf_debug.has_inf_or_nan) - # Defines the training inputs. - self.kl_coeff = tf.placeholder(name="newkl", shape=(), dtype=tf.float32) - self.observations = tf.placeholder(tf.float32, - shape=(None,) + preprocessor.shape) - self.advantages = tf.placeholder(tf.float32, shape=(None,)) + # Defines the training inputs. + self.kl_coeff = tf.placeholder(name="newkl", shape=(), + dtype=tf.float32) + self.observations = tf.placeholder(tf.float32, + shape=(None,) + preprocessor.shape) + self.advantages = tf.placeholder(tf.float32, shape=(None,)) - action_space = self.env.action_space - if isinstance(action_space, gym.spaces.Box): - # The first half of the dimensions are the means, the second half are the - # standard deviations. - self.action_dim = action_space.shape[0] - self.action_shape = (self.action_dim,) - self.logit_dim = 2 * self.action_dim - self.actions = tf.placeholder(tf.float32, shape=(None, self.action_dim)) - self.distribution_class = DiagGaussian - elif isinstance(action_space, gym.spaces.Discrete): - self.action_dim = action_space.n - self.action_shape = () - self.logit_dim = self.action_dim - self.actions = tf.placeholder(tf.int64, shape=(None,)) - self.distribution_class = Categorical - else: - raise NotImplemented("action space" + str(type(action_space)) + - "currently not supported") - self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim)) + action_space = self.env.action_space + if isinstance(action_space, gym.spaces.Box): + # The first half of the dimensions are the means, the second half + # are the standard deviations. + self.action_dim = action_space.shape[0] + self.action_shape = (self.action_dim,) + self.logit_dim = 2 * self.action_dim + self.actions = tf.placeholder(tf.float32, + shape=(None, self.action_dim)) + self.distribution_class = DiagGaussian + elif isinstance(action_space, gym.spaces.Discrete): + self.action_dim = action_space.n + self.action_shape = () + self.logit_dim = self.action_dim + self.actions = tf.placeholder(tf.int64, shape=(None,)) + self.distribution_class = Categorical + else: + raise NotImplemented("action space" + str(type(action_space)) + + "currently not supported") + self.prev_logits = tf.placeholder(tf.float32, + shape=(None, self.logit_dim)) - assert config["sgd_batchsize"] % len(devices) == 0, \ - "Batch size must be evenly divisible by devices" - if is_remote: - self.batch_size = 1 - self.per_device_batch_size = 1 - else: - self.batch_size = config["sgd_batchsize"] - self.per_device_batch_size = int(self.batch_size / len(devices)) + assert config["sgd_batchsize"] % len(devices) == 0, \ + "Batch size must be evenly divisible by devices" + if is_remote: + self.batch_size = 1 + self.per_device_batch_size = 1 + else: + self.batch_size = config["sgd_batchsize"] + self.per_device_batch_size = int(self.batch_size / len(devices)) - def build_loss(obs, advs, acts, plog): - return ProximalPolicyLoss( - self.env.observation_space, self.env.action_space, - obs, advs, acts, plog, self.logit_dim, - self.kl_coeff, self.distribution_class, self.config, self.sess) + def build_loss(obs, advs, acts, plog): + return ProximalPolicyLoss( + self.env.observation_space, self.env.action_space, + obs, advs, acts, plog, self.logit_dim, + self.kl_coeff, self.distribution_class, self.config, self.sess) - self.par_opt = LocalSyncParallelOptimizer( - tf.train.AdamOptimizer(self.config["sgd_stepsize"]), - self.devices, - [self.observations, self.advantages, self.actions, self.prev_logits], - self.per_device_batch_size, - build_loss, - self.logdir) + self.par_opt = LocalSyncParallelOptimizer( + tf.train.AdamOptimizer(self.config["sgd_stepsize"]), + self.devices, + [self.observations, self.advantages, self.actions, + self.prev_logits], + self.per_device_batch_size, + build_loss, + self.logdir) - # Metric ops - with tf.name_scope("test_outputs"): - policies = self.par_opt.get_device_losses() - self.mean_loss = tf.reduce_mean( - tf.stack(values=[policy.loss for policy in policies]), 0) - self.mean_kl = tf.reduce_mean( - tf.stack(values=[policy.mean_kl for policy in policies]), 0) - self.mean_entropy = tf.reduce_mean( - tf.stack(values=[policy.mean_entropy for policy in policies]), 0) + # Metric ops + with tf.name_scope("test_outputs"): + policies = self.par_opt.get_device_losses() + self.mean_loss = tf.reduce_mean( + tf.stack(values=[policy.loss for policy in policies]), 0) + self.mean_kl = tf.reduce_mean( + tf.stack(values=[policy.mean_kl for policy in policies]), 0) + self.mean_entropy = tf.reduce_mean( + tf.stack(values=[policy.mean_entropy for policy in policies]), + 0) - # References to the model weights - self.common_policy = self.par_opt.get_common_loss() - self.variables = ray.experimental.TensorFlowVariables( - self.common_policy.loss, - self.sess) - self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None) - self.reward_filter = MeanStdFilter((), clip=5.0) - self.sess.run(tf.global_variables_initializer()) + # References to the model weights + self.common_policy = self.par_opt.get_common_loss() + self.variables = ray.experimental.TensorFlowVariables( + self.common_policy.loss, + self.sess) + self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None) + self.reward_filter = MeanStdFilter((), clip=5.0) + self.sess.run(tf.global_variables_initializer()) - def load_data(self, trajectories, full_trace): - return self.par_opt.load_data( - self.sess, - [trajectories["observations"], - trajectories["advantages"], - trajectories["actions"].squeeze(), - trajectories["logprobs"]], - full_trace=full_trace) + def load_data(self, trajectories, full_trace): + return self.par_opt.load_data( + self.sess, + [trajectories["observations"], + trajectories["advantages"], + trajectories["actions"].squeeze(), + trajectories["logprobs"]], + full_trace=full_trace) - def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, file_writer): - return self.par_opt.optimize( - self.sess, - batch_index, - extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy], - extra_feed_dict={self.kl_coeff: kl_coeff}, - file_writer=file_writer if full_trace else None) + def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, + file_writer): + return self.par_opt.optimize( + self.sess, + batch_index, + extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy], + extra_feed_dict={self.kl_coeff: kl_coeff}, + file_writer=file_writer if full_trace else None) - def get_weights(self): - return self.variables.get_weights() + def get_weights(self): + return self.variables.get_weights() - def load_weights(self, weights): - self.variables.set_weights(weights) + def load_weights(self, weights): + self.variables.set_weights(weights) - def compute_trajectory(self, gamma, lam, horizon): - trajectory = rollouts( - self.common_policy, - self.env, horizon, self.observation_filter, self.reward_filter) - add_advantage_values(trajectory, gamma, lam, self.reward_filter) - return trajectory + def compute_trajectory(self, gamma, lam, horizon): + trajectory = rollouts( + self.common_policy, + self.env, horizon, self.observation_filter, self.reward_filter) + add_advantage_values(trajectory, gamma, lam, self.reward_filter) + return trajectory RemoteAgent = ray.remote(Agent) diff --git a/python/ray/rllib/policy_gradient/distributions.py b/python/ray/rllib/policy_gradient/distributions.py index 2fc953a58..915531c51 100644 --- a/python/ray/rllib/policy_gradient/distributions.py +++ b/python/ray/rllib/policy_gradient/distributions.py @@ -7,63 +7,63 @@ import numpy as np class Categorical(object): - def __init__(self, logits): - self.logits = logits + def __init__(self, logits): + self.logits = logits - def logp(self, x): - return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, - labels=x) + def logp(self, x): + return -tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=self.logits, labels=x) - def entropy(self): - a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], - keep_dims=True) - ea0 = tf.exp(a0) - z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) - p0 = ea0 / z0 - return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1]) + def entropy(self): + a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], + keep_dims=True) + ea0 = tf.exp(a0) + z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1]) - def kl(self, other): - a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], - keep_dims=True) - a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1], - keep_dims=True) - ea0 = tf.exp(a0) - ea1 = tf.exp(a1) - z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) - z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True) - p0 = ea0 / z0 - return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), - reduction_indices=[1]) + def kl(self, other): + a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1], + keep_dims=True) + a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1], + keep_dims=True) + ea0 = tf.exp(a0) + ea1 = tf.exp(a1) + z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True) + z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True) + p0 = ea0 / z0 + return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), + reduction_indices=[1]) - def sample(self): - return tf.multinomial(self.logits, 1) + def sample(self): + return tf.multinomial(self.logits, 1) class DiagGaussian(object): - def __init__(self, flat): - self.flat = flat - mean, logstd = tf.split(flat, 2, axis=1) - self.mean = mean - self.logstd = logstd - self.std = tf.exp(logstd) + def __init__(self, flat): + self.flat = flat + mean, logstd = tf.split(flat, 2, axis=1) + self.mean = mean + self.logstd = logstd + self.std = tf.exp(logstd) - def logp(self, x): - return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), - reduction_indices=[1]) - - 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - - tf.reduce_sum(self.logstd, reduction_indices=[1])) + def logp(self, x): + return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), + reduction_indices=[1]) - + 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) - + tf.reduce_sum(self.logstd, reduction_indices=[1])) - def kl(self, other): - assert isinstance(other, DiagGaussian) - return tf.reduce_sum(other.logstd - self.logstd + - (tf.square(self.std) + - tf.square(self.mean - other.mean)) / - (2.0 * tf.square(other.std)) - 0.5, - reduction_indices=[1]) + def kl(self, other): + assert isinstance(other, DiagGaussian) + return tf.reduce_sum(other.logstd - self.logstd + + (tf.square(self.std) + + tf.square(self.mean - other.mean)) / + (2.0 * tf.square(other.std)) - 0.5, + reduction_indices=[1]) - def entropy(self): - return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), - reduction_indices=[1]) + def entropy(self): + return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), + reduction_indices=[1]) - def sample(self): - return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) + def sample(self): + return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) diff --git a/python/ray/rllib/policy_gradient/env.py b/python/ray/rllib/policy_gradient/env.py index 6ddb98ee3..74795b8f4 100644 --- a/python/ray/rllib/policy_gradient/env.py +++ b/python/ray/rllib/policy_gradient/env.py @@ -7,59 +7,60 @@ import numpy as np class AtariPixelPreprocessor(object): - def __init__(self): - self.shape = (80, 80, 3) + def __init__(self): + self.shape = (80, 80, 3) - def __call__(self, observation): - "Convert images from (210, 160, 3) to (3, 80, 80) by downsampling." - return (observation[25:-25:2, ::2, :][None] - 128) / 128 + def __call__(self, observation): + "Convert images from (210, 160, 3) to (3, 80, 80) by downsampling." + return (observation[25:-25:2, ::2, :][None] - 128) / 128 class AtariRamPreprocessor(object): - def __init__(self): - self.shape = (128,) + def __init__(self): + self.shape = (128,) - def __call__(self, observation): - return (observation - 128) / 128 + def __call__(self, observation): + return (observation - 128) / 128 class NoPreprocessor(object): - def __init__(self): - self.shape = None + def __init__(self): + self.shape = None - def __call__(self, observation): - return observation + def __call__(self, observation): + return observation class BatchedEnv(object): - """This holds multiple gym enviroments and performs steps on all of them.""" - def __init__(self, name, batchsize, preprocessor=None): - self.envs = [gym.make(name) for _ in range(batchsize)] - self.observation_space = self.envs[0].observation_space - self.action_space = self.envs[0].action_space - self.batchsize = batchsize - self.preprocessor = preprocessor if preprocessor else lambda obs: obs[None] + """This holds multiple gym envs and performs steps on all of them.""" + def __init__(self, name, batchsize, preprocessor=None): + self.envs = [gym.make(name) for _ in range(batchsize)] + self.observation_space = self.envs[0].observation_space + self.action_space = self.envs[0].action_space + self.batchsize = batchsize + self.preprocessor = (preprocessor if preprocessor + else lambda obs: obs[None]) - def reset(self): - observations = [self.preprocessor(env.reset()) for env in self.envs] - self.shape = observations[0].shape - self.dones = [False for _ in range(self.batchsize)] - return np.vstack(observations) + def reset(self): + observations = [self.preprocessor(env.reset()) for env in self.envs] + self.shape = observations[0].shape + self.dones = [False for _ in range(self.batchsize)] + return np.vstack(observations) - def step(self, actions, render=False): - observations = [] - rewards = [] - for i, action in enumerate(actions): - if self.dones[i]: - observations.append(np.zeros(self.shape)) - rewards.append(0.0) - continue - observation, reward, done, info = self.envs[i].step( - action if len(action) > 1 else action[0]) - if render: - self.envs[0].render() - observations.append(self.preprocessor(observation)) - rewards.append(reward) - self.dones[i] = done - return (np.vstack(observations), np.array(rewards, dtype="float32"), - np.array(self.dones)) + def step(self, actions, render=False): + observations = [] + rewards = [] + for i, action in enumerate(actions): + if self.dones[i]: + observations.append(np.zeros(self.shape)) + rewards.append(0.0) + continue + observation, reward, done, info = self.envs[i].step( + action if len(action) > 1 else action[0]) + if render: + self.envs[0].render() + observations.append(self.preprocessor(observation)) + rewards.append(reward) + self.dones[i] = done + return (np.vstack(observations), np.array(rewards, dtype="float32"), + np.array(self.dones)) diff --git a/python/ray/rllib/policy_gradient/example.py b/python/ray/rllib/policy_gradient/example.py index dc81ced01..fda9a1fe6 100755 --- a/python/ray/rllib/policy_gradient/example.py +++ b/python/ray/rllib/policy_gradient/example.py @@ -11,28 +11,28 @@ from ray.rllib.policy_gradient import PolicyGradient, DEFAULT_CONFIG if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run the policy gradient " - "algorithm.") - parser.add_argument("--environment", default="Pong-v0", type=str, - help="The gym environment to use.") - parser.add_argument("--redis-address", default=None, type=str, - help="The Redis address of the cluster.") - parser.add_argument("--use-tf-debugger", default=False, type=bool, - help="Run the script inside of tf-dbg.") - parser.add_argument("--load-checkpoint", default=None, type=str, - help="Continue training from a checkpoint.") + parser = argparse.ArgumentParser(description="Run the policy gradient " + "algorithm.") + parser.add_argument("--environment", default="Pong-v0", type=str, + help="The gym environment to use.") + parser.add_argument("--redis-address", default=None, type=str, + help="The Redis address of the cluster.") + parser.add_argument("--use-tf-debugger", default=False, type=bool, + help="Run the script inside of tf-dbg.") + parser.add_argument("--load-checkpoint", default=None, type=str, + help="Continue training from a checkpoint.") - args = parser.parse_args() - config = DEFAULT_CONFIG.copy() - config["use_tf_debugger"] = args.use_tf_debugger - if args.load_checkpoint: - config["load_checkpoint"] = args.load_checkpoint + args = parser.parse_args() + config = DEFAULT_CONFIG.copy() + config["use_tf_debugger"] = args.use_tf_debugger + if args.load_checkpoint: + config["load_checkpoint"] = args.load_checkpoint - ray.init(redis_address=args.redis_address) + ray.init(redis_address=args.redis_address) - alg = PolicyGradient(args.environment, config) - result = alg.train() - while result.training_iteration < config["max_iterations"]: - print("\n== iteration", result.training_iteration) + alg = PolicyGradient(args.environment, config) result = alg.train() - print("current status: {}".format(result)) + while result.training_iteration < config["max_iterations"]: + print("\n== iteration", result.training_iteration) + result = alg.train() + print("current status: {}".format(result)) diff --git a/python/ray/rllib/policy_gradient/filter.py b/python/ray/rllib/policy_gradient/filter.py index 692490f58..bf778ac27 100644 --- a/python/ray/rllib/policy_gradient/filter.py +++ b/python/ray/rllib/policy_gradient/filter.py @@ -6,127 +6,127 @@ import numpy as np class NoFilter(object): - def __init__(self): - pass + def __init__(self): + pass - def __call__(self, x, update=True): - return np.asarray(x) + def __call__(self, x, update=True): + return np.asarray(x) # http://www.johndcook.com/blog/standard_deviation/ class RunningStat(object): - def __init__(self, shape=None): - self._n = 0 - self._M = np.zeros(shape) - self._S = np.zeros(shape) + def __init__(self, shape=None): + self._n = 0 + self._M = np.zeros(shape) + self._S = np.zeros(shape) - def push(self, x): - x = np.asarray(x) - # Unvectorized update of the running statistics. - assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}" - .format(x.shape, self._M.shape)) - n1 = self._n - self._n += 1 - if self._n == 1: - self._M[...] = x - else: - delta = x - self._M - self._M[...] += delta / self._n - self._S[...] += delta * delta * n1 / self._n + def push(self, x): + x = np.asarray(x) + # Unvectorized update of the running statistics. + assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}" + .format(x.shape, self._M.shape)) + n1 = self._n + self._n += 1 + if self._n == 1: + self._M[...] = x + else: + delta = x - self._M + self._M[...] += delta / self._n + self._S[...] += delta * delta * n1 / self._n - def update(self, other): - n1 = self._n - n2 = other._n - n = n1 + n2 - delta = self._M - other._M - delta2 = delta * delta - M = (n1 * self._M + n2 * other._M) / n - S = self._S + other._S + delta2 * n1 * n2 / n - self._n = n - self._M = M - self._S = S + def update(self, other): + n1 = self._n + n2 = other._n + n = n1 + n2 + delta = self._M - other._M + delta2 = delta * delta + M = (n1 * self._M + n2 * other._M) / n + S = self._S + other._S + delta2 * n1 * n2 / n + self._n = n + self._M = M + self._S = S - @property - def n(self): - return self._n + @property + def n(self): + return self._n - @property - def mean(self): - return self._M + @property + def mean(self): + return self._M - @property - def var(self): - return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) + @property + def var(self): + return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) - @property - def std(self): - return np.sqrt(self.var) + @property + def std(self): + return np.sqrt(self.var) - @property - def shape(self): - return self._M.shape + @property + def shape(self): + return self._M.shape class MeanStdFilter(object): - def __init__(self, shape, demean=True, destd=True, clip=10.0): - self.demean = demean - self.destd = destd - self.clip = clip + def __init__(self, shape, demean=True, destd=True, clip=10.0): + self.demean = demean + self.destd = destd + self.clip = clip - self.rs = RunningStat(shape) + self.rs = RunningStat(shape) - def __call__(self, x, update=True): - x = np.asarray(x) - if update: - if len(x.shape) == len(self.rs.shape) + 1: - # The vectorized case. - for i in range(x.shape[0]): - self.rs.push(x[i]) - else: - # The unvectorized case. - self.rs.push(x) - if self.demean: - x = x - self.rs.mean - if self.destd: - x = x / (self.rs.std + 1e-8) - if self.clip: - x = np.clip(x, -self.clip, self.clip) - return x + def __call__(self, x, update=True): + x = np.asarray(x) + if update: + if len(x.shape) == len(self.rs.shape) + 1: + # The vectorized case. + for i in range(x.shape[0]): + self.rs.push(x[i]) + else: + # The unvectorized case. + self.rs.push(x) + if self.demean: + x = x - self.rs.mean + if self.destd: + x = x / (self.rs.std + 1e-8) + if self.clip: + x = np.clip(x, -self.clip, self.clip) + return x def test_running_stat(): - for shp in ((), (3,), (3, 4)): - li = [] - rs = RunningStat(shp) - for _ in range(5): - val = np.random.randn(*shp) - rs.push(val) - li.append(val) - m = np.mean(li, axis=0) - assert np.allclose(rs.mean, m) - v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0) - assert np.allclose(rs.var, v) + for shp in ((), (3,), (3, 4)): + li = [] + rs = RunningStat(shp) + for _ in range(5): + val = np.random.randn(*shp) + rs.push(val) + li.append(val) + m = np.mean(li, axis=0) + assert np.allclose(rs.mean, m) + v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0) + assert np.allclose(rs.var, v) def test_combining_stat(): - for shape in [(), (3,), (3, 4)]: - li = [] - rs1 = RunningStat(shape) - rs2 = RunningStat(shape) - rs = RunningStat(shape) - for _ in range(5): - val = np.random.randn(*shape) - rs1.push(val) - rs.push(val) - li.append(val) - for _ in range(9): - rs2.push(val) - rs.push(val) - li.append(val) - rs1.update(rs2) - assert np.allclose(rs.mean, rs1.mean) - assert np.allclose(rs.std, rs1.std) + for shape in [(), (3,), (3, 4)]: + li = [] + rs1 = RunningStat(shape) + rs2 = RunningStat(shape) + rs = RunningStat(shape) + for _ in range(5): + val = np.random.randn(*shape) + rs1.push(val) + rs.push(val) + li.append(val) + for _ in range(9): + rs2.push(val) + rs.push(val) + li.append(val) + rs1.update(rs2) + assert np.allclose(rs.mean, rs1.mean) + assert np.allclose(rs.std, rs1.std) test_running_stat() diff --git a/python/ray/rllib/policy_gradient/loss.py b/python/ray/rllib/policy_gradient/loss.py index 2f2129150..211128727 100644 --- a/python/ray/rllib/policy_gradient/loss.py +++ b/python/ray/rllib/policy_gradient/loss.py @@ -10,43 +10,43 @@ from ray.rllib.policy_gradient.models.fcnet import fc_net class ProximalPolicyLoss(object): - def __init__( - self, observation_space, action_space, - observations, advantages, actions, prev_logits, logit_dim, - kl_coeff, distribution_class, config, sess): - assert (isinstance(action_space, gym.spaces.Discrete) or - isinstance(action_space, gym.spaces.Box)) - self.prev_dist = distribution_class(prev_logits) + def __init__( + self, observation_space, action_space, + observations, advantages, actions, prev_logits, logit_dim, + kl_coeff, distribution_class, config, sess): + assert (isinstance(action_space, gym.spaces.Discrete) or + isinstance(action_space, gym.spaces.Box)) + self.prev_dist = distribution_class(prev_logits) - # Saved so that we can compute actions given different observations - self.observations = observations + # Saved so that we can compute actions given different observations + self.observations = observations - if len(observation_space.shape) > 1: - self.curr_logits = vision_net(observations, num_classes=logit_dim) - else: - assert len(observation_space.shape) == 1 - self.curr_logits = fc_net(observations, num_classes=logit_dim) - self.curr_dist = distribution_class(self.curr_logits) - self.sampler = self.curr_dist.sample() + if len(observation_space.shape) > 1: + self.curr_logits = vision_net(observations, num_classes=logit_dim) + else: + assert len(observation_space.shape) == 1 + self.curr_logits = fc_net(observations, num_classes=logit_dim) + self.curr_dist = distribution_class(self.curr_logits) + self.sampler = self.curr_dist.sample() - # Make loss functions. - self.ratio = tf.exp(self.curr_dist.logp(actions) - - self.prev_dist.logp(actions)) - self.kl = self.prev_dist.kl(self.curr_dist) - self.mean_kl = tf.reduce_mean(self.kl) - self.entropy = self.curr_dist.entropy() - self.mean_entropy = tf.reduce_mean(self.entropy) - self.surr1 = self.ratio * advantages - self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"], - 1 + config["clip_param"]) * advantages - self.surr = tf.minimum(self.surr1, self.surr2) - self.loss = tf.reduce_mean(-self.surr + kl_coeff * self.kl - - config["entropy_coeff"] * self.entropy) - self.sess = sess + # Make loss functions. + self.ratio = tf.exp(self.curr_dist.logp(actions) - + self.prev_dist.logp(actions)) + self.kl = self.prev_dist.kl(self.curr_dist) + self.mean_kl = tf.reduce_mean(self.kl) + self.entropy = self.curr_dist.entropy() + self.mean_entropy = tf.reduce_mean(self.entropy) + self.surr1 = self.ratio * advantages + self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"], + 1 + config["clip_param"]) * advantages + self.surr = tf.minimum(self.surr1, self.surr2) + self.loss = tf.reduce_mean(-self.surr + kl_coeff * self.kl - + config["entropy_coeff"] * self.entropy) + self.sess = sess - def compute_actions(self, observations): - return self.sess.run([self.sampler, self.curr_logits], - feed_dict={self.observations: observations}) + def compute_actions(self, observations): + return self.sess.run([self.sampler, self.curr_logits], + feed_dict={self.observations: observations}) - def loss(self): - return self.loss + def loss(self): + return self.loss diff --git a/python/ray/rllib/policy_gradient/models/fcnet.py b/python/ray/rllib/policy_gradient/models/fcnet.py index e431ea44c..057b2004d 100644 --- a/python/ray/rllib/policy_gradient/models/fcnet.py +++ b/python/ray/rllib/policy_gradient/models/fcnet.py @@ -9,30 +9,30 @@ import numpy as np def normc_initializer(std=1.0): - def _initializer(shape, dtype=None, partition_info=None): - out = np.random.randn(*shape).astype(np.float32) - out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) - return tf.constant(out) - return _initializer + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer def fc_net(inputs, num_classes=10, logstd=False): - with tf.name_scope("fc_net"): - fc1 = slim.fully_connected(inputs, 128, - weights_initializer=normc_initializer(1.0), - scope="fc1") - fc2 = slim.fully_connected(fc1, 128, - weights_initializer=normc_initializer(1.0), - scope="fc2") - fc3 = slim.fully_connected(fc2, 128, - weights_initializer=normc_initializer(1.0), - scope="fc3") - fc4 = slim.fully_connected(fc3, num_classes, - weights_initializer=normc_initializer(0.01), - activation_fn=None, scope="fc4") - if logstd: - logstd = tf.get_variable(name="logstd", shape=[num_classes], - initializer=tf.zeros_initializer) - return tf.concat(1, [fc4, logstd]) - else: - return fc4 + with tf.name_scope("fc_net"): + fc1 = slim.fully_connected(inputs, 128, + weights_initializer=normc_initializer(1.0), + scope="fc1") + fc2 = slim.fully_connected(fc1, 128, + weights_initializer=normc_initializer(1.0), + scope="fc2") + fc3 = slim.fully_connected(fc2, 128, + weights_initializer=normc_initializer(1.0), + scope="fc3") + fc4 = slim.fully_connected(fc3, num_classes, + weights_initializer=normc_initializer(0.01), + activation_fn=None, scope="fc4") + if logstd: + logstd = tf.get_variable(name="logstd", shape=[num_classes], + initializer=tf.zeros_initializer) + return tf.concat(1, [fc4, logstd]) + else: + return fc4 diff --git a/python/ray/rllib/policy_gradient/models/visionnet.py b/python/ray/rllib/policy_gradient/models/visionnet.py index 75adc751e..2b03b2ff7 100644 --- a/python/ray/rllib/policy_gradient/models/visionnet.py +++ b/python/ray/rllib/policy_gradient/models/visionnet.py @@ -7,10 +7,10 @@ import tensorflow.contrib.slim as slim def vision_net(inputs, num_classes=10): - with tf.name_scope("vision_net"): - conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1") - conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2") - fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1") - fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None, - normalizer_fn=None, scope="fc2") - return tf.squeeze(fc2, [1, 2]) + with tf.name_scope("vision_net"): + conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1") + conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2") + fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1") + fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None, + normalizer_fn=None, scope="fc2") + return tf.squeeze(fc2, [1, 2]) diff --git a/python/ray/rllib/policy_gradient/policy_gradient.py b/python/ray/rllib/policy_gradient/policy_gradient.py index e9069c5b5..760ca5320 100644 --- a/python/ray/rllib/policy_gradient/policy_gradient.py +++ b/python/ray/rllib/policy_gradient/policy_gradient.py @@ -43,165 +43,170 @@ DEFAULT_CONFIG = { class PolicyGradient(Algorithm): - def __init__(self, env_name, config, upload_dir=None): - config.update({"alg": "PolicyGradient"}) + def __init__(self, env_name, config, upload_dir=None): + config.update({"alg": "PolicyGradient"}) - Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) + Algorithm.__init__(self, env_name, config, upload_dir=upload_dir) - # TODO(ekl) the preprocessor should be associated with the env elsewhere - if self.env_name == "Pong-v0": - preprocessor = AtariPixelPreprocessor() - elif self.env_name == "Pong-ram-v3": - preprocessor = AtariRamPreprocessor() - elif self.env_name == "CartPole-v0": - preprocessor = NoPreprocessor() - elif self.env_name == "Walker2d-v1": - preprocessor = NoPreprocessor() - else: - preprocessor = AtariPixelPreprocessor() + # TODO(ekl): The preprocessor should be associated with the env + # elsewhere. + if self.env_name == "Pong-v0": + preprocessor = AtariPixelPreprocessor() + elif self.env_name == "Pong-ram-v3": + preprocessor = AtariRamPreprocessor() + elif self.env_name == "CartPole-v0": + preprocessor = NoPreprocessor() + elif self.env_name == "Walker2d-v1": + preprocessor = NoPreprocessor() + else: + preprocessor = AtariPixelPreprocessor() - self.preprocessor = preprocessor - self.global_step = 0 - self.j = 0 - self.kl_coeff = config["kl_coeff"] - self.model = Agent( - self.env_name, 1, self.preprocessor, self.config, self.logdir, False) - self.agents = [ - RemoteAgent.remote( - self.env_name, 1, self.preprocessor, self.config, - self.logdir, True) - for _ in range(config["num_agents"])] + self.preprocessor = preprocessor + self.global_step = 0 + self.j = 0 + self.kl_coeff = config["kl_coeff"] + self.model = Agent( + self.env_name, 1, self.preprocessor, self.config, self.logdir, + False) + self.agents = [ + RemoteAgent.remote( + self.env_name, 1, self.preprocessor, self.config, + self.logdir, True) + for _ in range(config["num_agents"])] - def train(self): - agents = self.agents - config = self.config - model = self.model - j = self.j - self.j += 1 + def train(self): + agents = self.agents + config = self.config + model = self.model + j = self.j + self.j += 1 - saver = tf.train.Saver(max_to_keep=None) - if "load_checkpoint" in config: - saver.restore(model.sess, config["load_checkpoint"]) + saver = tf.train.Saver(max_to_keep=None) + if "load_checkpoint" in config: + saver.restore(model.sess, config["load_checkpoint"]) - # TF does not support to write logs to S3 at the moment - write_tf_logs = self.logdir.startswith("file") - iter_start = time.time() - if write_tf_logs: - file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph) - if config["model_checkpoint_file"]: - checkpoint_path = saver.save( - model.sess, - os.path.join(self.logdir, config["model_checkpoint_file"] % j)) - print("Checkpoint saved in file: %s" % checkpoint_path) - checkpointing_end = time.time() - weights = ray.put(model.get_weights()) - [a.load_weights.remote(weights) for a in agents] - trajectory, total_reward, traj_len_mean = collect_samples( - agents, config["timesteps_per_batch"], 0.995, 1.0, 2000) - print("total reward is ", total_reward) - print("trajectory length mean is ", traj_len_mean) - print("timesteps:", trajectory["dones"].shape[0]) - if write_tf_logs: - traj_stats = tf.Summary(value=[ - tf.Summary.Value( - tag="policy_gradient/rollouts/mean_reward", - simple_value=total_reward), - tf.Summary.Value( - tag="policy_gradient/rollouts/traj_len_mean", - simple_value=traj_len_mean)]) - file_writer.add_summary(traj_stats, self.global_step) - self.global_step += 1 - trajectory["advantages"] = ((trajectory["advantages"] - - trajectory["advantages"].mean()) / - trajectory["advantages"].std()) - rollouts_end = time.time() - print("Computing policy (iterations=" + str(config["num_sgd_iter"]) + - ", stepsize=" + str(config["sgd_stepsize"]) + "):") - names = ["iter", "loss", "kl", "entropy"] - print(("{:>15}" * len(names)).format(*names)) - trajectory = shuffle(trajectory) - shuffle_end = time.time() - tuples_per_device = model.load_data( - trajectory, j == 0 and config["full_trace_data_load"]) - load_end = time.time() - checkpointing_time = checkpointing_end - iter_start - rollouts_time = rollouts_end - checkpointing_end - shuffle_time = shuffle_end - rollouts_end - load_time = load_end - shuffle_end - sgd_time = 0 - for i in range(config["num_sgd_iter"]): - sgd_start = time.time() - batch_index = 0 - num_batches = int(tuples_per_device) // int(model.per_device_batch_size) - loss, kl, entropy = [], [], [] - permutation = np.random.permutation(num_batches) - while batch_index < num_batches: - full_trace = ( - i == 0 and j == 0 and - batch_index == config["full_trace_nth_sgd_batch"]) - batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch( - permutation[batch_index] * model.per_device_batch_size, - self.kl_coeff, full_trace, - file_writer if write_tf_logs else None) - loss.append(batch_loss) - kl.append(batch_kl) - entropy.append(batch_entropy) - batch_index += 1 - loss = np.mean(loss) - kl = np.mean(kl) - entropy = np.mean(entropy) - sgd_end = time.time() - print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl, entropy)) + # TF does not support to write logs to S3 at the moment + write_tf_logs = self.logdir.startswith("file") + iter_start = time.time() + if write_tf_logs: + file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph) + if config["model_checkpoint_file"]: + checkpoint_path = saver.save( + model.sess, + os.path.join(self.logdir, + config["model_checkpoint_file"] % j)) + print("Checkpoint saved in file: %s" % checkpoint_path) + checkpointing_end = time.time() + weights = ray.put(model.get_weights()) + [a.load_weights.remote(weights) for a in agents] + trajectory, total_reward, traj_len_mean = collect_samples( + agents, config["timesteps_per_batch"], 0.995, 1.0, 2000) + print("total reward is ", total_reward) + print("trajectory length mean is ", traj_len_mean) + print("timesteps:", trajectory["dones"].shape[0]) + if write_tf_logs: + traj_stats = tf.Summary(value=[ + tf.Summary.Value( + tag="policy_gradient/rollouts/mean_reward", + simple_value=total_reward), + tf.Summary.Value( + tag="policy_gradient/rollouts/traj_len_mean", + simple_value=traj_len_mean)]) + file_writer.add_summary(traj_stats, self.global_step) + self.global_step += 1 + trajectory["advantages"] = ((trajectory["advantages"] - + trajectory["advantages"].mean()) / + trajectory["advantages"].std()) + rollouts_end = time.time() + print("Computing policy (iterations=" + str(config["num_sgd_iter"]) + + ", stepsize=" + str(config["sgd_stepsize"]) + "):") + names = ["iter", "loss", "kl", "entropy"] + print(("{:>15}" * len(names)).format(*names)) + trajectory = shuffle(trajectory) + shuffle_end = time.time() + tuples_per_device = model.load_data( + trajectory, j == 0 and config["full_trace_data_load"]) + load_end = time.time() + checkpointing_time = checkpointing_end - iter_start + rollouts_time = rollouts_end - checkpointing_end + shuffle_time = shuffle_end - rollouts_end + load_time = load_end - shuffle_end + sgd_time = 0 + for i in range(config["num_sgd_iter"]): + sgd_start = time.time() + batch_index = 0 + num_batches = (int(tuples_per_device) // + int(model.per_device_batch_size)) + loss, kl, entropy = [], [], [] + permutation = np.random.permutation(num_batches) + while batch_index < num_batches: + full_trace = ( + i == 0 and j == 0 and + batch_index == config["full_trace_nth_sgd_batch"]) + batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch( + permutation[batch_index] * model.per_device_batch_size, + self.kl_coeff, full_trace, + file_writer if write_tf_logs else None) + loss.append(batch_loss) + kl.append(batch_kl) + entropy.append(batch_entropy) + batch_index += 1 + loss = np.mean(loss) + kl = np.mean(kl) + entropy = np.mean(entropy) + sgd_end = time.time() + print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl, + entropy)) - values = [] - if i == config["num_sgd_iter"] - 1: - metric_prefix = "policy_gradient/sgd/final_iter/" - values.append(tf.Summary.Value( - tag=metric_prefix + "kl_coeff", - simple_value=self.kl_coeff)) - else: - metric_prefix = "policy_gradient/sgd/intermediate_iters/" - values.extend([ - tf.Summary.Value( - tag=metric_prefix + "mean_entropy", - simple_value=entropy), - tf.Summary.Value( - tag=metric_prefix + "mean_loss", - simple_value=loss), - tf.Summary.Value( - tag=metric_prefix + "mean_kl", - simple_value=kl)]) - if write_tf_logs: - sgd_stats = tf.Summary(value=values) - file_writer.add_summary(sgd_stats, self.global_step) - self.global_step += 1 - sgd_time += sgd_end - sgd_start - if kl > 2.0 * config["kl_target"]: - self.kl_coeff *= 1.5 - elif kl < 0.5 * config["kl_target"]: - self.kl_coeff *= 0.5 + values = [] + if i == config["num_sgd_iter"] - 1: + metric_prefix = "policy_gradient/sgd/final_iter/" + values.append(tf.Summary.Value( + tag=metric_prefix + "kl_coeff", + simple_value=self.kl_coeff)) + else: + metric_prefix = "policy_gradient/sgd/intermediate_iters/" + values.extend([ + tf.Summary.Value( + tag=metric_prefix + "mean_entropy", + simple_value=entropy), + tf.Summary.Value( + tag=metric_prefix + "mean_loss", + simple_value=loss), + tf.Summary.Value( + tag=metric_prefix + "mean_kl", + simple_value=kl)]) + if write_tf_logs: + sgd_stats = tf.Summary(value=values) + file_writer.add_summary(sgd_stats, self.global_step) + self.global_step += 1 + sgd_time += sgd_end - sgd_start + if kl > 2.0 * config["kl_target"]: + self.kl_coeff *= 1.5 + elif kl < 0.5 * config["kl_target"]: + self.kl_coeff *= 0.5 - info = { - "kl_divergence": kl, - "kl_coefficient": self.kl_coeff, - "checkpointing_time": checkpointing_time, - "rollouts_time": rollouts_time, - "shuffle_time": shuffle_time, - "load_time": load_time, - "sgd_time": sgd_time, - "sample_throughput": len(trajectory["observations"]) / sgd_time - } + info = { + "kl_divergence": kl, + "kl_coefficient": self.kl_coeff, + "checkpointing_time": checkpointing_time, + "rollouts_time": rollouts_time, + "shuffle_time": shuffle_time, + "load_time": load_time, + "sgd_time": sgd_time, + "sample_throughput": len(trajectory["observations"]) / sgd_time + } - print("kl div:", kl) - print("kl coeff:", self.kl_coeff) - print("checkpointing time:", checkpointing_time) - print("rollouts time:", rollouts_time) - print("shuffle time:", shuffle_time) - print("load time:", load_time) - print("sgd time:", sgd_time) - print("sgd examples/s:", len(trajectory["observations"]) / sgd_time) + print("kl div:", kl) + print("kl coeff:", self.kl_coeff) + print("checkpointing time:", checkpointing_time) + print("rollouts time:", rollouts_time) + print("shuffle time:", shuffle_time) + print("load time:", load_time) + print("sgd time:", sgd_time) + print("sgd examples/s:", len(trajectory["observations"]) / sgd_time) - result = TrainingResult( - self.experiment_id.hex, j, total_reward, traj_len_mean, info) + result = TrainingResult( + self.experiment_id.hex, j, total_reward, traj_len_mean, info) - return result + return result diff --git a/python/ray/rllib/policy_gradient/rollout.py b/python/ray/rllib/policy_gradient/rollout.py index aed531ec2..356a49710 100644 --- a/python/ray/rllib/policy_gradient/rollout.py +++ b/python/ray/rllib/policy_gradient/rollout.py @@ -11,93 +11,95 @@ from ray.rllib.policy_gradient.utils import flatten, concatenate def rollouts(policy, env, horizon, observation_filter=NoFilter(), reward_filter=NoFilter()): - """Perform a batch of rollouts of a policy in an environment. + """Perform a batch of rollouts of a policy in an environment. - Args: - policy: The policy that will be rollout out. Can be an arbitrary object - that supports a compute_actions(observation) function. - env: The environment the rollout is computed in. Needs to support the - OpenAI gym API and needs to support batches of data. - horizon: Upper bound for the number of timesteps for each rollout in the - batch. - observation_filter: Function that is applied to each of the observations. - reward_filter: Function that is applied to each of the rewards. + Args: + policy: The policy that will be rollout out. Can be an arbitrary object + that supports a compute_actions(observation) function. + env: The environment the rollout is computed in. Needs to support the + OpenAI gym API and needs to support batches of data. + horizon: Upper bound for the number of timesteps for each rollout in + the batch. + observation_filter: Function that is applied to each of the + observations. + reward_filter: Function that is applied to each of the rewards. - Returns: - A trajectory, which is a dictionary with keys "observations", "rewards", - "orig_rewards", "actions", "logprobs", "dones". Each value is an array of - shape (num_timesteps, env.batchsize, shape). - """ + Returns: + A trajectory, which is a dictionary with keys "observations", + "rewards", "orig_rewards", "actions", "logprobs", "dones". Each + value is an array of shape (num_timesteps, env.batchsize, shape). + """ - observation = observation_filter(env.reset()) - done = np.array(env.batchsize * [False]) - t = 0 - observations = [] - raw_rewards = [] # Empirical rewards - actions = [] - logprobs = [] - dones = [] + observation = observation_filter(env.reset()) + done = np.array(env.batchsize * [False]) + t = 0 + observations = [] + raw_rewards = [] # Empirical rewards + actions = [] + logprobs = [] + dones = [] - while not done.all() and t < horizon: - action, logprob = policy.compute_actions(observation) - observations.append(observation[None]) - actions.append(action[None]) - logprobs.append(logprob[None]) - observation, raw_reward, done = env.step(action) - observation = observation_filter(observation) - raw_rewards.append(raw_reward[None]) - dones.append(done[None]) - t += 1 + while not done.all() and t < horizon: + action, logprob = policy.compute_actions(observation) + observations.append(observation[None]) + actions.append(action[None]) + logprobs.append(logprob[None]) + observation, raw_reward, done = env.step(action) + observation = observation_filter(observation) + raw_rewards.append(raw_reward[None]) + dones.append(done[None]) + t += 1 - return {"observations": np.vstack(observations), - "raw_rewards": np.vstack(raw_rewards), - "actions": np.vstack(actions), - "logprobs": np.vstack(logprobs), - "dones": np.vstack(dones)} + return {"observations": np.vstack(observations), + "raw_rewards": np.vstack(raw_rewards), + "actions": np.vstack(actions), + "logprobs": np.vstack(logprobs), + "dones": np.vstack(dones)} def add_advantage_values(trajectory, gamma, lam, reward_filter): - rewards = trajectory["raw_rewards"] - dones = trajectory["dones"] - advantages = np.zeros_like(rewards) - last_advantage = np.zeros(rewards.shape[1], dtype="float32") + rewards = trajectory["raw_rewards"] + dones = trajectory["dones"] + advantages = np.zeros_like(rewards) + last_advantage = np.zeros(rewards.shape[1], dtype="float32") - for t in reversed(range(len(rewards))): - delta = rewards[t, :] * (1 - dones[t, :]) - last_advantage = delta + gamma * lam * last_advantage - advantages[t, :] = last_advantage - reward_filter(advantages[t, :]) + for t in reversed(range(len(rewards))): + delta = rewards[t, :] * (1 - dones[t, :]) + last_advantage = delta + gamma * lam * last_advantage + advantages[t, :] = last_advantage + reward_filter(advantages[t, :]) - trajectory["advantages"] = advantages + trajectory["advantages"] = advantages @ray.remote def compute_trajectory(policy, env, gamma, lam, horizon, observation_filter, reward_filter): - trajectory = rollouts(policy, env, horizon, observation_filter, - reward_filter) - add_advantage_values(trajectory, gamma, lam, reward_filter) - return trajectory + trajectory = rollouts(policy, env, horizon, observation_filter, + reward_filter) + add_advantage_values(trajectory, gamma, lam, reward_filter) + return trajectory def collect_samples(agents, num_timesteps, gamma, lam, horizon, observation_filter=NoFilter(), reward_filter=NoFilter()): - num_timesteps_so_far = 0 - trajectories = [] - total_rewards = [] - traj_len_means = [] - while num_timesteps_so_far < num_timesteps: - trajectory_batch = ray.get( - [agent.compute_trajectory.remote(gamma, lam, horizon) - for agent in agents]) - trajectory = concatenate(trajectory_batch) - trajectory = flatten(trajectory) - not_done = np.logical_not(trajectory["dones"]) - total_rewards.append( - trajectory["raw_rewards"][not_done].sum(axis=0).mean() / len(agents)) - traj_len_means.append(not_done.sum(axis=0).mean() / len(agents)) - trajectory = {key: val[not_done] for key, val in trajectory.items()} - num_timesteps_so_far += len(trajectory["dones"]) - trajectories.append(trajectory) - return (concatenate(trajectories), np.mean(total_rewards), - np.mean(traj_len_means)) + num_timesteps_so_far = 0 + trajectories = [] + total_rewards = [] + traj_len_means = [] + while num_timesteps_so_far < num_timesteps: + trajectory_batch = ray.get( + [agent.compute_trajectory.remote(gamma, lam, horizon) + for agent in agents]) + trajectory = concatenate(trajectory_batch) + trajectory = flatten(trajectory) + not_done = np.logical_not(trajectory["dones"]) + total_rewards.append( + trajectory["raw_rewards"][not_done].sum(axis=0).mean() / + len(agents)) + traj_len_means.append(not_done.sum(axis=0).mean() / len(agents)) + trajectory = {key: val[not_done] for key, val in trajectory.items()} + num_timesteps_so_far += len(trajectory["dones"]) + trajectories.append(trajectory) + return (concatenate(trajectories), np.mean(total_rewards), + np.mean(traj_len_means)) diff --git a/python/ray/rllib/policy_gradient/test/test.py b/python/ray/rllib/policy_gradient/test/test.py index 0743bdc5d..325ee437f 100644 --- a/python/ray/rllib/policy_gradient/test/test.py +++ b/python/ray/rllib/policy_gradient/test/test.py @@ -13,49 +13,49 @@ from ray.rllib.policy_gradient.utils import flatten, concatenate class DistibutionsTest(unittest.TestCase): - def testCategorical(self): - num_samples = 100000 - logits = tf.placeholder(tf.float32, shape=(None, 10)) - z = 8 * (np.random.rand(10) - 0.5) - data = np.tile(z, (num_samples, 1)) - c = Categorical(logits) - sample_op = c.sample() - sess = tf.Session() - sess.run(tf.global_variables_initializer()) - samples = sess.run(sample_op, feed_dict={logits: data}) - counts = np.zeros(10) - for sample in samples: - counts[sample] += 1.0 - probs = np.exp(z) / np.sum(np.exp(z)) - self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01) + def testCategorical(self): + num_samples = 100000 + logits = tf.placeholder(tf.float32, shape=(None, 10)) + z = 8 * (np.random.rand(10) - 0.5) + data = np.tile(z, (num_samples, 1)) + c = Categorical(logits) + sample_op = c.sample() + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + samples = sess.run(sample_op, feed_dict={logits: data}) + counts = np.zeros(10) + for sample in samples: + counts[sample] += 1.0 + probs = np.exp(z) / np.sum(np.exp(z)) + self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01) class UtilsTest(unittest.TestCase): - def testFlatten(self): - d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]), - "a": np.array([[[5], [-5]], [[6], [-6]]])} - flat = flatten(d.copy(), start=0, stop=2) - assert_allclose(d["s"][0][0][:], flat["s"][0][:]) - assert_allclose(d["s"][0][1][:], flat["s"][1][:]) - assert_allclose(d["s"][1][0][:], flat["s"][2][:]) - assert_allclose(d["s"][1][1][:], flat["s"][3][:]) - assert_allclose(d["a"][0][0], flat["a"][0]) - assert_allclose(d["a"][0][1], flat["a"][1]) - assert_allclose(d["a"][1][0], flat["a"][2]) - assert_allclose(d["a"][1][1], flat["a"][3]) + def testFlatten(self): + d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]), + "a": np.array([[[5], [-5]], [[6], [-6]]])} + flat = flatten(d.copy(), start=0, stop=2) + assert_allclose(d["s"][0][0][:], flat["s"][0][:]) + assert_allclose(d["s"][0][1][:], flat["s"][1][:]) + assert_allclose(d["s"][1][0][:], flat["s"][2][:]) + assert_allclose(d["s"][1][1][:], flat["s"][3][:]) + assert_allclose(d["a"][0][0], flat["a"][0]) + assert_allclose(d["a"][0][1], flat["a"][1]) + assert_allclose(d["a"][1][0], flat["a"][2]) + assert_allclose(d["a"][1][1], flat["a"][3]) - def testConcatenate(self): - d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])} - d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])} - d = concatenate([d1, d2]) - assert_allclose(d["s"], np.array([0, 1, 4, 5])) - assert_allclose(d["a"], np.array([2, 3, 6, 7])) + def testConcatenate(self): + d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])} + d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])} + d = concatenate([d1, d2]) + assert_allclose(d["s"], np.array([0, 1, 4, 5])) + assert_allclose(d["a"], np.array([2, 3, 6, 7])) - D = concatenate([d]) - assert_allclose(D["s"], np.array([0, 1, 4, 5])) - assert_allclose(D["a"], np.array([2, 3, 6, 7])) + D = concatenate([d]) + assert_allclose(D["s"], np.array([0, 1, 4, 5])) + assert_allclose(D["a"], np.array([2, 3, 6, 7])) if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/python/ray/rllib/policy_gradient/utils.py b/python/ray/rllib/policy_gradient/utils.py index 0762b930c..e92092f11 100644 --- a/python/ray/rllib/policy_gradient/utils.py +++ b/python/ray/rllib/policy_gradient/utils.py @@ -6,31 +6,31 @@ import numpy as np def flatten(weights, start=0, stop=2): - """This methods reshapes all values in a dictionary. + """This methods reshapes all values in a dictionary. - The indices from start to stop will be flattened into a single index. + The indices from start to stop will be flattened into a single index. - Args: - weights: A dictionary mapping keys to numpy arrays. - start: The starting index. - stop: The ending index. - """ - for key, val in weights.items(): - new_shape = val.shape[0:start] + (-1,) + val.shape[stop:] - weights[key] = val.reshape(new_shape) - return weights + Args: + weights: A dictionary mapping keys to numpy arrays. + start: The starting index. + stop: The ending index. + """ + for key, val in weights.items(): + new_shape = val.shape[0:start] + (-1,) + val.shape[stop:] + weights[key] = val.reshape(new_shape) + return weights def concatenate(weights_list): - keys = weights_list[0].keys() - result = {} - for key in keys: - result[key] = np.concatenate([l[key] for l in weights_list]) - return result + keys = weights_list[0].keys() + result = {} + for key in keys: + result[key] = np.concatenate([l[key] for l in weights_list]) + return result def shuffle(trajectory): - permutation = np.random.permutation(trajectory["dones"].shape[0]) - for key, val in trajectory.items(): - trajectory[key] = val[permutation] - return trajectory + permutation = np.random.permutation(trajectory["dones"].shape[0]) + for key, val in trajectory.items(): + trajectory[key] = val[permutation] + return trajectory diff --git a/python/ray/rllib/train.py b/python/ray/rllib/train.py index 42dcd0622..88daf9582 100755 --- a/python/ray/rllib/train.py +++ b/python/ray/rllib/train.py @@ -22,36 +22,36 @@ parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str) if __name__ == "__main__": - args = parser.parse_args() + args = parser.parse_args() - ray.init() + ray.init() - env_name = args.env - if args.alg == "PolicyGradient": - alg = pg.PolicyGradient( - env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir) - elif args.alg == "EvolutionStrategies": - alg = es.EvolutionStrategies( - env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir) - elif args.alg == "DQN": - alg = dqn.DQN( - env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir) - elif args.alg == "A3C": - alg = a3c.A3C( - env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir) - else: - assert False, ("Unknown algorithm, check --alg argument. Valid choices " - "are PolicyGradientPolicyGradient, EvolutionStrategies, " - "DQN and A3C.") + env_name = args.env + if args.alg == "PolicyGradient": + alg = pg.PolicyGradient( + env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir) + elif args.alg == "EvolutionStrategies": + alg = es.EvolutionStrategies( + env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir) + elif args.alg == "DQN": + alg = dqn.DQN( + env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir) + elif args.alg == "A3C": + alg = a3c.A3C( + env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir) + else: + assert False, ("Unknown algorithm, check --alg argument. Valid " + "choices are PolicyGradientPolicyGradient, " + "EvolutionStrategies, DQN and A3C.") - result_logger = ray.rllib.common.RLLibLogger( - os.path.join(alg.logdir, "result.json")) + result_logger = ray.rllib.common.RLLibLogger( + os.path.join(alg.logdir, "result.json")) - while True: - result = alg.train() + while True: + result = alg.train() - # We need to use a custom json serializer class so that NaNs get encoded - # as null as required by Athena. - json.dump(result._asdict(), result_logger, - cls=ray.rllib.common.RLLibEncoder) - result_logger.write("\n") + # We need to use a custom json serializer class so that NaNs get + # encoded as null as required by Athena. + json.dump(result._asdict(), result_logger, + cls=ray.rllib.common.RLLibEncoder) + result_logger.write("\n") diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 26f279d9c..26441cd8d 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -10,34 +10,35 @@ import ray.services as services def check_no_existing_redis_clients(node_ip_address, redis_address): - redis_ip_address, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) - # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.cc" where it is defined. - REDIS_CLIENT_TABLE_PREFIX = "CL:" - client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) - # Filter to clients on the same node and do some basic checking. - for key in client_keys: - info = redis_client.hgetall(key) - assert b"ray_client_id" in info - assert b"node_ip_address" in info - assert b"client_type" in info - assert b"deleted" in info - # Clients that ran on the same node but that are marked dead can be - # ignored. - deleted = info[b"deleted"] - deleted = bool(int(deleted)) - if deleted: - continue + redis_ip_address, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_ip_address, + port=int(redis_port)) + # The client table prefix must be kept in sync with the file + # "src/common/redis_module/ray_redis_module.cc" where it is defined. + REDIS_CLIENT_TABLE_PREFIX = "CL:" + client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) + # Filter to clients on the same node and do some basic checking. + for key in client_keys: + info = redis_client.hgetall(key) + assert b"ray_client_id" in info + assert b"node_ip_address" in info + assert b"client_type" in info + assert b"deleted" in info + # Clients that ran on the same node but that are marked dead can be + # ignored. + deleted = info[b"deleted"] + deleted = bool(int(deleted)) + if deleted: + continue - if info[b"node_ip_address"].decode("ascii") == node_ip_address: - raise Exception("This Redis instance is already connected to clients " - "with this IP address.") + if info[b"node_ip_address"].decode("ascii") == node_ip_address: + raise Exception("This Redis instance is already connected to " + "clients with this IP address.") @click.group() def cli(): - pass + pass @click.command() @@ -64,120 +65,122 @@ def cli(): help="provide this argument to block forever in this command") def start(node_ip_address, redis_address, redis_port, num_redis_shards, object_manager_port, num_workers, num_cpus, num_gpus, head, block): - # Note that we redirect stdout and stderr to /dev/null because otherwise - # attempts to print may cause exceptions if a process is started inside of an - # SSH connection and the SSH connection dies. TODO(rkn): This is a temporary - # fix. We should actually redirect stdout and stderr to Redis in some way. + # Note that we redirect stdout and stderr to /dev/null because otherwise + # attempts to print may cause exceptions if a process is started inside of + # an SSH connection and the SSH connection dies. TODO(rkn): This is a + # temporary fix. We should actually redirect stdout and stderr to Redis in + # some way. - if head: - # Start Ray on the head node. - if redis_address is not None: - raise Exception("If --head is passed in, a Redis server will be " - "started, so a Redis address should not be provided.") + if head: + # Start Ray on the head node. + if redis_address is not None: + raise Exception("If --head is passed in, a Redis server will be " + "started, so a Redis address should not be " + "provided.") - # Get the node IP address if one is not provided. - if node_ip_address is None: - node_ip_address = services.get_node_ip_address() - print("Using IP address {} for this node.".format(node_ip_address)) + # Get the node IP address if one is not provided. + if node_ip_address is None: + node_ip_address = services.get_node_ip_address() + print("Using IP address {} for this node.".format(node_ip_address)) - address_info = {} - # Use the provided object manager port if there is one. - if object_manager_port is not None: - address_info["object_manager_ports"] = [object_manager_port] - if address_info == {}: - address_info = None + address_info = {} + # Use the provided object manager port if there is one. + if object_manager_port is not None: + address_info["object_manager_ports"] = [object_manager_port] + if address_info == {}: + address_info = None - address_info = services.start_ray_head( - address_info=address_info, - node_ip_address=node_ip_address, - redis_port=redis_port, - num_workers=num_workers, - cleanup=False, - redirect_output=True, - num_cpus=num_cpus, - num_gpus=num_gpus, - num_redis_shards=num_redis_shards) - print(address_info) - print("\nStarted Ray on this node. You can add additional nodes to the " - "cluster by calling\n\n" - " ray start --redis-address {}\n\n" - "from the node you wish to add. You can connect a driver to the " - "cluster from Python by running\n\n" - " import ray\n" - " ray.init(redis_address=\"{}\")\n\n" - "If you have trouble connecting from a different machine, check " - "that your firewall is configured properly. If you wish to " - "terminate the processes that have been started, run\n\n" - " ray stop".format(address_info["redis_address"], - address_info["redis_address"])) - else: - # Start Ray on a non-head node. - if redis_port is not None: - raise Exception("If --head is not passed in, --redis-port is not " - "allowed") - if redis_address is None: - raise Exception("If --head is not passed in, --redis-address must be " - "provided.") - if num_redis_shards is not None: - raise Exception("If --head is not passed in, --num-redis-shards must " - "not be provided.") - redis_ip_address, redis_port = redis_address.split(":") - # Wait for the Redis server to be started. And throw an exception if we - # can't connect to it. - services.wait_for_redis_to_start(redis_ip_address, int(redis_port)) - # Get the node IP address if one is not provided. - if node_ip_address is None: - node_ip_address = services.get_node_ip_address(redis_address) - print("Using IP address {} for this node.".format(node_ip_address)) - # Check that there aren't already Redis clients with the same IP address - # connected with this Redis instance. This raises an exception if the Redis - # server already has clients on this node. - check_no_existing_redis_clients(node_ip_address, redis_address) - address_info = services.start_ray_node( - node_ip_address=node_ip_address, - redis_address=redis_address, - object_manager_ports=[object_manager_port], - num_workers=num_workers, - cleanup=False, - redirect_output=True, - num_cpus=num_cpus, - num_gpus=num_gpus) - print(address_info) - print("\nStarted Ray on this node. If you wish to terminate the processes " - "that have been started, run\n\n" - " ray stop") + address_info = services.start_ray_head( + address_info=address_info, + node_ip_address=node_ip_address, + redis_port=redis_port, + num_workers=num_workers, + cleanup=False, + redirect_output=True, + num_cpus=num_cpus, + num_gpus=num_gpus, + num_redis_shards=num_redis_shards) + print(address_info) + print("\nStarted Ray on this node. You can add additional nodes to " + "the cluster by calling\n\n" + " ray start --redis-address {}\n\n" + "from the node you wish to add. You can connect a driver to the " + "cluster from Python by running\n\n" + " import ray\n" + " ray.init(redis_address=\"{}\")\n\n" + "If you have trouble connecting from a different machine, check " + "that your firewall is configured properly. If you wish to " + "terminate the processes that have been started, run\n\n" + " ray stop".format(address_info["redis_address"], + address_info["redis_address"])) + else: + # Start Ray on a non-head node. + if redis_port is not None: + raise Exception("If --head is not passed in, --redis-port is not " + "allowed") + if redis_address is None: + raise Exception("If --head is not passed in, --redis-address must " + "be provided.") + if num_redis_shards is not None: + raise Exception("If --head is not passed in, --num-redis-shards " + "must not be provided.") + redis_ip_address, redis_port = redis_address.split(":") + # Wait for the Redis server to be started. And throw an exception if we + # can't connect to it. + services.wait_for_redis_to_start(redis_ip_address, int(redis_port)) + # Get the node IP address if one is not provided. + if node_ip_address is None: + node_ip_address = services.get_node_ip_address(redis_address) + print("Using IP address {} for this node.".format(node_ip_address)) + # Check that there aren't already Redis clients with the same IP + # address connected with this Redis instance. This raises an exception + # if the Redis server already has clients on this node. + check_no_existing_redis_clients(node_ip_address, redis_address) + address_info = services.start_ray_node( + node_ip_address=node_ip_address, + redis_address=redis_address, + object_manager_ports=[object_manager_port], + num_workers=num_workers, + cleanup=False, + redirect_output=True, + num_cpus=num_cpus, + num_gpus=num_gpus) + print(address_info) + print("\nStarted Ray on this node. If you wish to terminate the " + "processes that have been started, run\n\n" + " ray stop") - if block: - import time - while True: - time.sleep(30) + if block: + import time + while True: + time.sleep(30) @click.command() def stop(): - subprocess.call(["killall global_scheduler plasma_store plasma_manager " - "local_scheduler"], shell=True) + subprocess.call(["killall global_scheduler plasma_store plasma_manager " + "local_scheduler"], shell=True) - # Find the PID of the monitor process and kill it. - subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null"], shell=True) + # Find the PID of the monitor process and kill it. + subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null"], shell=True) - # Find the PID of the Redis process and kill it. - subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null"], shell=True) + # Find the PID of the Redis process and kill it. + subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null"], shell=True) - # Find the PIDs of the worker processes and kill them. - subprocess.call(["kill -9 $(ps aux | grep default_worker.py | " - "grep -v grep | awk '{ print $2 }') 2> /dev/null"], - shell=True) + # Find the PIDs of the worker processes and kill them. + subprocess.call(["kill -9 $(ps aux | grep default_worker.py | " + "grep -v grep | awk '{ print $2 }') 2> /dev/null"], + shell=True) - # Find the PID of the Ray log monitor process and kill it. - subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null"], shell=True) + # Find the PID of the Ray log monitor process and kill it. + subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null"], shell=True) - # Find the PID of the jupyter process and kill it. - subprocess.call(["kill $(ps aux | grep jupyter | grep -v grep | " - "awk '{ print $2 }') 2> /dev/null"], shell=True) + # Find the PID of the jupyter process and kill it. + subprocess.call(["kill $(ps aux | grep jupyter | grep -v grep | " + "awk '{ print $2 }') 2> /dev/null"], shell=True) cli.add_command(start) @@ -185,8 +188,8 @@ cli.add_command(stop) def main(): - return cli() + return cli() if __name__ == "__main__": - main() + main() diff --git a/python/ray/serialization.py b/python/ray/serialization.py index bd66a83ba..5a99bda67 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -8,58 +8,61 @@ import ray.numbuf class RaySerializationException(Exception): - def __init__(self, message, example_object): - Exception.__init__(self, message) - self.example_object = example_object + def __init__(self, message, example_object): + Exception.__init__(self, message) + self.example_object = example_object class RayDeserializationException(Exception): - def __init__(self, message, class_id): - Exception.__init__(self, message) - self.class_id = class_id + def __init__(self, message, class_id): + Exception.__init__(self, message) + self.class_id = class_id class RayNotDictionarySerializable(Exception): - pass + pass def check_serializable(cls): - """Throws an exception if Ray cannot serialize this class efficiently. + """Throws an exception if Ray cannot serialize this class efficiently. - Args: - cls (type): The class to be serialized. + Args: + cls (type): The class to be serialized. - Raises: - Exception: An exception is raised if Ray cannot serialize this class - efficiently. - """ - if is_named_tuple(cls): - # This case works. - return - if not hasattr(cls, "__new__"): - print("The class {} does not have a '__new__' attribute and is probably " - "an old-stye class. Please make it a new-style class by inheriting " - "from 'object'.") - raise RayNotDictionarySerializable("The class {} does not have a " - "'__new__' attribute and is probably " - "an old-style class. We do not support " - "this. Please make it a new-style " - "class by inheriting from 'object'." - .format(cls)) - try: - obj = cls.__new__(cls) - except: - raise RayNotDictionarySerializable("The class {} has overridden '__new__'" - ", so Ray may not be able to serialize " - "it efficiently.".format(cls)) - if not hasattr(obj, "__dict__"): - raise RayNotDictionarySerializable("Objects of the class {} do not have a " - "'__dict__' attribute, so Ray cannot " - "serialize it efficiently.".format(cls)) - if hasattr(obj, "__slots__"): - raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray " - "may not be able to serialize it " - "efficiently.".format(cls)) + Raises: + Exception: An exception is raised if Ray cannot serialize this class + efficiently. + """ + if is_named_tuple(cls): + # This case works. + return + if not hasattr(cls, "__new__"): + print("The class {} does not have a '__new__' attribute and is " + "probably an old-stye class. Please make it a new-style class " + "by inheriting from 'object'.") + raise RayNotDictionarySerializable("The class {} does not have a " + "'__new__' attribute and is " + "probably an old-style class. We " + "do not support this. Please make " + "it a new-style class by " + "inheriting from 'object'." + .format(cls)) + try: + obj = cls.__new__(cls) + except: + raise RayNotDictionarySerializable("The class {} has overridden " + "'__new__', so Ray may not be able " + "to serialize it efficiently." + .format(cls)) + if not hasattr(obj, "__dict__"): + raise RayNotDictionarySerializable("Objects of the class {} do not " + "have a '__dict__' attribute, so " + "Ray cannot serialize it " + "efficiently.".format(cls)) + if hasattr(obj, "__slots__"): + raise RayNotDictionarySerializable("The class {} uses '__slots__', so " + "Ray may not be able to serialize " + "it efficiently.".format(cls)) # This field keeps track of a whitelisted set of classes that Ray will @@ -72,134 +75,136 @@ custom_deserializers = dict() def is_named_tuple(cls): - """Return True if cls is a namedtuple and False otherwise.""" - b = cls.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(cls, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) + """Return True if cls is a namedtuple and False otherwise.""" + b = cls.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(cls, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None, custom_deserializer=None): - """Add cls to the list of classes that we can serialize. + """Add cls to the list of classes that we can serialize. - Args: - cls (type): The class that we can serialize. - class_id: A string of bytes used to identify the class. - pickle (bool): True if the serialization should be done with pickle. False - if it should be done efficiently with Ray. - custom_serializer: This argument is optional, but can be provided to - serialize objects of the class in a particular way. - custom_deserializer: This argument is optional, but can be provided to - deserialize objects of the class in a particular way. - """ - type_to_class_id[cls] = class_id - whitelisted_classes[class_id] = cls - if pickle: - classes_to_pickle.add(class_id) - if custom_serializer is not None: - custom_serializers[class_id] = custom_serializer - custom_deserializers[class_id] = custom_deserializer + Args: + cls (type): The class that we can serialize. + class_id: A string of bytes used to identify the class. + pickle (bool): True if the serialization should be done with pickle. + False if it should be done efficiently with Ray. + custom_serializer: This argument is optional, but can be provided to + serialize objects of the class in a particular way. + custom_deserializer: This argument is optional, but can be provided to + deserialize objects of the class in a particular way. + """ + type_to_class_id[cls] = class_id + whitelisted_classes[class_id] = cls + if pickle: + classes_to_pickle.add(class_id) + if custom_serializer is not None: + custom_serializers[class_id] = custom_serializer + custom_deserializers[class_id] = custom_deserializer def serialize(obj): - """This is the callback that will be used by numbuf. + """This is the callback that will be used by numbuf. - If numbuf does not know how to serialize an object, it will call this method. + If numbuf does not know how to serialize an object, it will call this + method. - Args: - obj (object): A Python object. + Args: + obj (object): A Python object. - Returns: - A dictionary that has the key "_pyttype_" to identify the class, and - contains all information needed to reconstruct the object. - """ - if type(obj) not in type_to_class_id: - raise RaySerializationException("Ray does not know how to serialize " - "objects of type {}.".format(type(obj)), - obj) - class_id = type_to_class_id[type(obj)] + Returns: + A dictionary that has the key "_pyttype_" to identify the class, and + contains all information needed to reconstruct the object. + """ + if type(obj) not in type_to_class_id: + raise RaySerializationException("Ray does not know how to serialize " + "objects of type {}." + .format(type(obj)), + obj) + class_id = type_to_class_id[type(obj)] - if class_id in classes_to_pickle: - serialized_obj = {"data": pickle.dumps(obj), - "pickle": True} - elif class_id in custom_serializers: - serialized_obj = {"data": custom_serializers[class_id](obj)} - else: - # Handle the namedtuple case. - if is_named_tuple(type(obj)): - serialized_obj = {} - serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__() - elif hasattr(obj, "__dict__"): - serialized_obj = obj.__dict__ + if class_id in classes_to_pickle: + serialized_obj = {"data": pickle.dumps(obj), + "pickle": True} + elif class_id in custom_serializers: + serialized_obj = {"data": custom_serializers[class_id](obj)} else: - raise RaySerializationException("We do not know how to serialize the " - "object '{}'".format(obj), obj) - result = dict(serialized_obj, **{"_pytype_": class_id}) - return result + # Handle the namedtuple case. + if is_named_tuple(type(obj)): + serialized_obj = {} + serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__() + elif hasattr(obj, "__dict__"): + serialized_obj = obj.__dict__ + else: + raise RaySerializationException("We do not know how to serialize " + "the object '{}'".format(obj), obj) + result = dict(serialized_obj, **{"_pytype_": class_id}) + return result def deserialize(serialized_obj): - """This is the callback that will be used by numbuf. + """This is the callback that will be used by numbuf. - If numbuf encounters a dictionary that contains the key "_pytype_" during + If numbuf encounters a dictionary that contains the key "_pytype_" during deserialization, it will ask this callback to deserialize the object. - Args: - serialized_obj (object): A dictionary that contains the key "_pytype_". + Args: + serialized_obj (object): A dictionary that contains the key "_pytype_". - Returns: - A Python object. + Returns: + A Python object. - Raises: - An exception is raised if we do not know how to deserialize the object. - """ - class_id = serialized_obj["_pytype_"] + Raises: + An exception is raised if we do not know how to deserialize the object. + """ + class_id = serialized_obj["_pytype_"] - if "pickle" in serialized_obj: - # The object was pickled, so unpickle it. - obj = pickle.loads(serialized_obj["data"]) - else: - assert class_id not in classes_to_pickle - if class_id not in whitelisted_classes: - # If this happens, that means that the call to _register_class, which - # should have added the class to the list of whitelisted classes, has not - # yet propagated to this worker. It should happen if we wait a little - # longer. - raise RayDeserializationException("The class {} is not one of the " - "whitelisted classes." - .format(class_id), class_id) - cls = whitelisted_classes[class_id] - if class_id in custom_deserializers: - obj = custom_deserializers[class_id](serialized_obj["data"]) + if "pickle" in serialized_obj: + # The object was pickled, so unpickle it. + obj = pickle.loads(serialized_obj["data"]) else: - # In this case, serialized_obj should just be the __dict__ field. - if "_ray_getnewargs_" in serialized_obj: - obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"]) - else: - obj = cls.__new__(cls) - serialized_obj.pop("_pytype_") - obj.__dict__.update(serialized_obj) - return obj + assert class_id not in classes_to_pickle + if class_id not in whitelisted_classes: + # If this happens, that means that the call to _register_class, + # which should have added the class to the list of whitelisted + # classes, has not yet propagated to this worker. It should happen + # if we wait a little longer. + raise RayDeserializationException("The class {} is not one of the " + "whitelisted classes." + .format(class_id), class_id) + cls = whitelisted_classes[class_id] + if class_id in custom_deserializers: + obj = custom_deserializers[class_id](serialized_obj["data"]) + else: + # In this case, serialized_obj should just be the __dict__ field. + if "_ray_getnewargs_" in serialized_obj: + obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"]) + else: + obj = cls.__new__(cls) + serialized_obj.pop("_pytype_") + obj.__dict__.update(serialized_obj) + return obj def set_callbacks(): - """Register the custom callbacks with numbuf. + """Register the custom callbacks with numbuf. - The serialize callback is used to serialize objects that numbuf does not know - how to serialize (for example custom Python classes). The deserialize - callback is used to serialize objects that were serialized by the serialize - callback. - """ - ray.numbuf.register_callbacks(serialize, deserialize) + The serialize callback is used to serialize objects that numbuf does not + know how to serialize (for example custom Python classes). The deserialize + callback is used to serialize objects that were serialized by the serialize + callback. + """ + ray.numbuf.register_callbacks(serialize, deserialize) def clear_state(): - type_to_class_id.clear() - whitelisted_classes.clear() - classes_to_pickle.clear() - custom_serializers.clear() - custom_deserializers.clear() + type_to_class_id.clear() + whitelisted_classes.clear() + classes_to_pickle.clear() + custom_serializers.clear() + custom_deserializers.clear() diff --git a/python/ray/services.py b/python/ray/services.py index faa69ceb3..a37c16a30 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -61,184 +61,187 @@ ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name", def address(ip_address, port): - return ip_address + ":" + str(port) + return ip_address + ":" + str(port) def get_ip_address(address): - try: - ip_address = address.split(":")[0] - except: - raise Exception("Unable to parse IP address from address " - "{}".format(address)) - return ip_address + try: + ip_address = address.split(":")[0] + except: + raise Exception("Unable to parse IP address from address " + "{}".format(address)) + return ip_address def get_port(address): - try: - port = int(address.split(":")[1]) - except: - raise Exception("Unable to parse port from address {}".format(address)) - return port + try: + port = int(address.split(":")[1]) + except: + raise Exception("Unable to parse port from address {}".format(address)) + return port def new_port(): - return random.randint(10000, 65535) + return random.randint(10000, 65535) def random_name(): - return str(random.randint(0, 99999999)) + return str(random.randint(0, 99999999)) def kill_process(p): - """Kill a process. + """Kill a process. - Args: - p: The process to kill. + Args: + p: The process to kill. - Returns: - True if the process was killed successfully and false otherwise. - """ - if p.poll() is not None: - # The process has already terminated. - return True - if any([RUN_LOCAL_SCHEDULER_PROFILER, RUN_PLASMA_MANAGER_PROFILER, - RUN_PLASMA_STORE_PROFILER]): - # Give process signal to write profiler data. - os.kill(p.pid, signal.SIGINT) - # Wait for profiling data to be written. - time.sleep(0.1) + Returns: + True if the process was killed successfully and false otherwise. + """ + if p.poll() is not None: + # The process has already terminated. + return True + if any([RUN_LOCAL_SCHEDULER_PROFILER, RUN_PLASMA_MANAGER_PROFILER, + RUN_PLASMA_STORE_PROFILER]): + # Give process signal to write profiler data. + os.kill(p.pid, signal.SIGINT) + # Wait for profiling data to be written. + time.sleep(0.1) - # Allow the process one second to exit gracefully. - p.terminate() - timer = threading.Timer(1, lambda p: p.kill(), [p]) - try: - timer.start() - p.wait() - finally: - timer.cancel() + # Allow the process one second to exit gracefully. + p.terminate() + timer = threading.Timer(1, lambda p: p.kill(), [p]) + try: + timer.start() + p.wait() + finally: + timer.cancel() - if p.poll() is not None: - return True + if p.poll() is not None: + return True - # If the process did not exit within one second, force kill it. - p.kill() - if p.poll() is not None: - return True + # If the process did not exit within one second, force kill it. + p.kill() + if p.poll() is not None: + return True - # The process was not killed for some reason. - return False + # The process was not killed for some reason. + return False def cleanup(): - """When running in local mode, shutdown the Ray processes. + """When running in local mode, shutdown the Ray processes. - This method is used to shutdown processes that were started with - services.start_ray_head(). It kills all scheduler, object store, and worker - processes that were started by this services module. Driver processes are - started and disconnected by worker.py. - """ - successfully_shut_down = True - # Terminate the processes in reverse order. - for process_type in all_processes.keys(): - # Kill all of the processes of a certain type. - for p in all_processes[process_type]: - success = kill_process(p) - successfully_shut_down = successfully_shut_down and success - # Reset the list of processes of this type. - all_processes[process_type] = [] - if not successfully_shut_down: - print("Ray did not shut down properly.") + This method is used to shutdown processes that were started with + services.start_ray_head(). It kills all scheduler, object store, and worker + processes that were started by this services module. Driver processes are + started and disconnected by worker.py. + """ + successfully_shut_down = True + # Terminate the processes in reverse order. + for process_type in all_processes.keys(): + # Kill all of the processes of a certain type. + for p in all_processes[process_type]: + success = kill_process(p) + successfully_shut_down = successfully_shut_down and success + # Reset the list of processes of this type. + all_processes[process_type] = [] + if not successfully_shut_down: + print("Ray did not shut down properly.") def all_processes_alive(exclude=[]): - """Check if all of the processes are still alive. + """Check if all of the processes are still alive. - Args: - exclude: Don't check the processes whose types are in this list. - """ - for process_type, processes in all_processes.items(): - # Note that p.poll() returns the exit code that the process exited with, so - # an exit code of None indicates that the process is still alive. - processes_alive = [p.poll() is None for p in processes] - if (not all(processes_alive) and process_type not in exclude): - print("A process of type {} has died.".format(process_type)) - return False - return True + Args: + exclude: Don't check the processes whose types are in this list. + """ + for process_type, processes in all_processes.items(): + # Note that p.poll() returns the exit code that the process exited + # with, so an exit code of None indicates that the process is still + # alive. + processes_alive = [p.poll() is None for p in processes] + if (not all(processes_alive) and process_type not in exclude): + print("A process of type {} has died.".format(process_type)) + return False + return True def get_node_ip_address(address="8.8.8.8:53"): - """Determine the IP address of the local node. + """Determine the IP address of the local node. - Args: - address (str): The IP address and port of any known live service on the - network you care about. + Args: + address (str): The IP address and port of any known live service on the + network you care about. - Returns: - The IP address of the current node. - """ - ip_address, port = address.split(":") - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect((ip_address, int(port))) - return s.getsockname()[0] + Returns: + The IP address of the current node. + """ + ip_address, port = address.split(":") + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect((ip_address, int(port))) + return s.getsockname()[0] def record_log_files_in_redis(redis_address, node_ip_address, log_files): - """Record in Redis that a new log file has been created. + """Record in Redis that a new log file has been created. - This is used so that each log monitor can check Redis and figure out which - log files it is reponsible for monitoring. + This is used so that each log monitor can check Redis and figure out which + log files it is reponsible for monitoring. - Args: - redis_address: The address of the redis server. - node_ip_address: The IP address of the node that the log file exists on. - log_files: A list of file handles for the log files. If one of the file - handles is None, we ignore it. - """ - for log_file in log_files: - if log_file is not None: - redis_ip_address, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) - # The name of the key storing the list of log filenames for this IP - # address. - log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) - redis_client.rpush(log_file_list_key, log_file.name) + Args: + redis_address: The address of the redis server. + node_ip_address: The IP address of the node that the log file exists + on. + log_files: A list of file handles for the log files. If one of the file + handles is None, we ignore it. + """ + for log_file in log_files: + if log_file is not None: + redis_ip_address, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) + # The name of the key storing the list of log filenames for this IP + # address. + log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) + redis_client.rpush(log_file_list_key, log_file.name) def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): - """Wait for a Redis server to be available. + """Wait for a Redis server to be available. - This is accomplished by creating a Redis client and sending a random command - to the server until the command gets through. + This is accomplished by creating a Redis client and sending a random + command to the server until the command gets through. - Args: - redis_ip_address (str): The IP address of the redis server. - redis_port (int): The port of the redis server. - num_retries (int): The number of times to try connecting with redis. The - client will sleep for one second between attempts. + Args: + redis_ip_address (str): The IP address of the redis server. + redis_port (int): The port of the redis server. + num_retries (int): The number of times to try connecting with redis. + The client will sleep for one second between attempts. - Raises: - Exception: An exception is raised if we could not connect with Redis. - """ - redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) - # Wait for the Redis server to start. - counter = 0 - while counter < num_retries: - try: - # Run some random command and see if it worked. - print("Waiting for redis server at {}:{} to respond..." - .format(redis_ip_address, redis_port)) - redis_client.client_list() - except redis.ConnectionError as e: - # Wait a little bit. - time.sleep(1) - print("Failed to connect to the redis server, retrying.") - counter += 1 - else: - break - if counter == num_retries: - raise Exception("Unable to connect to Redis. If the Redis instance is on " - "a different machine, check that your firewall is " - "configured properly.") + Raises: + Exception: An exception is raised if we could not connect with Redis. + """ + redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + # Wait for the Redis server to start. + counter = 0 + while counter < num_retries: + try: + # Run some random command and see if it worked. + print("Waiting for redis server at {}:{} to respond..." + .format(redis_ip_address, redis_port)) + redis_client.client_list() + except redis.ConnectionError as e: + # Wait a little bit. + time.sleep(1) + print("Failed to connect to the redis server, retrying.") + counter += 1 + else: + break + if counter == num_retries: + raise Exception("Unable to connect to Redis. If the Redis instance is " + "on a different machine, check that your firewall is " + "configured properly.") def start_redis(node_ip_address, @@ -246,54 +249,56 @@ def start_redis(node_ip_address, num_redis_shards=1, redirect_output=False, cleanup=True): - """Start the Redis global state store. + """Start the Redis global state store. - Args: - node_ip_address: The IP address of the current node. This is only used for - recording the log filenames in Redis. - port (int): If provided, the primary Redis shard will be started on this - port. - num_redis_shards (int): If provided, the number of Redis shards to start, - in addition to the primary one. The default value is one shard. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - all Redis processes started by this method will be killed by - serices.cleanup() when the Python process that imported services exits. + Args: + node_ip_address: The IP address of the current node. This is only used + for recording the log filenames in Redis. + port (int): If provided, the primary Redis shard will be started on + this port. + num_redis_shards (int): If provided, the number of Redis shards to + start, in addition to the primary one. The default value is one + shard. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then all Redis processes started by this method will be killed by + services.cleanup() when the Python process that imported services + exits. - Returns: - A tuple of the address for the primary Redis shard and a list of addresses - for the remaining shards. - """ - redis_stdout_file, redis_stderr_file = new_log_files( - "redis", redirect_output) - assigned_port, _ = start_redis_instance( - node_ip_address=node_ip_address, port=port, - stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup) - if port is not None: - assert assigned_port == port - port = assigned_port - redis_address = address(node_ip_address, port) - - # Register the number of Redis shards in the primary shard, so that clients - # know how many redis shards to expect under RedisShards. - redis_client = redis.StrictRedis(host=node_ip_address, port=port) - redis_client.set("NumRedisShards", str(num_redis_shards)) - - # Start other Redis shards listening on random ports. Each Redis shard logs - # to a separate file, prefixed by "redis-". - redis_shards = [] - for i in range(num_redis_shards): + Returns: + A tuple of the address for the primary Redis shard and a list of + addresses for the remaining shards. + """ redis_stdout_file, redis_stderr_file = new_log_files( - "redis-{}".format(i), redirect_output) - redis_shard_port, _ = start_redis_instance( - node_ip_address=node_ip_address, stdout_file=redis_stdout_file, - stderr_file=redis_stderr_file, cleanup=cleanup) - shard_address = address(node_ip_address, redis_shard_port) - redis_shards.append(shard_address) - # Store redis shard information in the primary redis shard. - redis_client.rpush("RedisShards", shard_address) + "redis", redirect_output) + assigned_port, _ = start_redis_instance( + node_ip_address=node_ip_address, port=port, + stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, + cleanup=cleanup) + if port is not None: + assert assigned_port == port + port = assigned_port + redis_address = address(node_ip_address, port) - return redis_address, redis_shards + # Register the number of Redis shards in the primary shard, so that clients + # know how many redis shards to expect under RedisShards. + redis_client = redis.StrictRedis(host=node_ip_address, port=port) + redis_client.set("NumRedisShards", str(num_redis_shards)) + + # Start other Redis shards listening on random ports. Each Redis shard logs + # to a separate file, prefixed by "redis-". + redis_shards = [] + for i in range(num_redis_shards): + redis_stdout_file, redis_stderr_file = new_log_files( + "redis-{}".format(i), redirect_output) + redis_shard_port, _ = start_redis_instance( + node_ip_address=node_ip_address, stdout_file=redis_stdout_file, + stderr_file=redis_stderr_file, cleanup=cleanup) + shard_address = address(node_ip_address, redis_shard_port) + redis_shards.append(shard_address) + # Store redis shard information in the primary redis shard. + redis_client.rpush("RedisShards", shard_address) + + return redis_address, redis_shards def start_redis_instance(node_ip_address="127.0.0.1", @@ -302,188 +307,190 @@ def start_redis_instance(node_ip_address="127.0.0.1", stdout_file=None, stderr_file=None, cleanup=True): - """Start a single Redis server. + """Start a single Redis server. - Args: - node_ip_address (str): The IP address of the current node. This is only - used for recording the log filenames in Redis. - port (int): If provided, start a Redis server with this port. - num_retries (int): The number of times to attempt to start Redis. If a port - is provided, this defaults to 1. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by serices.cleanup() when the Python process - that imported services exits. + Args: + node_ip_address (str): The IP address of the current node. This is only + used for recording the log filenames in Redis. + port (int): If provided, start a Redis server with this port. + num_retries (int): The number of times to attempt to start Redis. If a + port is provided, this defaults to 1. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by serices.cleanup() when the + Python process that imported services exits. - Returns: - A tuple of the port used by Redis and a handle to the process that was - started. If a port is passed in, then the returned port value is the - same. + Returns: + A tuple of the port used by Redis and a handle to the process that was + started. If a port is passed in, then the returned port value is + the same. - Raises: - Exception: An exception is raised if Redis could not be started. - """ - redis_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "./core/src/common/thirdparty/redis/src/redis-server") - redis_module = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "./core/src/common/redis_module/libray_redis_module.so") - assert os.path.isfile(redis_filepath) - assert os.path.isfile(redis_module) - counter = 0 - if port is not None: - # If a port is specified, then try only once to connect. - num_retries = 1 - else: - port = new_port() - while counter < num_retries: - if counter > 0: - print("Redis failed to start, retrying now.") - p = subprocess.Popen([redis_filepath, - "--port", str(port), - "--loglevel", "warning", - "--loadmodule", redis_module], - stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - # Check if Redis successfully started (or at least if it the executable did - # not exit within 0.1 seconds). - if p.poll() is None: - if cleanup: - all_processes[PROCESS_TYPE_REDIS_SERVER].append(p) - break - port = new_port() - counter += 1 - if counter == num_retries: - raise Exception("Couldn't start Redis.") + Raises: + Exception: An exception is raised if Redis could not be started. + """ + redis_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "./core/src/common/thirdparty/redis/src/redis-server") + redis_module = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "./core/src/common/redis_module/libray_redis_module.so") + assert os.path.isfile(redis_filepath) + assert os.path.isfile(redis_module) + counter = 0 + if port is not None: + # If a port is specified, then try only once to connect. + num_retries = 1 + else: + port = new_port() + while counter < num_retries: + if counter > 0: + print("Redis failed to start, retrying now.") + p = subprocess.Popen([redis_filepath, + "--port", str(port), + "--loglevel", "warning", + "--loadmodule", redis_module], + stdout=stdout_file, stderr=stderr_file) + time.sleep(0.1) + # Check if Redis successfully started (or at least if it the executable + # did not exit within 0.1 seconds). + if p.poll() is None: + if cleanup: + all_processes[PROCESS_TYPE_REDIS_SERVER].append(p) + break + port = new_port() + counter += 1 + if counter == num_retries: + raise Exception("Couldn't start Redis.") - # Create a Redis client just for configuring Redis. - redis_client = redis.StrictRedis(host="127.0.0.1", port=port) - # Wait for the Redis server to start. - wait_for_redis_to_start("127.0.0.1", port) - # Configure Redis to generate keyspace notifications. TODO(rkn): Change this - # to only generate notifications for the export keys. - redis_client.config_set("notify-keyspace-events", "Kl") - # Configure Redis to not run in protected mode so that processes on other - # hosts can connect to it. TODO(rkn): Do this in a more secure way. - redis_client.config_set("protected-mode", "no") - # Increase the hard and soft limits for the redis client pubsub buffer to - # 128MB. This is a hack to make it less likely for pubsub messages to be - # dropped and for pubsub connections to therefore be killed. - cur_config = (redis_client.config_get("client-output-buffer-limit") - ["client-output-buffer-limit"]) - cur_config_list = cur_config.split() - assert len(cur_config_list) == 12 - cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"] - redis_client.config_set("client-output-buffer-limit", - " ".join(cur_config_list)) - # Put a time stamp in Redis to indicate when it was started. - redis_client.set("redis_start_time", time.time()) - # Record the log files in Redis. - record_log_files_in_redis(address(node_ip_address, port), node_ip_address, - [stdout_file, stderr_file]) - return port, p + # Create a Redis client just for configuring Redis. + redis_client = redis.StrictRedis(host="127.0.0.1", port=port) + # Wait for the Redis server to start. + wait_for_redis_to_start("127.0.0.1", port) + # Configure Redis to generate keyspace notifications. TODO(rkn): Change + # this to only generate notifications for the export keys. + redis_client.config_set("notify-keyspace-events", "Kl") + # Configure Redis to not run in protected mode so that processes on other + # hosts can connect to it. TODO(rkn): Do this in a more secure way. + redis_client.config_set("protected-mode", "no") + # Increase the hard and soft limits for the redis client pubsub buffer to + # 128MB. This is a hack to make it less likely for pubsub messages to be + # dropped and for pubsub connections to therefore be killed. + cur_config = (redis_client.config_get("client-output-buffer-limit") + ["client-output-buffer-limit"]) + cur_config_list = cur_config.split() + assert len(cur_config_list) == 12 + cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"] + redis_client.config_set("client-output-buffer-limit", + " ".join(cur_config_list)) + # Put a time stamp in Redis to indicate when it was started. + redis_client.set("redis_start_time", time.time()) + # Record the log files in Redis. + record_log_files_in_redis(address(node_ip_address, port), node_ip_address, + [stdout_file, stderr_file]) + return port, p def start_log_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, cleanup=cleanup): - """Start a log monitor process. + """Start a log monitor process. - Args: - redis_address (str): The address of the Redis instance. - node_ip_address (str): The IP address of the node that this log monitor is - running on. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by services.cleanup() when the Python process - that imported services exits. - """ - log_monitor_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "log_monitor.py") - p = subprocess.Popen(["python", log_monitor_filepath, - "--redis-address", redis_address, - "--node-ip-address", node_ip_address], - stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + Args: + redis_address (str): The address of the Redis instance. + node_ip_address (str): The IP address of the node that this log monitor + is running on. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by services.cleanup() when the + Python process that imported services exits. + """ + log_monitor_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "log_monitor.py") + p = subprocess.Popen(["python", log_monitor_filepath, + "--redis-address", redis_address, + "--node-ip-address", node_ip_address], + stdout=stdout_file, stderr=stderr_file) + if cleanup: + all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) + record_log_files_in_redis(redis_address, node_ip_address, + [stdout_file, stderr_file]) def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, stderr_file=None, cleanup=True): - """Start a global scheduler process. + """Start a global scheduler process. - Args: - redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will run - on. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by services.cleanup() when the Python process - that imported services exits. - """ - p = global_scheduler.start_global_scheduler(redis_address, - node_ip_address, - stdout_file=stdout_file, - stderr_file=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + Args: + redis_address (str): The address of the Redis instance. + node_ip_address: The IP address of the node that this scheduler will + run on. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by services.cleanup() when the + Python process that imported services exits. + """ + p = global_scheduler.start_global_scheduler(redis_address, + node_ip_address, + stdout_file=stdout_file, + stderr_file=stderr_file) + if cleanup: + all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) + record_log_files_in_redis(redis_address, node_ip_address, + [stdout_file, stderr_file]) def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): - """Start a UI process. + """Start a UI process. - Args: - redis_address: The address of the primary Redis shard. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by services.cleanup() when the Python process - that imported services exits. - """ - new_env = os.environ.copy() - notebook_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), - "WebUI.ipynb") - # We copy the notebook file so that the original doesn't get modified by the - # user. - random_ui_id = random.randint(0, 100000) - new_notebook_filepath = "/tmp/raylogs/ray_ui{}.ipynb".format(random_ui_id) - new_notebook_directory = os.path.dirname(new_notebook_filepath) - shutil.copy(notebook_filepath, new_notebook_filepath) - port = 8888 - new_env = os.environ.copy() - new_env["REDIS_ADDRESS"] = redis_address - command = ["jupyter", "notebook", "--no-browser", - "--port={}".format(port), - "--NotebookApp.iopub_data_rate_limit=10000000000", - "--NotebookApp.open_browser=False"] - try: - ui_process = subprocess.Popen(command, env=new_env, - cwd=new_notebook_directory, - stdout=stdout_file, stderr=stderr_file) - except: - print("Failed to start the UI, you may need to run 'pip install jupyter'.") - else: - if cleanup: - all_processes[PROCESS_TYPE_WEB_UI].append(ui_process) + Args: + redis_address: The address of the primary Redis shard. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by services.cleanup() when the + Python process that imported services exits. + """ + new_env = os.environ.copy() + notebook_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "WebUI.ipynb") + # We copy the notebook file so that the original doesn't get modified by + # the user. + random_ui_id = random.randint(0, 100000) + new_notebook_filepath = "/tmp/raylogs/ray_ui{}.ipynb".format(random_ui_id) + new_notebook_directory = os.path.dirname(new_notebook_filepath) + shutil.copy(notebook_filepath, new_notebook_filepath) + port = 8888 + new_env = os.environ.copy() + new_env["REDIS_ADDRESS"] = redis_address + command = ["jupyter", "notebook", "--no-browser", + "--port={}".format(port), + "--NotebookApp.iopub_data_rate_limit=10000000000", + "--NotebookApp.open_browser=False"] + try: + ui_process = subprocess.Popen(command, env=new_env, + cwd=new_notebook_directory, + stdout=stdout_file, stderr=stderr_file) + except: + print("Failed to start the UI, you may need to run " + "'pip install jupyter'.") + else: + if cleanup: + all_processes[PROCESS_TYPE_WEB_UI].append(ui_process) - print("View the web UI at http://localhost:{}/notebooks/ray_ui{}.ipynb" - .format(port, random_ui_id)) + print("View the web UI at http://localhost:{}/notebooks/ray_ui{}.ipynb" + .format(port, random_ui_id)) def start_local_scheduler(redis_address, @@ -498,58 +505,61 @@ def start_local_scheduler(redis_address, num_cpus=None, num_gpus=None, num_workers=0): - """Start a local scheduler process. + """Start a local scheduler process. - Args: - redis_address (str): The address of the Redis instance. - node_ip_address (str): The IP address of the node that this local scheduler - is running on. - plasma_store_name (str): The name of the plasma store socket to connect to. - plasma_manager_name (str): The name of the plasma manager socket to connect - to. - worker_path (str): The path of the script to use when the local scheduler - starts up new workers. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by serices.cleanup() when the Python process - that imported services exits. - num_cpus: The number of CPUs the local scheduler should be configured with. - num_gpus: The number of GPUs the local scheduler should be configured with. - num_workers (int): The number of workers that the local scheduler should - start. + Args: + redis_address (str): The address of the Redis instance. + node_ip_address (str): The IP address of the node that this local + scheduler is running on. + plasma_store_name (str): The name of the plasma store socket to connect + to. + plasma_manager_name (str): The name of the plasma manager socket to + connect to. + worker_path (str): The path of the script to use when the local + scheduler starts up new workers. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by serices.cleanup() when the + Python process that imported services exits. + num_cpus: The number of CPUs the local scheduler should be configured + with. + num_gpus: The number of GPUs the local scheduler should be configured + with. + num_workers (int): The number of workers that the local scheduler + should start. - Return: - The name of the local scheduler socket. - """ - if num_cpus is None: - # By default, use the number of hardware execution threads for the number - # of cores. - num_cpus = psutil.cpu_count() - if num_gpus is None: - # By default, assume this node has no GPUs. - num_gpus = 0 - print("Starting local scheduler with {} CPUs and {} GPUs.".format(num_cpus, - num_gpus)) - local_scheduler_name, p = ray.local_scheduler.start_local_scheduler( - plasma_store_name, - plasma_manager_name, - worker_path=worker_path, - node_ip_address=node_ip_address, - redis_address=redis_address, - plasma_address=plasma_address, - use_profiler=RUN_LOCAL_SCHEDULER_PROFILER, - stdout_file=stdout_file, - stderr_file=stderr_file, - static_resource_list=[num_cpus, num_gpus], - num_workers=num_workers) - if cleanup: - all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) - return local_scheduler_name + Return: + The name of the local scheduler socket. + """ + if num_cpus is None: + # By default, use the number of hardware execution threads for the + # number of cores. + num_cpus = psutil.cpu_count() + if num_gpus is None: + # By default, assume this node has no GPUs. + num_gpus = 0 + print("Starting local scheduler with {} CPUs and {} GPUs." + .format(num_cpus, num_gpus)) + local_scheduler_name, p = ray.local_scheduler.start_local_scheduler( + plasma_store_name, + plasma_manager_name, + worker_path=worker_path, + node_ip_address=node_ip_address, + redis_address=redis_address, + plasma_address=plasma_address, + use_profiler=RUN_LOCAL_SCHEDULER_PROFILER, + stdout_file=stdout_file, + stderr_file=stderr_file, + static_resource_list=[num_cpus, num_gpus], + num_workers=num_workers) + if cleanup: + all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) + record_log_files_in_redis(redis_address, node_ip_address, + [stdout_file, stderr_file]) + return local_scheduler_name def start_objstore(node_ip_address, redis_address, @@ -557,157 +567,166 @@ def start_objstore(node_ip_address, redis_address, store_stderr_file=None, manager_stdout_file=None, manager_stderr_file=None, cleanup=True, objstore_memory=None): - """This method starts an object store process. + """This method starts an object store process. - Args: - node_ip_address (str): The IP address of the node running the object store. - redis_address (str): The address of the Redis instance to connect to. - object_manager_port (int): The port to use for the object manager. If this - is not provided, one will be generated randomly. - store_stdout_file: A file handle opened for writing to redirect stdout to. - If no redirection should happen, then this should be None. - store_stderr_file: A file handle opened for writing to redirect stderr to. - If no redirection should happen, then this should be None. - manager_stdout_file: A file handle opened for writing to redirect stdout - to. If no redirection should happen, then this should be None. - manager_stderr_file: A file handle opened for writing to redirect stderr - to. If no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by serices.cleanup() when the Python process - that imported services exits. - objstore_memory: The amount of memory (in bytes) to start the object store - with. + Args: + node_ip_address (str): The IP address of the node running the object + store. + redis_address (str): The address of the Redis instance to connect to. + object_manager_port (int): The port to use for the object manager. If + this is not provided, one will be generated randomly. + store_stdout_file: A file handle opened for writing to redirect stdout + to. If no redirection should happen, then this should be None. + store_stderr_file: A file handle opened for writing to redirect stderr + to. If no redirection should happen, then this should be None. + manager_stdout_file: A file handle opened for writing to redirect + stdout to. If no redirection should happen, then this should be + None. + manager_stderr_file: A file handle opened for writing to redirect + stderr to. If no redirection should happen, then this should be + None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by serices.cleanup() when the + Python process that imported services exits. + objstore_memory: The amount of memory (in bytes) to start the object + store with. - Return: - A tuple of the Plasma store socket name, the Plasma manager socket name, - and the plasma manager port. - """ - if objstore_memory is None: - # Compute a fraction of the system memory for the Plasma store to use. - system_memory = psutil.virtual_memory().total - if sys.platform == "linux" or sys.platform == "linux2": - # On linux we use /dev/shm, its size is half the size of the physical - # memory. To not overflow it, we set the plasma memory limit to 0.4 times - # the size of the physical memory. - objstore_memory = int(system_memory * 0.4) - # Compare the requested memory size to the memory available in /dev/shm. - shm_fd = os.open("/dev/shm", os.O_RDONLY) - try: - shm_fs_stats = os.fstatvfs(shm_fd) - # The value shm_fs_stats.f_bsize is the block size and the value - # shm_fs_stats.f_bavail is the number of available blocks. - shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail - if objstore_memory > shm_avail: - print("Warning: Reducing object store memory because /dev/shm has " - "only {} bytes available. You may be able to free up space by " - "deleting files in /dev/shm. If you are inside a Docker " - "container, you may need to pass an argument with the flag " - "'--shm-size' to 'docker run'.".format(shm_avail)) - objstore_memory = int(shm_avail * 0.8) - finally: - os.close(shm_fd) + Return: + A tuple of the Plasma store socket name, the Plasma manager socket + name, and the plasma manager port. + """ + if objstore_memory is None: + # Compute a fraction of the system memory for the Plasma store to use. + system_memory = psutil.virtual_memory().total + if sys.platform == "linux" or sys.platform == "linux2": + # On linux we use /dev/shm, its size is half the size of the + # physical memory. To not overflow it, we set the plasma memory + # limit to 0.4 times the size of the physical memory. + objstore_memory = int(system_memory * 0.4) + # Compare the requested memory size to the memory available in + # /dev/shm. + shm_fd = os.open("/dev/shm", os.O_RDONLY) + try: + shm_fs_stats = os.fstatvfs(shm_fd) + # The value shm_fs_stats.f_bsize is the block size and the + # value shm_fs_stats.f_bavail is the number of available + # blocks. + shm_avail = shm_fs_stats.f_bsize * shm_fs_stats.f_bavail + if objstore_memory > shm_avail: + print("Warning: Reducing object store memory because " + "/dev/shm has only {} bytes available. You may be " + "able to free up space by deleting files in " + "/dev/shm. If you are inside a Docker container, " + "you may need to pass an argument with the flag " + "'--shm-size' to 'docker run'.".format(shm_avail)) + objstore_memory = int(shm_avail * 0.8) + finally: + os.close(shm_fd) + else: + objstore_memory = int(system_memory * 0.8) + # Start the Plasma store. + plasma_store_name, p1 = ray.plasma.start_plasma_store( + plasma_store_memory=objstore_memory, + use_profiler=RUN_PLASMA_STORE_PROFILER, + stdout_file=store_stdout_file, + stderr_file=store_stderr_file) + # Start the plasma manager. + if object_manager_port is not None: + (plasma_manager_name, p2, + plasma_manager_port) = ray.plasma.start_plasma_manager( + plasma_store_name, + redis_address, + plasma_manager_port=object_manager_port, + node_ip_address=node_ip_address, + num_retries=1, + run_profiler=RUN_PLASMA_MANAGER_PROFILER, + stdout_file=manager_stdout_file, + stderr_file=manager_stderr_file) + assert plasma_manager_port == object_manager_port else: - objstore_memory = int(system_memory * 0.8) - # Start the Plasma store. - plasma_store_name, p1 = ray.plasma.start_plasma_store( - plasma_store_memory=objstore_memory, - use_profiler=RUN_PLASMA_STORE_PROFILER, - stdout_file=store_stdout_file, - stderr_file=store_stderr_file) - # Start the plasma manager. - if object_manager_port is not None: - (plasma_manager_name, p2, - plasma_manager_port) = ray.plasma.start_plasma_manager( - plasma_store_name, - redis_address, - plasma_manager_port=object_manager_port, - node_ip_address=node_ip_address, - num_retries=1, - run_profiler=RUN_PLASMA_MANAGER_PROFILER, - stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) - assert plasma_manager_port == object_manager_port - else: - (plasma_manager_name, p2, - plasma_manager_port) = ray.plasma.start_plasma_manager( - plasma_store_name, - redis_address, - node_ip_address=node_ip_address, - run_profiler=RUN_PLASMA_MANAGER_PROFILER, - stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) - all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) - record_log_files_in_redis(redis_address, node_ip_address, - [store_stdout_file, store_stderr_file, - manager_stdout_file, manager_stderr_file]) + (plasma_manager_name, p2, + plasma_manager_port) = ray.plasma.start_plasma_manager( + plasma_store_name, + redis_address, + node_ip_address=node_ip_address, + run_profiler=RUN_PLASMA_MANAGER_PROFILER, + stdout_file=manager_stdout_file, + stderr_file=manager_stderr_file) + if cleanup: + all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) + all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) + record_log_files_in_redis(redis_address, node_ip_address, + [store_stdout_file, store_stderr_file, + manager_stdout_file, manager_stderr_file]) - return ObjectStoreAddress(plasma_store_name, plasma_manager_name, - plasma_manager_port) + return ObjectStoreAddress(plasma_store_name, plasma_manager_name, + plasma_manager_port) def start_worker(node_ip_address, object_store_name, object_store_manager_name, local_scheduler_name, redis_address, worker_path, stdout_file=None, stderr_file=None, cleanup=True): - """This method starts a worker process. + """This method starts a worker process. - Args: - node_ip_address (str): The IP address of the node that this worker is - running on. - object_store_name (str): The name of the object store. - object_store_manager_name (str): The name of the object store manager. - local_scheduler_name (str): The name of the local scheduler. - redis_address (str): The address that the Redis server is listening on. - worker_path (str): The path of the source code which the worker process - will run. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by services.cleanup() when the Python process - that imported services exits. This is True by default. - """ - command = ["python", - worker_path, - "--node-ip-address=" + node_ip_address, - "--object-store-name=" + object_store_name, - "--object-store-manager-name=" + object_store_manager_name, - "--local-scheduler-name=" + local_scheduler_name, - "--redis-address=" + str(redis_address)] - p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_WORKER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + Args: + node_ip_address (str): The IP address of the node that this worker is + running on. + object_store_name (str): The name of the object store. + object_store_manager_name (str): The name of the object store manager. + local_scheduler_name (str): The name of the local scheduler. + redis_address (str): The address that the Redis server is listening on. + worker_path (str): The path of the source code which the worker process + will run. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by services.cleanup() when the + Python process that imported services exits. This is True by + default. + """ + command = ["python", + worker_path, + "--node-ip-address=" + node_ip_address, + "--object-store-name=" + object_store_name, + "--object-store-manager-name=" + object_store_manager_name, + "--local-scheduler-name=" + local_scheduler_name, + "--redis-address=" + str(redis_address)] + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + if cleanup: + all_processes[PROCESS_TYPE_WORKER].append(p) + record_log_files_in_redis(redis_address, node_ip_address, + [stdout_file, stderr_file]) def start_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, cleanup=True): - """Run a process to monitor the other processes. + """Run a process to monitor the other processes. - Args: - redis_address (str): The address that the Redis server is listening on. - node_ip_address: The IP address of the node that this process will run on. - stdout_file: A file handle opened for writing to redirect stdout to. If no - redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If no - redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, then - this process will be killed by services.cleanup() when the Python process - that imported services exits. This is True by default. - """ - monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), - "monitor.py") - command = ["python", - monitor_path, - "--redis-address=" + str(redis_address)] - p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_WORKER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + Args: + redis_address (str): The address that the Redis server is listening on. + node_ip_address: The IP address of the node that this process will run + on. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, + then this process will be killed by services.cleanup() when the + Python process that imported services exits. This is True by + default. + """ + monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "monitor.py") + command = ["python", + monitor_path, + "--redis-address=" + str(redis_address)] + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + if cleanup: + all_processes[PROCESS_TYPE_WORKER].append(p) + record_log_files_in_redis(redis_address, node_ip_address, + [stdout_file, stderr_file]) def start_ray_processes(address_info=None, @@ -725,224 +744,230 @@ def start_ray_processes(address_info=None, start_workers_from_local_scheduler=True, num_cpus=None, num_gpus=None): - """Helper method to start Ray processes. + """Helper method to start Ray processes. - Args: - address_info (dict): A dictionary with address information for processes - that have already been started. If provided, address_info will be - modified to include processes that are newly started. - node_ip_address (str): The IP address of this node. - redis_port (int): The port that the primary Redis shard should listen to. - If None, then a random port will be chosen. If the key "redis_address" is - in address_info, then this argument will be ignored. - num_workers (int): The number of workers to start. - num_local_schedulers (int): The total number of local schedulers required. - This is also the total number of object stores required. This method will - start new instances of local schedulers and object stores until there are - num_local_schedulers existing instances of each, including ones already - registered with the given address_info. - num_redis_shards: The number of Redis shards to start in addition to the - primary Redis shard. - worker_path (str): The path of the source code that will be run by the - worker. - cleanup (bool): If cleanup is true, then the processes started here will be - killed by services.cleanup() when the Python process that called this - method exits. - redirect_output (bool): True if stdout and stderr should be redirected to a - file. - include_global_scheduler (bool): If include_global_scheduler is True, then - start a global scheduler process. - include_log_monitor (bool): If True, then start a log monitor to monitor - the log files for all processes on this node and push their contents to - Redis. - include_webui (bool): If True, then attempt to start the web UI. Note that - this is only possible with Python 3. - start_workers_from_local_scheduler (bool): If this flag is True, then start - the initial workers from the local scheduler. Else, start them from - Python. - num_cpus: A list of length num_local_schedulers containing the number of - CPUs each local scheduler should be configured with. - num_gpus: A list of length num_local_schedulers containing the number of - GPUs each local scheduler should be configured with. + Args: + address_info (dict): A dictionary with address information for + processes that have already been started. If provided, address_info + will be modified to include processes that are newly started. + node_ip_address (str): The IP address of this node. + redis_port (int): The port that the primary Redis shard should listen + to. If None, then a random port will be chosen. If the key + "redis_address" is in address_info, then this argument will be + ignored. + num_workers (int): The number of workers to start. + num_local_schedulers (int): The total number of local schedulers + required. This is also the total number of object stores required. + This method will start new instances of local schedulers and object + stores until there are num_local_schedulers existing instances of + each, including ones already registered with the given + address_info. + num_redis_shards: The number of Redis shards to start in addition to + the primary Redis shard. + worker_path (str): The path of the source code that will be run by the + worker. + cleanup (bool): If cleanup is true, then the processes started here + will be killed by services.cleanup() when the Python process that + called this method exits. + redirect_output (bool): True if stdout and stderr should be redirected + to a file. + include_global_scheduler (bool): If include_global_scheduler is True, + then start a global scheduler process. + include_log_monitor (bool): If True, then start a log monitor to + monitor the log files for all processes on this node and push their + contents to Redis. + include_webui (bool): If True, then attempt to start the web UI. Note + that this is only possible with Python 3. + start_workers_from_local_scheduler (bool): If this flag is True, then + start the initial workers from the local scheduler. Else, start + them from Python. + num_cpus: A list of length num_local_schedulers containing the number + of CPUs each local scheduler should be configured with. + num_gpus: A list of length num_local_schedulers containing the number + of GPUs each local scheduler should be configured with. - Returns: - A dictionary of the address information for the processes that were - started. - """ - if not isinstance(num_cpus, list): - num_cpus = num_local_schedulers * [num_cpus] - if not isinstance(num_gpus, list): - num_gpus = num_local_schedulers * [num_gpus] - assert len(num_cpus) == num_local_schedulers - assert len(num_gpus) == num_local_schedulers + Returns: + A dictionary of the address information for the processes that were + started. + """ + if not isinstance(num_cpus, list): + num_cpus = num_local_schedulers * [num_cpus] + if not isinstance(num_gpus, list): + num_gpus = num_local_schedulers * [num_gpus] + assert len(num_cpus) == num_local_schedulers + assert len(num_gpus) == num_local_schedulers - if num_workers is not None: - workers_per_local_scheduler = num_local_schedulers * [num_workers] - else: - workers_per_local_scheduler = [] - for cpus in num_cpus: - workers_per_local_scheduler.append(cpus if cpus is not None - else psutil.cpu_count()) - - if address_info is None: - address_info = {} - address_info["node_ip_address"] = node_ip_address - - if worker_path is None: - worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), - "workers/default_worker.py") - - # Start Redis if there isn't already an instance running. TODO(rkn): We are - # suppressing the output of Redis because on Linux it prints a bunch of - # warning messages when it starts up. Instead of suppressing the output, we - # should address the warnings. - redis_address = address_info.get("redis_address") - redis_shards = address_info.get("redis_shards", []) - if redis_address is None: - redis_address, redis_shards = start_redis( - node_ip_address, port=redis_port, num_redis_shards=num_redis_shards, - redirect_output=redirect_output, cleanup=cleanup) - address_info["redis_address"] = redis_address - time.sleep(0.1) - - # Start monitoring the processes. - monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", - redirect_output) - start_monitor(redis_address, - node_ip_address, - stdout_file=monitor_stdout_file, - stderr_file=monitor_stderr_file) - - if redis_shards == []: - # Get redis shards from primary redis instance. - redis_ip_address, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) - redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) - redis_shards = [shard.decode("ascii") for shard in redis_shards] - address_info["redis_shards"] = redis_shards - - # Start the log monitor, if necessary. - if include_log_monitor: - log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( - "log_monitor", redirect_output=True) - start_log_monitor(redis_address, - node_ip_address, - stdout_file=log_monitor_stdout_file, - stderr_file=log_monitor_stderr_file, - cleanup=cleanup) - - # Start the global scheduler, if necessary. - if include_global_scheduler: - global_scheduler_stdout_file, global_scheduler_stderr_file = new_log_files( - "global_scheduler", redirect_output) - start_global_scheduler(redis_address, - node_ip_address, - stdout_file=global_scheduler_stdout_file, - stderr_file=global_scheduler_stderr_file, - cleanup=cleanup) - - # Initialize with existing services. - if "object_store_addresses" not in address_info: - address_info["object_store_addresses"] = [] - object_store_addresses = address_info["object_store_addresses"] - if "local_scheduler_socket_names" not in address_info: - address_info["local_scheduler_socket_names"] = [] - local_scheduler_socket_names = address_info["local_scheduler_socket_names"] - - # Get the ports to use for the object managers if any are provided. - object_manager_ports = (address_info["object_manager_ports"] - if "object_manager_ports" in address_info else None) - if not isinstance(object_manager_ports, list): - object_manager_ports = num_local_schedulers * [object_manager_ports] - assert len(object_manager_ports) == num_local_schedulers - - # Start any object stores that do not yet exist. - for i in range(num_local_schedulers - len(object_store_addresses)): - # Start Plasma. - plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( - "plasma_store_{}".format(i), redirect_output) - plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( - "plasma_manager_{}".format(i), redirect_output) - object_store_address = start_objstore( - node_ip_address, - redis_address, - object_manager_port=object_manager_ports[i], - store_stdout_file=plasma_store_stdout_file, - store_stderr_file=plasma_store_stderr_file, - manager_stdout_file=plasma_manager_stdout_file, - manager_stderr_file=plasma_manager_stderr_file, - cleanup=cleanup) - object_store_addresses.append(object_store_address) - time.sleep(0.1) - - # Start any local schedulers that do not yet exist. - for i in range(len(local_scheduler_socket_names), num_local_schedulers): - # Connect the local scheduler to the object store at the same index. - object_store_address = object_store_addresses[i] - plasma_address = "{}:{}".format(node_ip_address, - object_store_address.manager_port) - # Determine how many workers this local scheduler should start. - if start_workers_from_local_scheduler: - num_local_scheduler_workers = workers_per_local_scheduler[i] - workers_per_local_scheduler[i] = 0 + if num_workers is not None: + workers_per_local_scheduler = num_local_schedulers * [num_workers] else: - # If we're starting the workers from Python, the local scheduler should - # not start any workers. - num_local_scheduler_workers = 0 - # Start the local scheduler. - local_scheduler_stdout_file, local_scheduler_stderr_file = new_log_files( - "local_scheduler_{}".format(i), redirect_output) - local_scheduler_name = start_local_scheduler( - redis_address, - node_ip_address, - object_store_address.name, - object_store_address.manager_name, - worker_path, - plasma_address=plasma_address, - stdout_file=local_scheduler_stdout_file, - stderr_file=local_scheduler_stderr_file, - cleanup=cleanup, - num_cpus=num_cpus[i], - num_gpus=num_gpus[i], - num_workers=num_local_scheduler_workers) - local_scheduler_socket_names.append(local_scheduler_name) - time.sleep(0.1) + workers_per_local_scheduler = [] + for cpus in num_cpus: + workers_per_local_scheduler.append(cpus if cpus is not None + else psutil.cpu_count()) - # Make sure that we have exactly num_local_schedulers instances of object - # stores and local schedulers. - assert len(object_store_addresses) == num_local_schedulers - assert len(local_scheduler_socket_names) == num_local_schedulers + if address_info is None: + address_info = {} + address_info["node_ip_address"] = node_ip_address - # Start any workers that the local scheduler has not already started. - for i, num_local_scheduler_workers in enumerate(workers_per_local_scheduler): - object_store_address = object_store_addresses[i] - local_scheduler_name = local_scheduler_socket_names[i] - for j in range(num_local_scheduler_workers): - worker_stdout_file, worker_stderr_file = new_log_files( - "worker_{}_{}".format(i, j), redirect_output) - start_worker(node_ip_address, - object_store_address.name, - object_store_address.manager_name, - local_scheduler_name, - redis_address, - worker_path, - stdout_file=worker_stdout_file, - stderr_file=worker_stderr_file, - cleanup=cleanup) - workers_per_local_scheduler[i] -= 1 + if worker_path is None: + worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "workers/default_worker.py") - # Make sure that we've started all the workers. - assert(sum(workers_per_local_scheduler) == 0) + # Start Redis if there isn't already an instance running. TODO(rkn): We are + # suppressing the output of Redis because on Linux it prints a bunch of + # warning messages when it starts up. Instead of suppressing the output, we + # should address the warnings. + redis_address = address_info.get("redis_address") + redis_shards = address_info.get("redis_shards", []) + if redis_address is None: + redis_address, redis_shards = start_redis( + node_ip_address, port=redis_port, + num_redis_shards=num_redis_shards, + redirect_output=redirect_output, cleanup=cleanup) + address_info["redis_address"] = redis_address + time.sleep(0.1) - # Try to start the web UI. - if include_webui: - ui_stdout_file, ui_stderr_file = new_log_files( - "webui", redirect_output=True) - start_ui(redis_address, stdout_file=ui_stdout_file, - stderr_file=ui_stderr_file, cleanup=cleanup) + # Start monitoring the processes. + monitor_stdout_file, monitor_stderr_file = new_log_files( + "monitor", redirect_output) + start_monitor(redis_address, + node_ip_address, + stdout_file=monitor_stdout_file, + stderr_file=monitor_stderr_file) - # Return the addresses of the relevant processes. - return address_info + if redis_shards == []: + # Get redis shards from primary redis instance. + redis_ip_address, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_ip_address, + port=redis_port) + redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) + redis_shards = [shard.decode("ascii") for shard in redis_shards] + address_info["redis_shards"] = redis_shards + + # Start the log monitor, if necessary. + if include_log_monitor: + log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( + "log_monitor", redirect_output=True) + start_log_monitor(redis_address, + node_ip_address, + stdout_file=log_monitor_stdout_file, + stderr_file=log_monitor_stderr_file, + cleanup=cleanup) + + # Start the global scheduler, if necessary. + if include_global_scheduler: + global_scheduler_stdout_file, global_scheduler_stderr_file = ( + new_log_files("global_scheduler", redirect_output)) + start_global_scheduler(redis_address, + node_ip_address, + stdout_file=global_scheduler_stdout_file, + stderr_file=global_scheduler_stderr_file, + cleanup=cleanup) + + # Initialize with existing services. + if "object_store_addresses" not in address_info: + address_info["object_store_addresses"] = [] + object_store_addresses = address_info["object_store_addresses"] + if "local_scheduler_socket_names" not in address_info: + address_info["local_scheduler_socket_names"] = [] + local_scheduler_socket_names = address_info["local_scheduler_socket_names"] + + # Get the ports to use for the object managers if any are provided. + object_manager_ports = (address_info["object_manager_ports"] + if "object_manager_ports" in address_info + else None) + if not isinstance(object_manager_ports, list): + object_manager_ports = num_local_schedulers * [object_manager_ports] + assert len(object_manager_ports) == num_local_schedulers + + # Start any object stores that do not yet exist. + for i in range(num_local_schedulers - len(object_store_addresses)): + # Start Plasma. + plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( + "plasma_store_{}".format(i), redirect_output) + plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( + "plasma_manager_{}".format(i), redirect_output) + object_store_address = start_objstore( + node_ip_address, + redis_address, + object_manager_port=object_manager_ports[i], + store_stdout_file=plasma_store_stdout_file, + store_stderr_file=plasma_store_stderr_file, + manager_stdout_file=plasma_manager_stdout_file, + manager_stderr_file=plasma_manager_stderr_file, + cleanup=cleanup) + object_store_addresses.append(object_store_address) + time.sleep(0.1) + + # Start any local schedulers that do not yet exist. + for i in range(len(local_scheduler_socket_names), num_local_schedulers): + # Connect the local scheduler to the object store at the same index. + object_store_address = object_store_addresses[i] + plasma_address = "{}:{}".format(node_ip_address, + object_store_address.manager_port) + # Determine how many workers this local scheduler should start. + if start_workers_from_local_scheduler: + num_local_scheduler_workers = workers_per_local_scheduler[i] + workers_per_local_scheduler[i] = 0 + else: + # If we're starting the workers from Python, the local scheduler + # should not start any workers. + num_local_scheduler_workers = 0 + # Start the local scheduler. + local_scheduler_stdout_file, local_scheduler_stderr_file = ( + new_log_files("local_scheduler_{}".format(i), redirect_output)) + local_scheduler_name = start_local_scheduler( + redis_address, + node_ip_address, + object_store_address.name, + object_store_address.manager_name, + worker_path, + plasma_address=plasma_address, + stdout_file=local_scheduler_stdout_file, + stderr_file=local_scheduler_stderr_file, + cleanup=cleanup, + num_cpus=num_cpus[i], + num_gpus=num_gpus[i], + num_workers=num_local_scheduler_workers) + local_scheduler_socket_names.append(local_scheduler_name) + time.sleep(0.1) + + # Make sure that we have exactly num_local_schedulers instances of object + # stores and local schedulers. + assert len(object_store_addresses) == num_local_schedulers + assert len(local_scheduler_socket_names) == num_local_schedulers + + # Start any workers that the local scheduler has not already started. + for i, num_local_scheduler_workers in enumerate( + workers_per_local_scheduler): + object_store_address = object_store_addresses[i] + local_scheduler_name = local_scheduler_socket_names[i] + for j in range(num_local_scheduler_workers): + worker_stdout_file, worker_stderr_file = new_log_files( + "worker_{}_{}".format(i, j), redirect_output) + start_worker(node_ip_address, + object_store_address.name, + object_store_address.manager_name, + local_scheduler_name, + redis_address, + worker_path, + stdout_file=worker_stdout_file, + stderr_file=worker_stderr_file, + cleanup=cleanup) + workers_per_local_scheduler[i] -= 1 + + # Make sure that we've started all the workers. + assert(sum(workers_per_local_scheduler) == 0) + + # Try to start the web UI. + if include_webui: + ui_stdout_file, ui_stderr_file = new_log_files( + "webui", redirect_output=True) + start_ui(redis_address, stdout_file=ui_stdout_file, + stderr_file=ui_stderr_file, cleanup=cleanup) + + # Return the addresses of the relevant processes. + return address_info def start_ray_node(node_ip_address, @@ -955,44 +980,45 @@ def start_ray_node(node_ip_address, redirect_output=False, num_cpus=None, num_gpus=None): - """Start the Ray processes for a single node. + """Start the Ray processes for a single node. - This assumes that the Ray processes on some master node have already been - started. + This assumes that the Ray processes on some master node have already been + started. - Args: - node_ip_address (str): The IP address of this node. - redis_address (str): The address of the Redis server. - object_manager_ports (list): A list of the ports to use for the object - managers. There should be one per object manager being started on this - node (typically just one). - num_workers (int): The number of workers to start. - num_local_schedulers (int): The number of local schedulers to start. This - is also the number of plasma stores and plasma managers to start. - worker_path (str): The path of the source code that will be run by the - worker. - cleanup (bool): If cleanup is true, then the processes started here will be - killed by services.cleanup() when the Python process that called this - method exits. - redirect_output (bool): True if stdout and stderr should be redirected to a - file. + Args: + node_ip_address (str): The IP address of this node. + redis_address (str): The address of the Redis server. + object_manager_ports (list): A list of the ports to use for the object + managers. There should be one per object manager being started on + this node (typically just one). + num_workers (int): The number of workers to start. + num_local_schedulers (int): The number of local schedulers to start. + This is also the number of plasma stores and plasma managers to + start. + worker_path (str): The path of the source code that will be run by the + worker. + cleanup (bool): If cleanup is true, then the processes started here + will be killed by services.cleanup() when the Python process that + called this method exits. + redirect_output (bool): True if stdout and stderr should be redirected + to a file. - Returns: - A dictionary of the address information for the processes that were - started. - """ - address_info = {"redis_address": redis_address, - "object_manager_ports": object_manager_ports} - return start_ray_processes(address_info=address_info, - node_ip_address=node_ip_address, - num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - worker_path=worker_path, - include_log_monitor=True, - cleanup=cleanup, - redirect_output=redirect_output, - num_cpus=num_cpus, - num_gpus=num_gpus) + Returns: + A dictionary of the address information for the processes that were + started. + """ + address_info = {"redis_address": redis_address, + "object_manager_ports": object_manager_ports} + return start_ray_processes(address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + worker_path=worker_path, + include_log_monitor=True, + cleanup=cleanup, + redirect_output=redirect_output, + num_cpus=num_cpus, + num_gpus=num_gpus) def start_ray_head(address_info=None, @@ -1007,106 +1033,108 @@ def start_ray_head(address_info=None, num_cpus=None, num_gpus=None, num_redis_shards=None): - """Start Ray in local mode. + """Start Ray in local mode. - Args: - address_info (dict): A dictionary with address information for processes - that have already been started. If provided, address_info will be - modified to include processes that are newly started. - node_ip_address (str): The IP address of this node. - redis_port (int): The port that the primary Redis shard should listen to. - If None, then a random port will be chosen. If the key "redis_address" is - in address_info, then this argument will be ignored. - num_workers (int): The number of workers to start. - num_local_schedulers (int): The total number of local schedulers required. - This is also the total number of object stores required. This method will - start new instances of local schedulers and object stores until there are - at least num_local_schedulers existing instances of each, including ones - already registered with the given address_info. - worker_path (str): The path of the source code that will be run by the - worker. - cleanup (bool): If cleanup is true, then the processes started here will be - killed by services.cleanup() when the Python process that called this - method exits. - redirect_output (bool): True if stdout and stderr should be redirected to a - file. - start_workers_from_local_scheduler (bool): If this flag is True, then start - the initial workers from the local scheduler. Else, start them from - Python. - num_cpus (int): number of cpus to configure the local scheduler with. - num_gpus (int): number of gpus to configure the local scheduler with. - num_redis_shards: The number of Redis shards to start in addition to the - primary Redis shard. + Args: + address_info (dict): A dictionary with address information for + processes that have already been started. If provided, address_info + will be modified to include processes that are newly started. + node_ip_address (str): The IP address of this node. + redis_port (int): The port that the primary Redis shard should listen + to. If None, then a random port will be chosen. If the key + "redis_address" is in address_info, then this argument will be + ignored. + num_workers (int): The number of workers to start. + num_local_schedulers (int): The total number of local schedulers + required. This is also the total number of object stores required. + This method will start new instances of local schedulers and object + stores until there are at least num_local_schedulers existing + instances of each, including ones already registered with the given + address_info. + worker_path (str): The path of the source code that will be run by the + worker. + cleanup (bool): If cleanup is true, then the processes started here + will be killed by services.cleanup() when the Python process that + called this method exits. + redirect_output (bool): True if stdout and stderr should be redirected + to a file. + start_workers_from_local_scheduler (bool): If this flag is True, then + start the initial workers from the local scheduler. Else, start + them from Python. + num_cpus (int): number of cpus to configure the local scheduler with. + num_gpus (int): number of gpus to configure the local scheduler with. + num_redis_shards: The number of Redis shards to start in addition to + the primary Redis shard. - Returns: - A dictionary of the address information for the processes that were - started. - """ - num_redis_shards = 1 if num_redis_shards is None else num_redis_shards - return start_ray_processes( - address_info=address_info, - node_ip_address=node_ip_address, - redis_port=redis_port, - num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - worker_path=worker_path, - cleanup=cleanup, - redirect_output=redirect_output, - include_global_scheduler=True, - include_log_monitor=True, - include_webui=True, - start_workers_from_local_scheduler=start_workers_from_local_scheduler, - num_cpus=num_cpus, - num_gpus=num_gpus, - num_redis_shards=num_redis_shards) + Returns: + A dictionary of the address information for the processes that were + started. + """ + num_redis_shards = 1 if num_redis_shards is None else num_redis_shards + return start_ray_processes( + address_info=address_info, + node_ip_address=node_ip_address, + redis_port=redis_port, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + worker_path=worker_path, + cleanup=cleanup, + redirect_output=redirect_output, + include_global_scheduler=True, + include_log_monitor=True, + include_webui=True, + start_workers_from_local_scheduler=start_workers_from_local_scheduler, + num_cpus=num_cpus, + num_gpus=num_gpus, + num_redis_shards=num_redis_shards) def try_to_create_directory(directory_path): - """Attempt to create a directory that is globally readable/writable. + """Attempt to create a directory that is globally readable/writable. - Args: - directory_path: The path of the directory to create. - """ - if not os.path.exists(directory_path): - try: - os.makedirs(directory_path) - except OSError as e: - if e.errno != os.errno.EEXIST: - raise e - print("Attempted to create '{}', but the directory already " - "exists.".format(directory_path)) - # Change the log directory permissions so others can use it. This is - # important when multiple people are using the same machine. - os.chmod(directory_path, 0o0777) + Args: + directory_path: The path of the directory to create. + """ + if not os.path.exists(directory_path): + try: + os.makedirs(directory_path) + except OSError as e: + if e.errno != os.errno.EEXIST: + raise e + print("Attempted to create '{}', but the directory already " + "exists.".format(directory_path)) + # Change the log directory permissions so others can use it. This is + # important when multiple people are using the same machine. + os.chmod(directory_path, 0o0777) def new_log_files(name, redirect_output): - """Generate partially randomized filenames for log files. + """Generate partially randomized filenames for log files. - Args: - name (str): descriptive string for this log file. - redirect_output (bool): True if files should be generated for logging - stdout and stderr and false if stdout and stderr should not be - redirected. + Args: + name (str): descriptive string for this log file. + redirect_output (bool): True if files should be generated for logging + stdout and stderr and false if stdout and stderr should not be + redirected. - Returns: - If redirect_output is true, this will return a tuple of two filehandles. - The first is for redirecting stdout and the second is for redirecting - stderr. If redirect_output is false, this will return a tuple of two None - objects. - """ - if not redirect_output: - return None, None + Returns: + If redirect_output is true, this will return a tuple of two + filehandles. The first is for redirecting stdout and the second is + for redirecting stderr. If redirect_output is false, this will + return a tuple of two None objects. + """ + if not redirect_output: + return None, None - # Create a directory to be used for process log files. - logs_dir = "/tmp/raylogs" - try_to_create_directory(logs_dir) - # Create another directory that will be used by some of the RL algorithms. - try_to_create_directory("/tmp/ray") + # Create a directory to be used for process log files. + logs_dir = "/tmp/raylogs" + try_to_create_directory(logs_dir) + # Create another directory that will be used by some of the RL algorithms. + try_to_create_directory("/tmp/ray") - log_id = random.randint(0, 1000000000) - log_stdout = "{}/{}-{:010d}.out".format(logs_dir, name, log_id) - log_stderr = "{}/{}-{:010d}.err".format(logs_dir, name, log_id) - log_stdout_file = open(log_stdout, "a") - log_stderr_file = open(log_stderr, "a") - return log_stdout_file, log_stderr_file + log_id = random.randint(0, 1000000000) + log_stdout = "{}/{}-{:010d}.out".format(logs_dir, name, log_id) + log_stderr = "{}/{}-{:010d}.err".format(logs_dir, name, log_id) + log_stdout_file = open(log_stdout, "a") + log_stderr_file = open(log_stderr, "a") + return log_stdout_file, log_stderr_file diff --git a/python/ray/signature.py b/python/ray/signature.py index 4c28a024b..6c557fe03 100644 --- a/python/ray/signature.py +++ b/python/ray/signature.py @@ -13,160 +13,165 @@ FunctionSignature = namedtuple("FunctionSignature", ["arg_names", """This class is used to represent a function signature. Attributes: - keyword_names: The names of the functions keyword arguments. This is used to - test if an incorrect keyword argument has been passed to the function. - arg_defaults: A dictionary mapping from argument name to argument default - value. If the argument is not a keyword argument, the default value will be - funcsigs._empty. - arg_is_positionals: A dictionary mapping from argument name to a bool. The - bool will be true if the argument is a *args argument. Otherwise it will be - false. - function_name: The name of the function whose signature is being inspected. - This is used for printing better error messages. + keyword_names: The names of the functions keyword arguments. This is used + to test if an incorrect keyword argument has been passed to the + function. + arg_defaults: A dictionary mapping from argument name to argument default + value. If the argument is not a keyword argument, the default value + will be funcsigs._empty. + arg_is_positionals: A dictionary mapping from argument name to a bool. The + bool will be true if the argument is a *args argument. Otherwise it + will be false. + function_name: The name of the function whose signature is being + inspected. This is used for printing better error messages. """ def check_signature_supported(func, warn=False): - """Check if we support the signature of this function. + """Check if we support the signature of this function. - We currently do not allow remote functions to have **kwargs. We also do not - support keyword arguments in conjunction with a *args argument. + We currently do not allow remote functions to have **kwargs. We also do not + support keyword arguments in conjunction with a *args argument. - Args: - func: The function whose signature should be checked. - warn: If this is true, a warning will be printed if the signature is not - supported. If it is false, an exception will be raised if the signature - is not supported. + Args: + func: The function whose signature should be checked. + warn: If this is true, a warning will be printed if the signature is + not supported. If it is false, an exception will be raised if the + signature is not supported. - Raises: - Exception: An exception is raised if the signature is not supported. - """ - function_name = func.__name__ - sig_params = [(k, v) for k, v - in funcsigs.signature(func).parameters.items()] + Raises: + Exception: An exception is raised if the signature is not supported. + """ + function_name = func.__name__ + sig_params = [(k, v) for k, v + in funcsigs.signature(func).parameters.items()] - has_vararg_param = False - has_kwargs_param = False - has_keyword_arg = False - for keyword_name, parameter in sig_params: - if parameter.kind == parameter.VAR_KEYWORD: - has_kwargs_param = True - if parameter.kind == parameter.VAR_POSITIONAL: - has_vararg_param = True - if parameter.default != funcsigs._empty: - has_keyword_arg = True + has_vararg_param = False + has_kwargs_param = False + has_keyword_arg = False + for keyword_name, parameter in sig_params: + if parameter.kind == parameter.VAR_KEYWORD: + has_kwargs_param = True + if parameter.kind == parameter.VAR_POSITIONAL: + has_vararg_param = True + if parameter.default != funcsigs._empty: + has_keyword_arg = True - if has_kwargs_param: - message = ("The function {} has a **kwargs argument, which is " - "currently not supported.".format(function_name)) - if warn: - print(message) - else: - raise Exception(message) - # Check if the user specified a variable number of arguments and any keyword - # arguments. - if has_vararg_param and has_keyword_arg: - message = ("Function {} has a *args argument as well as a keyword " - "argument, which is currently not supported." - .format(function_name)) - if warn: - print(message) - else: - raise Exception(message) + if has_kwargs_param: + message = ("The function {} has a **kwargs argument, which is " + "currently not supported.".format(function_name)) + if warn: + print(message) + else: + raise Exception(message) + # Check if the user specified a variable number of arguments and any + # keyword arguments. + if has_vararg_param and has_keyword_arg: + message = ("Function {} has a *args argument as well as a keyword " + "argument, which is currently not supported." + .format(function_name)) + if warn: + print(message) + else: + raise Exception(message) def extract_signature(func, ignore_first=False): - """Extract the function signature from the function. + """Extract the function signature from the function. - Args: - func: The function whose signature should be extracted. - ignore_first: True if the first argument should be ignored. This should be - used when func is a method of a class. + Args: + func: The function whose signature should be extracted. + ignore_first: True if the first argument should be ignored. This should + be used when func is a method of a class. - Returns: - A function signature object, which includes the names of the keyword - arguments as well as their default values. - """ - sig_params = [(k, v) for k, v - in funcsigs.signature(func).parameters.items()] + Returns: + A function signature object, which includes the names of the keyword + arguments as well as their default values. + """ + sig_params = [(k, v) for k, v + in funcsigs.signature(func).parameters.items()] - if ignore_first: - if len(sig_params) == 0: - raise Exception("Methods must take a 'self' argument, but the method " - "'{}' does not have one.".format(func.__name__)) - sig_params = sig_params[1:] + if ignore_first: + if len(sig_params) == 0: + raise Exception("Methods must take a 'self' argument, but the " + "method '{}' does not have one." + .format(func.__name__)) + sig_params = sig_params[1:] - # Extract the names of the keyword arguments. - keyword_names = set() - for keyword_name, parameter in sig_params: - if parameter.default != funcsigs._empty: - keyword_names.add(keyword_name) + # Extract the names of the keyword arguments. + keyword_names = set() + for keyword_name, parameter in sig_params: + if parameter.default != funcsigs._empty: + keyword_names.add(keyword_name) - # Construct the argument default values and other argument information. - arg_names = [] - arg_defaults = [] - arg_is_positionals = [] - for keyword_name, parameter in sig_params: - arg_names.append(keyword_name) - arg_defaults.append(parameter.default) - arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL) + # Construct the argument default values and other argument information. + arg_names = [] + arg_defaults = [] + arg_is_positionals = [] + for keyword_name, parameter in sig_params: + arg_names.append(keyword_name) + arg_defaults.append(parameter.default) + arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL) - return FunctionSignature(arg_names, arg_defaults, arg_is_positionals, - keyword_names, func.__name__) + return FunctionSignature(arg_names, arg_defaults, arg_is_positionals, + keyword_names, func.__name__) def extend_args(function_signature, args, kwargs): - """Extend the arguments that were passed into a function. + """Extend the arguments that were passed into a function. - This extends the arguments that were passed into a function with the default - arguments provided in the function definition. + This extends the arguments that were passed into a function with the + default arguments provided in the function definition. - Args: - function_signature: The function signature of the function being called. - args: The non-keyword arguments passed into the function. - kwargs: The keyword arguments passed into the function. + Args: + function_signature: The function signature of the function being + called. + args: The non-keyword arguments passed into the function. + kwargs: The keyword arguments passed into the function. - Returns: - An extended list of arguments to pass into the function. + Returns: + An extended list of arguments to pass into the function. - Raises: - Exception: An exception may be raised if the function cannot be called with - these arguments. - """ - arg_names = function_signature.arg_names - arg_defaults = function_signature.arg_defaults - arg_is_positionals = function_signature.arg_is_positionals - keyword_names = function_signature.keyword_names - function_name = function_signature.function_name + Raises: + Exception: An exception may be raised if the function cannot be called + with these arguments. + """ + arg_names = function_signature.arg_names + arg_defaults = function_signature.arg_defaults + arg_is_positionals = function_signature.arg_is_positionals + keyword_names = function_signature.keyword_names + function_name = function_signature.function_name - args = list(args) + args = list(args) - for keyword_name in kwargs: - if keyword_name not in keyword_names: - raise Exception("The name '{}' is not a valid keyword argument for the " - "function '{}'.".format(keyword_name, function_name)) + for keyword_name in kwargs: + if keyword_name not in keyword_names: + raise Exception("The name '{}' is not a valid keyword argument " + "for the function '{}'." + .format(keyword_name, function_name)) - # Fill in the remaining arguments. - zipped_info = list(zip(arg_names, arg_defaults, - arg_is_positionals))[len(args):] - for keyword_name, default_value, is_positional in zipped_info: - if keyword_name in kwargs: - args.append(kwargs[keyword_name]) - else: - if default_value != funcsigs._empty: - args.append(default_value) - else: - # This means that there is a missing argument. Unless this is the last - # argument and it is a *args argument in which case it can be omitted. - if not is_positional: - raise Exception("No value was provided for the argument '{}' for " - "the function '{}'.".format(keyword_name, - function_name)) + # Fill in the remaining arguments. + zipped_info = list(zip(arg_names, arg_defaults, + arg_is_positionals))[len(args):] + for keyword_name, default_value, is_positional in zipped_info: + if keyword_name in kwargs: + args.append(kwargs[keyword_name]) + else: + if default_value != funcsigs._empty: + args.append(default_value) + else: + # This means that there is a missing argument. Unless this is + # the last argument and it is a *args argument in which case it + # can be omitted. + if not is_positional: + raise Exception("No value was provided for the argument " + "'{}' for the function '{}'." + .format(keyword_name, function_name)) - too_many_arguments = (len(args) > len(arg_names) and - (len(arg_is_positionals) == 0 or - not arg_is_positionals[-1])) - if too_many_arguments: - raise Exception("Too many arguments were passed to the function '{}'" - .format(function_name)) - return args + too_many_arguments = (len(args) > len(arg_names) and + (len(arg_is_positionals) == 0 or + not arg_is_positionals[-1])) + if too_many_arguments: + raise Exception("Too many arguments were passed to the function '{}'" + .format(function_name)) + return args diff --git a/python/ray/test/test_functions.py b/python/ray/test/test_functions.py index 79e5e2f8d..105dbf184 100644 --- a/python/ray/test/test_functions.py +++ b/python/ray/test/test_functions.py @@ -11,99 +11,99 @@ import numpy as np @ray.remote(num_return_vals=2) def handle_int(a, b): - return a + 1, b + 1 + return a + 1, b + 1 # Test timing @ray.remote def empty_function(): - pass + pass @ray.remote def trivial_function(): - return 1 + return 1 # Test keyword arguments @ray.remote def keyword_fct1(a, b="hello"): - return "{} {}".format(a, b) + return "{} {}".format(a, b) @ray.remote def keyword_fct2(a="hello", b="world"): - return "{} {}".format(a, b) + return "{} {}".format(a, b) @ray.remote def keyword_fct3(a, b, c="hello", d="world"): - return "{} {} {} {}".format(a, b, c, d) + return "{} {} {} {}".format(a, b, c, d) # Test variable numbers of arguments @ray.remote def varargs_fct1(*a): - return " ".join(map(str, a)) + return " ".join(map(str, a)) @ray.remote def varargs_fct2(a, *b): - return " ".join(map(str, b)) + return " ".join(map(str, b)) try: - @ray.remote - def kwargs_throw_exception(**c): - return () - kwargs_exception_thrown = False + @ray.remote + def kwargs_throw_exception(**c): + return () + kwargs_exception_thrown = False except: - kwargs_exception_thrown = True + kwargs_exception_thrown = True try: - @ray.remote - def varargs_and_kwargs_throw_exception(a, b="hi", *c): - return "{} {} {}".format(a, b, c) - varargs_and_kwargs_exception_thrown = False + @ray.remote + def varargs_and_kwargs_throw_exception(a, b="hi", *c): + return "{} {} {}".format(a, b, c) + varargs_and_kwargs_exception_thrown = False except: - varargs_and_kwargs_exception_thrown = True + varargs_and_kwargs_exception_thrown = True # test throwing an exception @ray.remote def throw_exception_fct1(): - raise Exception("Test function 1 intentionally failed.") + raise Exception("Test function 1 intentionally failed.") @ray.remote def throw_exception_fct2(): - raise Exception("Test function 2 intentionally failed.") + raise Exception("Test function 2 intentionally failed.") @ray.remote(num_return_vals=3) def throw_exception_fct3(x): - raise Exception("Test function 3 intentionally failed.") + raise Exception("Test function 3 intentionally failed.") # test Python mode @ray.remote def python_mode_f(): - return np.array([0, 0]) + return np.array([0, 0]) @ray.remote def python_mode_g(x): - x[0] = 1 - return x + x[0] = 1 + return x # test no return values @ray.remote def no_op(): - pass + pass diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index 3e531a27d..179ad8117 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -15,119 +15,124 @@ EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY" def _wait_for_nodes_to_join(num_nodes, timeout=20): - """Wait until the nodes have joined the cluster. + """Wait until the nodes have joined the cluster. - This will wait until exactly num_nodes have joined the cluster and each node - has a local scheduler and a plasma manager. + This will wait until exactly num_nodes have joined the cluster and each + node has a local scheduler and a plasma manager. - Args: - num_nodes: The number of nodes to wait for. - timeout: The amount of time in seconds to wait before failing. + Args: + num_nodes: The number of nodes to wait for. + timeout: The amount of time in seconds to wait before failing. - Raises: - Exception: An exception is raised if too many nodes join the cluster or if - the timeout expires while we are waiting. - """ - start_time = time.time() - while time.time() - start_time < timeout: - client_table = ray.global_state.client_table() - num_ready_nodes = len(client_table) - if num_ready_nodes == num_nodes: - ready = True - # Check that for each node, a local scheduler and a plasma manager are - # present. - for ip_address, clients in client_table.items(): - client_types = [client["ClientType"] for client in clients] - if "local_scheduler" not in client_types: - ready = False - if "plasma_manager" not in client_types: - ready = False - if ready: - return - if num_ready_nodes > num_nodes: - # Too many nodes have joined. Something must be wrong. - raise Exception("{} nodes have joined the cluster, but we were " - "expecting {} nodes.".format(num_ready_nodes, num_nodes)) - time.sleep(0.1) + Raises: + Exception: An exception is raised if too many nodes join the cluster or + if the timeout expires while we are waiting. + """ + start_time = time.time() + while time.time() - start_time < timeout: + client_table = ray.global_state.client_table() + num_ready_nodes = len(client_table) + if num_ready_nodes == num_nodes: + ready = True + # Check that for each node, a local scheduler and a plasma manager + # are present. + for ip_address, clients in client_table.items(): + client_types = [client["ClientType"] for client in clients] + if "local_scheduler" not in client_types: + ready = False + if "plasma_manager" not in client_types: + ready = False + if ready: + return + if num_ready_nodes > num_nodes: + # Too many nodes have joined. Something must be wrong. + raise Exception("{} nodes have joined the cluster, but we were " + "expecting {} nodes.".format(num_ready_nodes, + num_nodes)) + time.sleep(0.1) - # If we get here then we timed out. - raise Exception("Timed out while waiting for {} nodes to join. Only {} " - "nodes have joined so far.".format(num_ready_nodes, - num_nodes)) + # If we get here then we timed out. + raise Exception("Timed out while waiting for {} nodes to join. Only {} " + "nodes have joined so far.".format(num_ready_nodes, + num_nodes)) def _broadcast_event(event_name, redis_address, data=None): - """Broadcast an event. + """Broadcast an event. - This is used to synchronize drivers for the multi-node tests. + This is used to synchronize drivers for the multi-node tests. - Args: - event_name: The name of the event to wait for. - redis_address: The address of the Redis server to use for synchronization. - data: Extra data to include in the broadcast (this will be returned by the - corresponding _wait_for_event call). This data must be json serializable. - """ - redis_host, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) - payload = json.dumps((event_name, data)) - redis_client.rpush(EVENT_KEY, payload) + Args: + event_name: The name of the event to wait for. + redis_address: The address of the Redis server to use for + synchronization. + data: Extra data to include in the broadcast (this will be returned by + the corresponding _wait_for_event call). This data must be json + serializable. + """ + redis_host, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) + payload = json.dumps((event_name, data)) + redis_client.rpush(EVENT_KEY, payload) def _wait_for_event(event_name, redis_address, extra_buffer=0): - """Block until an event has been broadcast. + """Block until an event has been broadcast. - This is used to synchronize drivers for the multi-node tests. + This is used to synchronize drivers for the multi-node tests. - Args: - event_name: The name of the event to wait for. - redis_address: The address of the Redis server to use for synchronization. - extra_buffer: An amount of time in seconds to wait after the event. + Args: + event_name: The name of the event to wait for. + redis_address: The address of the Redis server to use for + synchronization. + extra_buffer: An amount of time in seconds to wait after the event. - Returns: - The data that was passed into the corresponding _broadcast_event call. - """ - redis_host, redis_port = redis_address.split(":") - redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) - while True: - event_infos = redis_client.lrange(EVENT_KEY, 0, -1) - events = dict() - for event_info in event_infos: - name, data = json.loads(event_info) - if name in events: - raise Exception("The same event {} was broadcast twice.".format(name)) - events[name] = data - if event_name in events: - # Potentially sleep a little longer and then return the event data. - time.sleep(extra_buffer) - return events[event_name] - time.sleep(0.1) + Returns: + The data that was passed into the corresponding _broadcast_event call. + """ + redis_host, redis_port = redis_address.split(":") + redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) + while True: + event_infos = redis_client.lrange(EVENT_KEY, 0, -1) + events = dict() + for event_info in event_infos: + name, data = json.loads(event_info) + if name in events: + raise Exception("The same event {} was broadcast twice." + .format(name)) + events[name] = data + if event_name in events: + # Potentially sleep a little longer and then return the event data. + time.sleep(extra_buffer) + return events[event_name] + time.sleep(0.1) def _pid_alive(pid): - """Check if the process with this PID is alive or not. + """Check if the process with this PID is alive or not. - Args: - pid: The pid to check. + Args: + pid: The pid to check. - Returns: - This returns false if the process is dead or defunct. Otherwise, it returns - true. - """ - try: - os.kill(pid, 0) - except OSError: - return False - else: - if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE: - return False + Returns: + This returns false if the process is dead or defunct. Otherwise, it + returns true. + """ + try: + os.kill(pid, 0) + except OSError: + return False else: - return True + if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE: + return False + else: + return True def wait_for_pid_to_exit(pid, timeout=20): - start_time = time.time() - while time.time() - start_time < timeout: - if not _pid_alive(pid): - return - time.sleep(0.1) - raise Exception("Timed out while waiting for process to exit.") + start_time = time.time() + while time.time() - start_time < timeout: + if not _pid_alive(pid): + return + time.sleep(0.1) + raise Exception("Timed out while waiting for process to exit.") diff --git a/python/ray/utils.py b/python/ray/utils.py index 55febb1d4..2f6ed1423 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -11,51 +11,52 @@ import ray.local_scheduler def random_string(): - """Generate a random string to use as an ID. + """Generate a random string to use as an ID. - Note that users may seed numpy, which could cause this function to generate - duplicate IDs. Therefore, we need to seed numpy ourselves, but we can't - interfere with the state of the user's random number generator, so we extract - the state of the random number generator and reset it after we are done. + Note that users may seed numpy, which could cause this function to generate + duplicate IDs. Therefore, we need to seed numpy ourselves, but we can't + interfere with the state of the user's random number generator, so we + extract the state of the random number generator and reset it after we are + done. - TODO(rkn): If we want to later guarantee that these are generated in a - deterministic manner, then we will need to make some changes here. + TODO(rkn): If we want to later guarantee that these are generated in a + deterministic manner, then we will need to make some changes here. - Returns: - A random byte string of length 20. - """ - # Get the state of the numpy random number generator. - numpy_state = np.random.get_state() - # Try to use true randomness. - np.random.seed(None) - # Generate the random ID. - random_id = np.random.bytes(20) - # Reset the state of the numpy random number generator. - np.random.set_state(numpy_state) - return random_id + Returns: + A random byte string of length 20. + """ + # Get the state of the numpy random number generator. + numpy_state = np.random.get_state() + # Try to use true randomness. + np.random.seed(None) + # Generate the random ID. + random_id = np.random.bytes(20) + # Reset the state of the numpy random number generator. + np.random.set_state(numpy_state) + return random_id def decode(byte_str): - """Make this unicode in Python 3, otherwise leave it as bytes.""" - if sys.version_info >= (3, 0): - return byte_str.decode("ascii") - else: - return byte_str + """Make this unicode in Python 3, otherwise leave it as bytes.""" + if sys.version_info >= (3, 0): + return byte_str.decode("ascii") + else: + return byte_str def binary_to_object_id(binary_object_id): - return ray.local_scheduler.ObjectID(binary_object_id) + return ray.local_scheduler.ObjectID(binary_object_id) def binary_to_hex(identifier): - hex_identifier = binascii.hexlify(identifier) - if sys.version_info >= (3, 0): - hex_identifier = hex_identifier.decode() - return hex_identifier + hex_identifier = binascii.hexlify(identifier) + if sys.version_info >= (3, 0): + hex_identifier = hex_identifier.decode() + return hex_identifier def hex_to_binary(hex_identifier): - return binascii.unhexlify(hex_identifier) + return binascii.unhexlify(hex_identifier) FunctionProperties = collections.namedtuple("FunctionProperties", diff --git a/python/ray/worker.py b/python/ray/worker.py index 49e84dec8..5373530cd 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -62,506 +62,530 @@ TASK_STATUS_RUNNING = 8 class FunctionID(object): - def __init__(self, function_id): - self.function_id = function_id + def __init__(self, function_id): + self.function_id = function_id - def id(self): - return self.function_id + def id(self): + return self.function_id contained_objectids = [] def numbuf_serialize(value): - """This serializes a value and tracks the object IDs inside the value. + """This serializes a value and tracks the object IDs inside the value. - We also define a custom ObjectID serializer which also closes over the global - variable contained_objectids, and whenever the custom serializer is called, - it adds the releevant ObjectID to the list contained_objectids. The list - contained_objectids should be reset between calls to numbuf_serialize. + We also define a custom ObjectID serializer which also closes over the + global variable contained_objectids, and whenever the custom serializer is + called, it adds the releevant ObjectID to the list contained_objectids. The + list contained_objectids should be reset between calls to numbuf_serialize. - Args: - value: A Python object that will be serialized. + Args: + value: A Python object that will be serialized. - Returns: - The serialized object. - """ - assert len(contained_objectids) == 0, "This should be unreachable." - return ray.numbuf.serialize_list([value]) + Returns: + The serialized object. + """ + assert len(contained_objectids) == 0, "This should be unreachable." + return ray.numbuf.serialize_list([value]) class RayTaskError(Exception): - """An object used internally to represent a task that threw an exception. + """An object used internally to represent a task that threw an exception. - If a task throws an exception during execution, a RayTaskError is stored in - the object store for each of the task's outputs. When an object is retrieved - from the object store, the Python method that retrieved it checks to see if - the object is a RayTaskError and if it is then an exception is thrown - propagating the error message. + If a task throws an exception during execution, a RayTaskError is stored in + the object store for each of the task's outputs. When an object is + retrieved from the object store, the Python method that retrieved it checks + to see if the object is a RayTaskError and if it is then an exception is + thrown propagating the error message. - Currently, we either use the exception attribute or the traceback attribute - but not both. + Currently, we either use the exception attribute or the traceback attribute + but not both. - Attributes: - function_name (str): The name of the function that failed and produced the - RayTaskError. - exception (Exception): The exception object thrown by the failed task. - traceback_str (str): The traceback from the exception. - """ + Attributes: + function_name (str): The name of the function that failed and produced + the RayTaskError. + exception (Exception): The exception object thrown by the failed task. + traceback_str (str): The traceback from the exception. + """ - def __init__(self, function_name, exception, traceback_str): - """Initialize a RayTaskError.""" - self.function_name = function_name - if isinstance(exception, RayGetError) or isinstance(exception, - RayGetArgumentError): - self.exception = exception - else: - self.exception = None - self.traceback_str = traceback_str + def __init__(self, function_name, exception, traceback_str): + """Initialize a RayTaskError.""" + self.function_name = function_name + if (isinstance(exception, RayGetError) or + isinstance(exception, RayGetArgumentError)): + self.exception = exception + else: + self.exception = None + self.traceback_str = traceback_str - def __str__(self): - """Format a RayTaskError as a string.""" - if self.traceback_str is None: - # This path is taken if getting the task arguments failed. - return ("Remote function {}{}{} failed with:\n\n{}" - .format(colorama.Fore.RED, self.function_name, - colorama.Fore.RESET, self.exception)) - else: - # This path is taken if the task execution failed. - return ("Remote function {}{}{} failed with:\n\n{}" - .format(colorama.Fore.RED, self.function_name, - colorama.Fore.RESET, self.traceback_str)) + def __str__(self): + """Format a RayTaskError as a string.""" + if self.traceback_str is None: + # This path is taken if getting the task arguments failed. + return ("Remote function {}{}{} failed with:\n\n{}" + .format(colorama.Fore.RED, self.function_name, + colorama.Fore.RESET, self.exception)) + else: + # This path is taken if the task execution failed. + return ("Remote function {}{}{} failed with:\n\n{}" + .format(colorama.Fore.RED, self.function_name, + colorama.Fore.RESET, self.traceback_str)) class RayGetError(Exception): - """An exception used when get is called on an output of a failed task. + """An exception used when get is called on an output of a failed task. - Attributes: - objectid (lib.ObjectID): The ObjectID that get was called on. - task_error (RayTaskError): The RayTaskError object created by the failed - task. - """ + Attributes: + objectid (lib.ObjectID): The ObjectID that get was called on. + task_error (RayTaskError): The RayTaskError object created by the + failed task. + """ - def __init__(self, objectid, task_error): - """Initialize a RayGetError object.""" - self.objectid = objectid - self.task_error = task_error + def __init__(self, objectid, task_error): + """Initialize a RayGetError object.""" + self.objectid = objectid + self.task_error = task_error - def __str__(self): - """Format a RayGetError as a string.""" - return ("Could not get objectid {}. It was created by remote function " - "{}{}{} which failed with:\n\n{}" - .format(self.objectid, colorama.Fore.RED, - self.task_error.function_name, colorama.Fore.RESET, - self.task_error)) + def __str__(self): + """Format a RayGetError as a string.""" + return ("Could not get objectid {}. It was created by remote function " + "{}{}{} which failed with:\n\n{}" + .format(self.objectid, colorama.Fore.RED, + self.task_error.function_name, colorama.Fore.RESET, + self.task_error)) class RayGetArgumentError(Exception): - """An exception used when a task's argument was produced by a failed task. + """An exception used when a task's argument was produced by a failed task. - Attributes: - argument_index (int): The index (zero indexed) of the failed argument in - present task's remote function call. - function_name (str): The name of the function for the current task. - objectid (lib.ObjectID): The ObjectID that was passed in as the argument. - task_error (RayTaskError): The RayTaskError object created by the failed - task. - """ + Attributes: + argument_index (int): The index (zero indexed) of the failed argument + in present task's remote function call. + function_name (str): The name of the function for the current task. + objectid (lib.ObjectID): The ObjectID that was passed in as the + argument. + task_error (RayTaskError): The RayTaskError object created by the + failed task. + """ - def __init__(self, function_name, argument_index, objectid, task_error): - """Initialize a RayGetArgumentError object.""" - self.argument_index = argument_index - self.function_name = function_name - self.objectid = objectid - self.task_error = task_error + def __init__(self, function_name, argument_index, objectid, task_error): + """Initialize a RayGetArgumentError object.""" + self.argument_index = argument_index + self.function_name = function_name + self.objectid = objectid + self.task_error = task_error - def __str__(self): - """Format a RayGetArgumentError as a string.""" - return ("Failed to get objectid {} as argument {} for remote function " - "{}{}{}. It was created by remote function {}{}{} which failed " - "with:\n{}".format(self.objectid, self.argument_index, - colorama.Fore.RED, self.function_name, - colorama.Fore.RESET, colorama.Fore.RED, - self.task_error.function_name, - colorama.Fore.RESET, self.task_error)) + def __str__(self): + """Format a RayGetArgumentError as a string.""" + return ("Failed to get objectid {} as argument {} for remote function " + "{}{}{}. It was created by remote function {}{}{} which " + "failed with:\n{}".format(self.objectid, self.argument_index, + colorama.Fore.RED, + self.function_name, + colorama.Fore.RESET, + colorama.Fore.RED, + self.task_error.function_name, + colorama.Fore.RESET, + self.task_error)) class Worker(object): - """A class used to define the control flow of a worker process. + """A class used to define the control flow of a worker process. - Note: - The methods in this class are considered unexposed to the user. The - functions outside of this class are considered exposed. + Note: + The methods in this class are considered unexposed to the user. The + functions outside of this class are considered exposed. - Attributes: - functions (Dict[str, Callable]): A dictionary mapping the name of a remote - function to the remote function itself. This is the set of remote - functions that can be executed by this worker. - connected (bool): True if Ray has been started and False otherwise. - mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE, - and WORKER_MODE. - cached_remote_functions (List[Tuple[str, str]]): A list of pairs - representing the remote functions that were defined before he worker - called connect. The first element is the name of the remote function, and - the second element is the serialized remote function. When the worker - eventually does call connect, if it is a driver, it will export these - functions to the scheduler. If cached_remote_functions is None, that - means that connect has been called already. - cached_functions_to_run (List): A list of functions to run on all of the - workers that should be exported as soon as connect is called. - """ - - def __init__(self): - """Initialize a Worker object.""" - # The functions field is a dictionary that maps a driver ID to a dictionary - # of functions that have been registered for that driver (this inner - # dictionary maps function IDs to a tuple of the function name and the - # function itself). This should only be used on workers that execute remote - # functions. - self.functions = collections.defaultdict(lambda: {}) - # The function_properties field is a dictionary that maps a driver ID to a - # dictionary of functions that have been registered for that driver (this - # inner dictionary maps function IDs to a tuple of the number of values - # returned by that function, the number of CPUs required by that function, - # and the number of GPUs required by that function). This is used when - # submitting a function (which can be done both on workers and on drivers). - self.function_properties = collections.defaultdict(lambda: {}) - # This is a dictionary mapping driver ID to a dictionary that maps remote - # function IDs for that driver to a counter of the number of times that - # remote function has been executed on this worker. The counter is - # incremented every time the function is executed on this worker. When the - # counter reaches the maximum number of executions allowed for a particular - # function, the worker is killed. - self.num_task_executions = collections.defaultdict(lambda: {}) - self.connected = False - self.mode = None - self.cached_remote_functions = [] - self.cached_functions_to_run = [] - self.fetch_and_register_actor = None - self.make_actor = None - self.actors = {} - # Use a defaultdict for the actor counts. If this is accessed with a - # missing key, the default value of 0 is returned, and that key value pair - # is added to the dict. - self.actor_counters = collections.defaultdict(lambda: 0) - - def set_mode(self, mode): - """Set the mode of the worker. - - The mode SCRIPT_MODE should be used if this Worker is a driver that is - being run as a Python script or interactively in a shell. It will print - information about task failures. - - The mode WORKER_MODE should be used if this Worker is not a driver. It will - not print information about tasks. - - The mode PYTHON_MODE should be used if this Worker is a driver and if you - want to run the driver in a manner equivalent to serial Python for - debugging purposes. It will not send remote function calls to the scheduler - and will insead execute them in a blocking fashion. - - The mode SILENT_MODE should be used only during testing. It does not print - any information about errors because some of the tests intentionally fail. - - args: - mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and SILENT_MODE. + Attributes: + functions (Dict[str, Callable]): A dictionary mapping the name of a + remote function to the remote function itself. This is the set of + remote functions that can be executed by this worker. + connected (bool): True if Ray has been started and False otherwise. + mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, + SILENT_MODE, and WORKER_MODE. + cached_remote_functions (List[Tuple[str, str]]): A list of pairs + representing the remote functions that were defined before the + worker called connect. The first element is the name of the remote + function, and the second element is the serialized remote function. + When the worker eventually does call connect, if it is a driver, it + will export these functions to the scheduler. If + cached_remote_functions is None, that means that connect has been + called already. + cached_functions_to_run (List): A list of functions to run on all of + the workers that should be exported as soon as connect is called. """ - self.mode = mode - colorama.init() - def store_and_register(self, object_id, value, depth=100): - """Store an object and attempt to register its class if needed. + def __init__(self): + """Initialize a Worker object.""" + # The functions field is a dictionary that maps a driver ID to a + # dictionary of functions that have been registered for that driver + # (this inner dictionary maps function IDs to a tuple of the function + # name and the function itself). This should only be used on workers + # that execute remote functions. + self.functions = collections.defaultdict(lambda: {}) + # The function_properties field is a dictionary that maps a driver ID + # to a dictionary of functions that have been registered for that + # driver (this inner dictionary maps function IDs to a tuple of the + # number of values returned by that function, the number of CPUs + # required by that function, and the number of GPUs required by that + # function). This is used when submitting a function (which can be done + # both on workers and on drivers). + self.function_properties = collections.defaultdict(lambda: {}) + # This is a dictionary mapping driver ID to a dictionary that maps + # remote function IDs for that driver to a counter of the number of + # times that remote function has been executed on this worker. The + # counter is incremented every time the function is executed on this + # worker. When the counter reaches the maximum number of executions + # allowed for a particular function, the worker is killed. + self.num_task_executions = collections.defaultdict(lambda: {}) + self.connected = False + self.mode = None + self.cached_remote_functions = [] + self.cached_functions_to_run = [] + self.fetch_and_register_actor = None + self.make_actor = None + self.actors = {} + # Use a defaultdict for the actor counts. If this is accessed with a + # missing key, the default value of 0 is returned, and that key value + # pair is added to the dict. + self.actor_counters = collections.defaultdict(lambda: 0) - Args: - object_id: The ID of the object to store. - value: The value to put in the object store. - depth: The maximum number of classes to recursively register. + def set_mode(self, mode): + """Set the mode of the worker. - Raises: - Exception: An exception is raised if the attempt to store the object - fails. This can happen if there is already an object with the same ID - in the object store or if the object store is full. - """ - counter = 0 - while True: - if counter == depth: - raise Exception("Ray exceeded the maximum number of classes that it " - "will recursively serialize when attempting to " - "serialize an object of type {}.".format(type(value))) - counter += 1 - try: - ray.numbuf.store_list(object_id.id(), self.plasma_client.conn, [value]) - break - except serialization.RaySerializationException as e: + The mode SCRIPT_MODE should be used if this Worker is a driver that is + being run as a Python script or interactively in a shell. It will print + information about task failures. + + The mode WORKER_MODE should be used if this Worker is not a driver. It + will not print information about tasks. + + The mode PYTHON_MODE should be used if this Worker is a driver and if + you want to run the driver in a manner equivalent to serial Python for + debugging purposes. It will not send remote function calls to the + scheduler and will insead execute them in a blocking fashion. + + The mode SILENT_MODE should be used only during testing. It does not + print any information about errors because some of the tests + intentionally fail. + + args: + mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and + SILENT_MODE. + """ + self.mode = mode + colorama.init() + + def store_and_register(self, object_id, value, depth=100): + """Store an object and attempt to register its class if needed. + + Args: + object_id: The ID of the object to store. + value: The value to put in the object store. + depth: The maximum number of classes to recursively register. + + Raises: + Exception: An exception is raised if the attempt to store the + object fails. This can happen if there is already an object + with the same ID in the object store or if the object store is + full. + """ + counter = 0 + while True: + if counter == depth: + raise Exception("Ray exceeded the maximum number of classes " + "that it will recursively serialize when " + "attempting to serialize an object of " + "type {}.".format(type(value))) + counter += 1 + try: + ray.numbuf.store_list(object_id.id(), self.plasma_client.conn, + [value]) + break + except serialization.RaySerializationException as e: + try: + _register_class(type(e.example_object)) + warning_message = ("WARNING: Serializing objects of type " + "{} by expanding them as dictionaries " + "of their fields. This behavior may " + "be incorrect in some cases." + .format(type(e.example_object))) + print(warning_message) + except serialization.RayNotDictionarySerializable: + _register_class(type(e.example_object), pickle=True) + warning_message = ("WARNING: Falling back to serializing " + "objects of type {} by using pickle. " + "This may be inefficient." + .format(type(e.example_object))) + print(warning_message) + + def put_object(self, object_id, value): + """Put value in the local object store with object id objectid. + + This assumes that the value for objectid has not yet been placed in the + local object store. + + Args: + object_id (object_id.ObjectID): The object ID of the value to be + put. + value: The value to put in the object store. + + Raises: + Exception: An exception is raised if the attempt to store the + object fails. This can happen if there is already an object + with the same ID in the object store or if the object store is + full. + """ + # Make sure that the value is not an object ID. + if isinstance(value, ray.local_scheduler.ObjectID): + raise Exception("Calling 'put' on an ObjectID is not allowed " + "(similarly, returning an ObjectID from a remote " + "function is not allowed). If you really want to " + "do this, you can wrap the ObjectID in a list and " + "call 'put' on it (or return it).") + + # Serialize and put the object in the object store. try: - _register_class(type(e.example_object)) - warning_message = ("WARNING: Serializing objects of type {} by " - "expanding them as dictionaries of their fields. " - "This behavior may be incorrect in some cases." - .format(type(e.example_object))) - print(warning_message) - except serialization.RayNotDictionarySerializable: - _register_class(type(e.example_object), pickle=True) - warning_message = ("WARNING: Falling back to serializing objects of " - "type {} by using pickle. This may be " - "inefficient.".format(type(e.example_object))) - print(warning_message) + self.store_and_register(object_id, value) + except ray.numbuf.numbuf_plasma_object_exists_error as e: + # The object already exists in the object store, so there is no + # need to add it again. TODO(rkn): We need to compare the hashes + # and make sure that the objects are in fact the same. We also + # should return an error code to the caller instead of printing a + # message. + print("This object already exists in the object store.") - def put_object(self, object_id, value): - """Put value in the local object store with object id objectid. + global contained_objectids + # Optionally do something with the contained_objectids here. + contained_objectids = [] - This assumes that the value for objectid has not yet been placed in the - local object store. + def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + try: + # We divide very large get requests into smaller get requests + # so that a single get request doesn't block the store for a + # long time, if the store is blocked, it can block the manager + # as well as a consequence. + results = [] + get_request_size = 10000 + for i in range(0, len(object_ids), get_request_size): + results += ray.numbuf.retrieve_list( + object_ids[i:(i + get_request_size)], + self.plasma_client.conn, + timeout) + return results + except serialization.RayDeserializationException as e: + # Wait a little bit for the import thread to import the class. + # If we currently have the worker lock, we need to release it + # so that the import thread can acquire it. + if self.mode == WORKER_MODE: + self.lock.release() + time.sleep(0.01) + if self.mode == WORKER_MODE: + self.lock.acquire() - Args: - object_id (object_id.ObjectID): The object ID of the value to be put. - value: The value to put in the object store. + if time.time() - start_time > error_timeout: + warning_message = ("This worker or driver is waiting to " + "receive a class definition so that it " + "can deserialize an object from the " + "object store. This may be fine, or it " + "may be a bug.") + if not warning_sent: + self.push_error_to_driver(self.task_driver_id.id(), + "wait_for_class", + warning_message) + warning_sent = True - Raises: - Exception: An exception is raised if the attempt to store the object - fails. This can happen if there is already an object with the same ID - in the object store or if the object store is full. - """ - # Make sure that the value is not an object ID. - if isinstance(value, ray.local_scheduler.ObjectID): - raise Exception("Calling 'put' on an ObjectID is not allowed " - "(similarly, returning an ObjectID from a remote " - "function is not allowed). If you really want to do " - "this, you can wrap the ObjectID in a list and call " - "'put' on it (or return it).") + def get_object(self, object_ids): + """Get the value or values in the object store associated with the IDs. - # Serialize and put the object in the object store. - try: - self.store_and_register(object_id, value) - except ray.numbuf.numbuf_plasma_object_exists_error as e: - # The object already exists in the object store, so there is no need to - # add it again. TODO(rkn): We need to compare the hashes and make sure - # that the objects are in fact the same. We also should return an error - # code to the caller instead of printing a message. - print("This object already exists in the object store.") + Return the values from the local object store for object_ids. This will + block until all the values for object_ids have been written to the + local object store. - global contained_objectids - # Optionally do something with the contained_objectids here. - contained_objectids = [] + Args: + object_ids (List[object_id.ObjectID]): A list of the object IDs + whose values should be retrieved. + """ + # Make sure that the values are object IDs. + for object_id in object_ids: + if not isinstance(object_id, ray.local_scheduler.ObjectID): + raise Exception("Attempting to call `get` on the value {}, " + "which is not an ObjectID.".format(object_id)) + # Do an initial fetch for remote objects. We divide the fetch into + # smaller fetches so as to not block the manager for a prolonged period + # of time in a single call. + fetch_request_size = 10000 + plain_object_ids = [object_id.id() for object_id in object_ids] + for i in range(0, len(object_ids), fetch_request_size): + self.plasma_client.fetch( + plain_object_ids[i:(i + fetch_request_size)]) - def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10): - start_time = time.time() - # Only send the warning once. - warning_sent = False - while True: - try: - # We divide very large get requests into smaller get requests so that - # a single get request doesn't block the store for a long time, if the - # store is blocked, it can block the manager as well as a consequence. - results = [] - get_request_size = 10000 - for i in range(0, len(object_ids), get_request_size): - results += ray.numbuf.retrieve_list( - object_ids[i:(i + get_request_size)], - self.plasma_client.conn, - timeout) - return results - except serialization.RayDeserializationException as e: - # Wait a little bit for the import thread to import the class. If we - # currently have the worker lock, we need to release it so that the - # import thread can acquire it. - if self.mode == WORKER_MODE: - self.lock.release() - time.sleep(0.01) - if self.mode == WORKER_MODE: - self.lock.acquire() + # Get the objects. We initially try to get the objects immediately. + final_results = self.retrieve_and_deserialize( + [object_id.id() for object_id in object_ids], 0) + # Construct a dictionary mapping object IDs that we haven't gotten yet + # to their original index in the object_ids argument. + unready_ids = dict((object_id, i) for (i, (object_id, val)) in + enumerate(final_results) if val is None) + was_blocked = (len(unready_ids) > 0) + # Try reconstructing any objects we haven't gotten yet. Try to get them + # until at least GET_TIMEOUT_MILLISECONDS milliseconds passes, then + # repeat. + while len(unready_ids) > 0: + for unready_id in unready_ids: + self.local_scheduler_client.reconstruct_object(unready_id) + # Do another fetch for objects that aren't available locally yet, + # in case they were evicted since the last fetch. We divide the + # fetch into smaller fetches so as to not block the manager for a + # prolonged period of time in a single call. + object_ids_to_fetch = list(unready_ids.keys()) + for i in range(0, len(object_ids_to_fetch), fetch_request_size): + self.plasma_client.fetch( + object_ids_to_fetch[i:(i + fetch_request_size)]) + results = self.retrieve_and_deserialize( + list(unready_ids.keys()), + max([GET_TIMEOUT_MILLISECONDS, int(0.01 * len(unready_ids))])) + # Remove any entries for objects we received during this iteration + # so we don't retrieve the same object twice. + for object_id, val in results: + if val is not None: + index = unready_ids[object_id] + final_results[index] = (object_id, val) + unready_ids.pop(object_id) - if time.time() - start_time > error_timeout: - warning_message = ("This worker or driver is waiting to receive a " - "class definition so that it can deserialize an " - "object from the object store. This may be fine, " - "or it may be a bug.") - if not warning_sent: - self.push_error_to_driver(self.task_driver_id.id(), - "wait_for_class", - warning_message) - warning_sent = True + # If there were objects that we weren't able to get locally, let the + # local scheduler know that we're now unblocked. + if was_blocked: + self.local_scheduler_client.notify_unblocked() - def get_object(self, object_ids): - """Get the value or values in the object store associated with object_ids. + # Unwrap the object from the list (it was wrapped put_object). + assert len(final_results) == len(object_ids) + for i in range(len(final_results)): + assert final_results[i][0] == object_ids[i].id() + return [result[1][0] for result in final_results] - Return the values from the local object store for object_ids. This will - block until all the values for object_ids have been written to the local - object store. + def submit_task(self, function_id, func_name, args, actor_id=None): + """Submit a remote task to the scheduler. - Args: - object_ids (List[object_id.ObjectID]): A list of the object IDs whose - values should be retrieved. - """ - # Make sure that the values are object IDs. - for object_id in object_ids: - if not isinstance(object_id, ray.local_scheduler.ObjectID): - raise Exception("Attempting to call `get` on the value {}, which is " - "not an ObjectID.".format(object_id)) - # Do an initial fetch for remote objects. We divide the fetch into smaller - # fetches so as to not block the manager for a prolonged period of time in - # a single call. - fetch_request_size = 10000 - plain_object_ids = [object_id.id() for object_id in object_ids] - for i in range(0, len(object_ids), fetch_request_size): - self.plasma_client.fetch(plain_object_ids[i:(i + fetch_request_size)]) + Tell the scheduler to schedule the execution of the function with name + func_name with arguments args. Retrieve object IDs for the outputs of + the function from the scheduler and immediately return them. - # Get the objects. We initially try to get the objects immediately. - final_results = self.retrieve_and_deserialize( - [object_id.id() for object_id in object_ids], 0) - # Construct a dictionary mapping object IDs that we haven't gotten yet to - # their original index in the object_ids argument. - unready_ids = dict((object_id, i) for (i, (object_id, val)) in - enumerate(final_results) if val is None) - was_blocked = (len(unready_ids) > 0) - # Try reconstructing any objects we haven't gotten yet. Try to get them - # until at least GET_TIMEOUT_MILLISECONDS milliseconds passes, then repeat. - while len(unready_ids) > 0: - for unready_id in unready_ids: - self.local_scheduler_client.reconstruct_object(unready_id) - # Do another fetch for objects that aren't available locally yet, in case - # they were evicted since the last fetch. We divide the fetch into - # smaller fetches so as to not block the manager for a prolonged period - # of time in a single call. - object_ids_to_fetch = list(unready_ids.keys()) - for i in range(0, len(object_ids_to_fetch), fetch_request_size): - self.plasma_client.fetch( - object_ids_to_fetch[i:(i + fetch_request_size)]) - results = self.retrieve_and_deserialize( - list(unready_ids.keys()), - max([GET_TIMEOUT_MILLISECONDS, int(0.01 * len(unready_ids))])) - # Remove any entries for objects we received during this iteration so we - # don't retrieve the same object twice. - for object_id, val in results: - if val is not None: - index = unready_ids[object_id] - final_results[index] = (object_id, val) - unready_ids.pop(object_id) + Args: + func_name (str): The name of the function to be executed. + args (List[Any]): The arguments to pass into the function. + Arguments can be object IDs or they can be values. If they are + values, they must be serializable objecs. + """ + with log_span("ray:submit_task", worker=self): + check_main_thread() + actor_id = (ray.local_scheduler.ObjectID(NIL_ACTOR_ID) + if actor_id is None else actor_id) + # Put large or complex arguments that are passed by value in the + # object store first. + args_for_local_scheduler = [] + for arg in args: + if isinstance(arg, ray.local_scheduler.ObjectID): + args_for_local_scheduler.append(arg) + elif ray.local_scheduler.check_simple_value(arg): + args_for_local_scheduler.append(arg) + else: + args_for_local_scheduler.append(put(arg)) - # If there were objects that we weren't able to get locally, let the local - # scheduler know that we're now unblocked. - if was_blocked: - self.local_scheduler_client.notify_unblocked() + # Look up the various function properties. + function_properties = self.function_properties[ + self.task_driver_id.id()][function_id.id()] - # Unwrap the object from the list (it was wrapped put_object). - assert len(final_results) == len(object_ids) - for i in range(len(final_results)): - assert final_results[i][0] == object_ids[i].id() - return [result[1][0] for result in final_results] + # Submit the task to local scheduler. + task = ray.local_scheduler.Task( + self.task_driver_id, + ray.local_scheduler.ObjectID(function_id.id()), + args_for_local_scheduler, + function_properties.num_return_vals, + self.current_task_id, + self.task_index, + actor_id, self.actor_counters[actor_id], + [function_properties.num_cpus, function_properties.num_gpus]) + # Increment the worker's task index to track how many tasks have + # been submitted by the current task so far. + self.task_index += 1 + self.actor_counters[actor_id] += 1 + self.local_scheduler_client.submit(task) - def submit_task(self, function_id, func_name, args, actor_id=None): - """Submit a remote task to the scheduler. + return task.returns() - Tell the scheduler to schedule the execution of the function with name - func_name with arguments args. Retrieve object IDs for the outputs of - the function from the scheduler and immediately return them. + def run_function_on_all_workers(self, function): + """Run arbitrary code on all of the workers. - Args: - func_name (str): The name of the function to be executed. - args (List[Any]): The arguments to pass into the function. Arguments can - be object IDs or they can be values. If they are values, they - must be serializable objecs. - """ - with log_span("ray:submit_task", worker=self): - check_main_thread() - actor_id = (ray.local_scheduler.ObjectID(NIL_ACTOR_ID) - if actor_id is None else actor_id) - # Put large or complex arguments that are passed by value in the object - # store first. - args_for_local_scheduler = [] - for arg in args: - if isinstance(arg, ray.local_scheduler.ObjectID): - args_for_local_scheduler.append(arg) - elif ray.local_scheduler.check_simple_value(arg): - args_for_local_scheduler.append(arg) + This function will first be run on the driver, and then it will be + exported to all of the workers to be run. It will also be run on any + new workers that register later. If ray.init has not been called yet, + then cache the function and export it later. + + Args: + function (Callable): The function to run on all of the workers. It + should not take any arguments. If it returns anything, its + return values will not be used. + """ + check_main_thread() + # If ray.init has not been called yet, then cache the function and + # export it when connect is called. Otherwise, run the function on all + # workers. + if self.mode is None: + self.cached_functions_to_run.append(function) else: - args_for_local_scheduler.append(put(arg)) + # Attempt to pickle the function before we need it. This could + # fail, and it is more convenient if the failure happens before we + # actually run the function locally. + pickled_function = pickle.dumps(function) - # Look up the various function properties. - function_properties = self.function_properties[ - self.task_driver_id.id()][function_id.id()] + function_to_run_id = random_string() + key = b"FunctionsToRun:" + function_to_run_id + # First run the function on the driver. Pass in the number of + # workers on this node that have already started executing this + # remote function, and increment that value. Subtract 1 so that the + # counter starts at 0. + counter = self.redis_client.hincrby(self.node_ip_address, + key, 1) - 1 + function({"counter": counter}) + # Run the function on all workers. + self.redis_client.hmset(key, + {"driver_id": self.task_driver_id.id(), + "function_id": function_to_run_id, + "function": pickled_function}) + self.redis_client.rpush("Exports", key) - # Submit the task to local scheduler. - task = ray.local_scheduler.Task( - self.task_driver_id, - ray.local_scheduler.ObjectID(function_id.id()), - args_for_local_scheduler, - function_properties.num_return_vals, - self.current_task_id, - self.task_index, - actor_id, self.actor_counters[actor_id], - [function_properties.num_cpus, function_properties.num_gpus]) - # Increment the worker's task index to track how many tasks have been - # submitted by the current task so far. - self.task_index += 1 - self.actor_counters[actor_id] += 1 - self.local_scheduler_client.submit(task) + def push_error_to_driver(self, driver_id, error_type, message, data=None): + """Push an error message to the driver to be printed in the background. - return task.returns() - - def run_function_on_all_workers(self, function): - """Run arbitrary code on all of the workers. - - This function will first be run on the driver, and then it will be exported - to all of the workers to be run. It will also be run on any new workers - that register later. If ray.init has not been called yet, then cache the - function and export it later. - - Args: - function (Callable): The function to run on all of the workers. It should - not take any arguments. If it returns anything, its return values will - not be used. - """ - check_main_thread() - # If ray.init has not been called yet, then cache the function and export - # it when connect is called. Otherwise, run the function on all workers. - if self.mode is None: - self.cached_functions_to_run.append(function) - else: - # Attempt to pickle the function before we need it. This could fail, and - # it is more convenient if the failure happens before we actually run the - # function locally. - pickled_function = pickle.dumps(function) - - function_to_run_id = random_string() - key = b"FunctionsToRun:" + function_to_run_id - # First run the function on the driver. Pass in the number of workers on - # this node that have already started executing this remote function, - # and increment that value. Subtract 1 so that the counter starts at 0. - counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1 - function({"counter": counter}) - # Run the function on all workers. - self.redis_client.hmset(key, {"driver_id": self.task_driver_id.id(), - "function_id": function_to_run_id, - "function": pickled_function}) - self.redis_client.rpush("Exports", key) - - def push_error_to_driver(self, driver_id, error_type, message, data=None): - """Push an error message to the driver to be printed in the background. - - Args: - driver_id: The ID of the driver to push the error message to. - error_type (str): The type of the error. - message (str): The message that will be printed in the background on the - driver. - data: This should be a dictionary mapping strings to strings. It will be - serialized with json and stored in Redis. - """ - error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string() - data = {} if data is None else data - self.redis_client.hmset(error_key, {"type": error_type, - "message": message, - "data": data}) - self.redis_client.rpush("ErrorKeys", error_key) + Args: + driver_id: The ID of the driver to push the error message to. + error_type (str): The type of the error. + message (str): The message that will be printed in the background + on the driver. + data: This should be a dictionary mapping strings to strings. It + will be serialized with json and stored in Redis. + """ + error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string() + data = {} if data is None else data + self.redis_client.hmset(error_key, {"type": error_type, + "message": message, + "data": data}) + self.redis_client.rpush("ErrorKeys", error_key) def get_gpu_ids(): - """Get the IDs of the GPU that are available to the worker. + """Get the IDs of the GPU that are available to the worker. - Each ID is an integer in the range [0, NUM_GPUS - 1], where NUM_GPUS is the - number of GPUs that the node has. - """ - return global_worker.local_scheduler_client.gpu_ids() + Each ID is an integer in the range [0, NUM_GPUS - 1], where NUM_GPUS is the + number of GPUs that the node has. + """ + return global_worker.local_scheduler_client.gpu_ids() global_worker = Worker() @@ -575,209 +599,214 @@ global_state = state.GlobalState() class RayConnectionError(Exception): - pass + pass def check_main_thread(): - """Check that we are currently on the main thread. + """Check that we are currently on the main thread. - Raises: - Exception: An exception is raised if this is called on a thread other than - the main thread. - """ - if threading.current_thread().getName() != "MainThread": - raise Exception("The Ray methods are not thread safe and must be called " - "from the main thread. This method was called from thread " - "{}.".format(threading.current_thread().getName())) + Raises: + Exception: An exception is raised if this is called on a thread other + than the main thread. + """ + if threading.current_thread().getName() != "MainThread": + raise Exception("The Ray methods are not thread safe and must be " + "called from the main thread. This method was called " + "from thread {}." + .format(threading.current_thread().getName())) def check_connected(worker=global_worker): - """Check if the worker is connected. + """Check if the worker is connected. - Raises: - Exception: An exception is raised if the worker is not connected. - """ - if not worker.connected: - raise RayConnectionError("This command cannot be called before Ray has " - "been started. You can start Ray with " - "'ray.init()'.") + Raises: + Exception: An exception is raised if the worker is not connected. + """ + if not worker.connected: + raise RayConnectionError("This command cannot be called before Ray " + "has been started. You can start Ray with " + "'ray.init()'.") def print_failed_task(task_status): - """Print information about failed tasks. + """Print information about failed tasks. - Args: - task_status (Dict): A dictionary containing the name, operationid, and - error message for a failed task. - """ - print(""" - Error: Task failed - Function Name: {} - Task ID: {} - Error Message: \n{} - """.format(task_status["function_name"], task_status["operationid"], - task_status["error_message"])) + Args: + task_status (Dict): A dictionary containing the name, operationid, and + error message for a failed task. + """ + print(""" + Error: Task failed + Function Name: {} + Task ID: {} + Error Message: \n{} + """.format(task_status["function_name"], task_status["operationid"], + task_status["error_message"])) def error_applies_to_driver(error_key, worker=global_worker): - """Return True if the error is for this driver and false otherwise.""" - # TODO(rkn): Should probably check that this is only called on a driver. - # Check that the error key is formatted as in push_error_to_driver. - assert len(error_key) == (len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + - ERROR_ID_LENGTH), error_key - # If the driver ID in the error message is a sequence of all zeros, then the - # message is intended for all drivers. - generic_driver_id = DRIVER_ID_LENGTH * b"\x00" - driver_id = error_key[len(ERROR_KEY_PREFIX):(len(ERROR_KEY_PREFIX) + - DRIVER_ID_LENGTH)] - return (driver_id == worker.task_driver_id.id() or - driver_id == generic_driver_id) + """Return True if the error is for this driver and false otherwise.""" + # TODO(rkn): Should probably check that this is only called on a driver. + # Check that the error key is formatted as in push_error_to_driver. + assert len(error_key) == (len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + + ERROR_ID_LENGTH), error_key + # If the driver ID in the error message is a sequence of all zeros, then + # the message is intended for all drivers. + generic_driver_id = DRIVER_ID_LENGTH * b"\x00" + driver_id = error_key[len(ERROR_KEY_PREFIX):(len(ERROR_KEY_PREFIX) + + DRIVER_ID_LENGTH)] + return (driver_id == worker.task_driver_id.id() or + driver_id == generic_driver_id) def error_info(worker=global_worker): - """Return information about failed tasks.""" - check_connected(worker) - check_main_thread() - error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) - errors = [] - for error_key in error_keys: - if error_applies_to_driver(error_key, worker=worker): - error_contents = worker.redis_client.hgetall(error_key) - # If the error is an object hash mismatch, look up the function name for - # the nondeterministic task. TODO(rkn): Change this so that we don't have - # to look up additional information. Ideally all relevant information - # would already be in error_contents. - error_type = error_contents[b"type"] - if error_type in [OBJECT_HASH_MISMATCH_ERROR_TYPE, - PUT_RECONSTRUCTION_ERROR_TYPE]: - function_id = error_contents[b"data"] - if function_id == NIL_FUNCTION_ID: - function_name = b"Driver" - else: - task_driver_id = worker.task_driver_id - function_name = worker.redis_client.hget( - b"RemoteFunction:" + task_driver_id.id() + b":" + function_id, - "name") - error_contents[b"data"] = function_name - errors.append(error_contents) + """Return information about failed tasks.""" + check_connected(worker) + check_main_thread() + error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) + errors = [] + for error_key in error_keys: + if error_applies_to_driver(error_key, worker=worker): + error_contents = worker.redis_client.hgetall(error_key) + # If the error is an object hash mismatch, look up the function + # name for the nondeterministic task. TODO(rkn): Change this so + # that we don't have to look up additional information. Ideally all + # relevant information would already be in error_contents. + error_type = error_contents[b"type"] + if error_type in [OBJECT_HASH_MISMATCH_ERROR_TYPE, + PUT_RECONSTRUCTION_ERROR_TYPE]: + function_id = error_contents[b"data"] + if function_id == NIL_FUNCTION_ID: + function_name = b"Driver" + else: + task_driver_id = worker.task_driver_id + function_name = worker.redis_client.hget( + (b"RemoteFunction:" + task_driver_id.id() + + b":" + function_id), + "name") + error_contents[b"data"] = function_name + errors.append(error_contents) - return errors + return errors def initialize_numbuf(worker=global_worker): - """Initialize the serialization library. + """Initialize the serialization library. - This defines a custom serializer for object IDs and also tells numbuf to - serialize several exception classes that we define for error handling. - """ - ray.serialization.set_callbacks() + This defines a custom serializer for object IDs and also tells numbuf to + serialize several exception classes that we define for error handling. + """ + ray.serialization.set_callbacks() - # Define a custom serializer and deserializer for handling Object IDs. - def objectid_custom_serializer(obj): - contained_objectids.append(obj) - return obj.id() + # Define a custom serializer and deserializer for handling Object IDs. + def objectid_custom_serializer(obj): + contained_objectids.append(obj) + return obj.id() - def objectid_custom_deserializer(serialized_obj): - return ray.local_scheduler.ObjectID(serialized_obj) + def objectid_custom_deserializer(serialized_obj): + return ray.local_scheduler.ObjectID(serialized_obj) - serialization.add_class_to_whitelist( - ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False, - custom_serializer=objectid_custom_serializer, - custom_deserializer=objectid_custom_deserializer) + serialization.add_class_to_whitelist( + ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False, + custom_serializer=objectid_custom_serializer, + custom_deserializer=objectid_custom_deserializer) - # Define a custom serializer and deserializer for handling numpy arrays that - # contain objects. - def array_custom_serializer(obj): - return obj.tolist(), obj.dtype.str + # Define a custom serializer and deserializer for handling numpy arrays + # that contain objects. + def array_custom_serializer(obj): + return obj.tolist(), obj.dtype.str - def array_custom_deserializer(serialized_obj): - return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1])) + def array_custom_deserializer(serialized_obj): + return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1])) - serialization.add_class_to_whitelist( - np.ndarray, 20 * b"\x01", pickle=False, - custom_serializer=array_custom_serializer, - custom_deserializer=array_custom_deserializer) + serialization.add_class_to_whitelist( + np.ndarray, 20 * b"\x01", pickle=False, + custom_serializer=array_custom_serializer, + custom_deserializer=array_custom_deserializer) - if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - # These should only be called on the driver because _register_class will - # export the class to all of the workers. - _register_class(RayTaskError) - _register_class(RayGetError) - _register_class(RayGetArgumentError) - # Tell Ray to serialize lambdas with pickle. - _register_class(type(lambda: 0), pickle=True) - # Tell Ray to serialize sets with pickle. - _register_class(type(set()), pickle=True) - # Tell Ray to serialize types with pickle. - _register_class(type(int), pickle=True) + if worker.mode in [SCRIPT_MODE, SILENT_MODE]: + # These should only be called on the driver because _register_class + # will export the class to all of the workers. + _register_class(RayTaskError) + _register_class(RayGetError) + _register_class(RayGetArgumentError) + # Tell Ray to serialize lambdas with pickle. + _register_class(type(lambda: 0), pickle=True) + # Tell Ray to serialize sets with pickle. + _register_class(type(set()), pickle=True) + # Tell Ray to serialize types with pickle. + _register_class(type(int), pickle=True) def get_address_info_from_redis_helper(redis_address, node_ip_address): - redis_ip_address, redis_port = redis_address.split(":") - # For this command to work, some other client (on the same machine as Redis) - # must have run "CONFIG SET protected-mode no". - redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) - # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.cc" where it is defined. - REDIS_CLIENT_TABLE_PREFIX = "CL:" - client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) - # Filter to live clients on the same node and do some basic checking. - plasma_managers = [] - local_schedulers = [] - for key in client_keys: - info = redis_client.hgetall(key) + redis_ip_address, redis_port = redis_address.split(":") + # For this command to work, some other client (on the same machine as + # Redis) must have run "CONFIG SET protected-mode no". + redis_client = redis.StrictRedis(host=redis_ip_address, + port=int(redis_port)) + # The client table prefix must be kept in sync with the file + # "src/common/redis_module/ray_redis_module.cc" where it is defined. + REDIS_CLIENT_TABLE_PREFIX = "CL:" + client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) + # Filter to live clients on the same node and do some basic checking. + plasma_managers = [] + local_schedulers = [] + for key in client_keys: + info = redis_client.hgetall(key) - # Ignore clients that were deleted. - deleted = info[b"deleted"] - deleted = bool(int(deleted)) - if deleted: - continue + # Ignore clients that were deleted. + deleted = info[b"deleted"] + deleted = bool(int(deleted)) + if deleted: + continue - assert b"ray_client_id" in info - assert b"node_ip_address" in info - assert b"client_type" in info - if info[b"node_ip_address"].decode("ascii") == node_ip_address: - if info[b"client_type"].decode("ascii") == "plasma_manager": - plasma_managers.append(info) - elif info[b"client_type"].decode("ascii") == "local_scheduler": - local_schedulers.append(info) - # Make sure that we got at least one plasma manager and local scheduler. - assert len(plasma_managers) >= 1 - assert len(local_schedulers) >= 1 - # Build the address information. - object_store_addresses = [] - for manager in plasma_managers: - address = manager[b"address"].decode("ascii") - port = services.get_port(address) - object_store_addresses.append( - services.ObjectStoreAddress( - name=manager[b"store_socket_name"].decode("ascii"), - manager_name=manager[b"manager_socket_name"].decode("ascii"), - manager_port=port)) - scheduler_names = [scheduler[b"local_scheduler_socket_name"].decode("ascii") - for scheduler in local_schedulers] - client_info = {"node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_addresses": object_store_addresses, - "local_scheduler_socket_names": scheduler_names, - } - return client_info + assert b"ray_client_id" in info + assert b"node_ip_address" in info + assert b"client_type" in info + if info[b"node_ip_address"].decode("ascii") == node_ip_address: + if info[b"client_type"].decode("ascii") == "plasma_manager": + plasma_managers.append(info) + elif info[b"client_type"].decode("ascii") == "local_scheduler": + local_schedulers.append(info) + # Make sure that we got at least one plasma manager and local scheduler. + assert len(plasma_managers) >= 1 + assert len(local_schedulers) >= 1 + # Build the address information. + object_store_addresses = [] + for manager in plasma_managers: + address = manager[b"address"].decode("ascii") + port = services.get_port(address) + object_store_addresses.append( + services.ObjectStoreAddress( + name=manager[b"store_socket_name"].decode("ascii"), + manager_name=manager[b"manager_socket_name"].decode("ascii"), + manager_port=port)) + scheduler_names = [ + scheduler[b"local_scheduler_socket_name"].decode("ascii") + for scheduler in local_schedulers] + client_info = {"node_ip_address": node_ip_address, + "redis_address": redis_address, + "object_store_addresses": object_store_addresses, + "local_scheduler_socket_names": scheduler_names} + return client_info def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5): - counter = 0 - while True: - try: - return get_address_info_from_redis_helper(redis_address, node_ip_address) - except Exception as e: - if counter == num_retries: - raise - # Some of the information may not be in Redis yet, so wait a little bit. - print("Some processes that the driver needs to connect to have not " - "registered with Redis, so retrying. Have you run " - "'ray start' on this node?") - time.sleep(1) - counter += 1 + counter = 0 + while True: + try: + return get_address_info_from_redis_helper(redis_address, + node_ip_address) + except Exception as e: + if counter == num_retries: + raise + # Some of the information may not be in Redis yet, so wait a little + # bit. + print("Some processes that the driver needs to connect to have " + "not registered with Redis, so retrying. Have you run " + "'ray start' on this node?") + time.sleep(1) + counter += 1 def _init(address_info=None, @@ -791,218 +820,223 @@ def _init(address_info=None, num_cpus=None, num_gpus=None, num_redis_shards=None): - """Helper method to connect to an existing Ray cluster or start a new one. + """Helper method to connect to an existing Ray cluster or start a new one. - This method handles two cases. Either a Ray cluster already exists and we - just attach this driver to it, or we start all of the processes associated - with a Ray cluster and attach to the newly started cluster. + This method handles two cases. Either a Ray cluster already exists and we + just attach this driver to it, or we start all of the processes associated + with a Ray cluster and attach to the newly started cluster. - Args: - address_info (dict): A dictionary with address information for processes in - a partially-started Ray cluster. If start_ray_local=True, any processes - not in this dictionary will be started. If provided, an updated - address_info dictionary will be returned to include processes that are - newly started. - start_ray_local (bool): If True then this will start any processes not - already in address_info, including Redis, a global scheduler, local - scheduler(s), object store(s), and worker(s). It will also kill these - processes when Python exits. If False, this will attach to an existing - Ray cluster. - object_id_seed (int): Used to seed the deterministic generation of object - IDs. The same value can be used across multiple runs of the same job in - order to generate the object IDs in a consistent manner. However, the - same ID should not be used for different jobs. - num_workers (int): The number of workers to start. This is only provided if - start_ray_local is True. - num_local_schedulers (int): The number of local schedulers to start. This - is only provided if start_ray_local is True. - driver_mode (bool): The mode in which to start the driver. This should be - one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE. - redirect_output (bool): True if stdout and stderr for all the processes - should be redirected to files and false otherwise. - start_workers_from_local_scheduler (bool): If this flag is True, then start - the initial workers from the local scheduler. Else, start them from - Python. The latter case is for debugging purposes only. - num_cpus: A list containing the number of CPUs the local schedulers should - be configured with. - num_gpus: A list containing the number of GPUs the local schedulers should - be configured with. - num_redis_shards: The number of Redis shards to start in addition to the - primary Redis shard. + Args: + address_info (dict): A dictionary with address information for + processes in a partially-started Ray cluster. If + start_ray_local=True, any processes not in this dictionary will be + started. If provided, an updated address_info dictionary will be + returned to include processes that are newly started. + start_ray_local (bool): If True then this will start any processes not + already in address_info, including Redis, a global scheduler, local + scheduler(s), object store(s), and worker(s). It will also kill + these processes when Python exits. If False, this will attach to an + existing Ray cluster. + object_id_seed (int): Used to seed the deterministic generation of + object IDs. The same value can be used across multiple runs of the + same job in order to generate the object IDs in a consistent + manner. However, the same ID should not be used for different jobs. + num_workers (int): The number of workers to start. This is only + provided if start_ray_local is True. + num_local_schedulers (int): The number of local schedulers to start. + This is only provided if start_ray_local is True. + driver_mode (bool): The mode in which to start the driver. This should + be one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE. + redirect_output (bool): True if stdout and stderr for all the processes + should be redirected to files and false otherwise. + start_workers_from_local_scheduler (bool): If this flag is True, then + start the initial workers from the local scheduler. Else, start + them from Python. The latter case is for debugging purposes only. + num_cpus: A list containing the number of CPUs the local schedulers + should be configured with. + num_gpus: A list containing the number of GPUs the local schedulers + should be configured with. + num_redis_shards: The number of Redis shards to start in addition to + the primary Redis shard. - Returns: - Address information about the started processes. + Returns: + Address information about the started processes. - Raises: - Exception: An exception is raised if an inappropriate combination of - arguments is passed in. - """ - check_main_thread() - if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: - raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, " - "ray.PYTHON_MODE, ray.SILENT_MODE].") + Raises: + Exception: An exception is raised if an inappropriate combination of + arguments is passed in. + """ + check_main_thread() + if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: + raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, " + "ray.PYTHON_MODE, ray.SILENT_MODE].") - # Get addresses of existing services. - if address_info is None: - address_info = {} - else: - assert isinstance(address_info, dict) - node_ip_address = address_info.get("node_ip_address") - redis_address = address_info.get("redis_address") + # Get addresses of existing services. + if address_info is None: + address_info = {} + else: + assert isinstance(address_info, dict) + node_ip_address = address_info.get("node_ip_address") + redis_address = address_info.get("redis_address") - # Start any services that do not yet exist. - if driver_mode == PYTHON_MODE: - # If starting Ray in PYTHON_MODE, don't start any other processes. - pass - elif start_ray_local: - # In this case, we launch a scheduler, a new object store, and some - # workers, and we connect to them. We do not launch any processes that are - # already registered in address_info. - # Use the address 127.0.0.1 in local mode. - node_ip_address = ("127.0.0.1" if node_ip_address is None - else node_ip_address) - # Use 1 local scheduler if num_local_schedulers is not provided. If - # existing local schedulers are provided, use that count as - # num_local_schedulers. - local_schedulers = address_info.get("local_scheduler_socket_names", []) - if num_local_schedulers is None: - if len(local_schedulers) > 0: - num_local_schedulers = len(local_schedulers) - else: - num_local_schedulers = 1 - # Use 1 additional redis shard if num_redis_shards is not provided. - num_redis_shards = 1 if num_redis_shards is None else num_redis_shards - # Start the scheduler, object store, and some workers. These will be killed - # by the call to cleanup(), which happens when the Python script exits. - address_info = services.start_ray_head( - address_info=address_info, - node_ip_address=node_ip_address, - num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - redirect_output=redirect_output, - start_workers_from_local_scheduler=start_workers_from_local_scheduler, - num_cpus=num_cpus, - num_gpus=num_gpus, - num_redis_shards=num_redis_shards) - else: - if redis_address is None: - raise Exception("When connecting to an existing cluster, redis_address " - "must be provided.") - if num_workers is not None: - raise Exception("When connecting to an existing cluster, num_workers " - "must not be provided.") - if num_local_schedulers is not None: - raise Exception("When connecting to an existing cluster, " - "num_local_schedulers must not be provided.") - if num_cpus is not None or num_gpus is not None: - raise Exception("When connecting to an existing cluster, num_cpus and " - "num_gpus must not be provided.") - if num_redis_shards is not None: - raise Exception("When connecting to an existing cluster, " - "num_redis_shards must not be provided.") - # Get the node IP address if one is not provided. - if node_ip_address is None: - node_ip_address = services.get_node_ip_address(redis_address) - # Get the address info of the processes to connect to from Redis. - address_info = get_address_info_from_redis(redis_address, node_ip_address) + # Start any services that do not yet exist. + if driver_mode == PYTHON_MODE: + # If starting Ray in PYTHON_MODE, don't start any other processes. + pass + elif start_ray_local: + # In this case, we launch a scheduler, a new object store, and some + # workers, and we connect to them. We do not launch any processes that + # are already registered in address_info. + # Use the address 127.0.0.1 in local mode. + node_ip_address = ("127.0.0.1" if node_ip_address is None + else node_ip_address) + # Use 1 local scheduler if num_local_schedulers is not provided. If + # existing local schedulers are provided, use that count as + # num_local_schedulers. + local_schedulers = address_info.get("local_scheduler_socket_names", []) + if num_local_schedulers is None: + if len(local_schedulers) > 0: + num_local_schedulers = len(local_schedulers) + else: + num_local_schedulers = 1 + # Use 1 additional redis shard if num_redis_shards is not provided. + num_redis_shards = 1 if num_redis_shards is None else num_redis_shards + # Start the scheduler, object store, and some workers. These will be + # killed by the call to cleanup(), which happens when the Python script + # exits. + address_info = services.start_ray_head( + address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + redirect_output=redirect_output, + start_workers_from_local_scheduler=( + start_workers_from_local_scheduler), + num_cpus=num_cpus, + num_gpus=num_gpus, + num_redis_shards=num_redis_shards) + else: + if redis_address is None: + raise Exception("When connecting to an existing cluster, " + "redis_address must be provided.") + if num_workers is not None: + raise Exception("When connecting to an existing cluster, " + "num_workers must not be provided.") + if num_local_schedulers is not None: + raise Exception("When connecting to an existing cluster, " + "num_local_schedulers must not be provided.") + if num_cpus is not None or num_gpus is not None: + raise Exception("When connecting to an existing cluster, num_cpus " + "and num_gpus must not be provided.") + if num_redis_shards is not None: + raise Exception("When connecting to an existing cluster, " + "num_redis_shards must not be provided.") + # Get the node IP address if one is not provided. + if node_ip_address is None: + node_ip_address = services.get_node_ip_address(redis_address) + # Get the address info of the processes to connect to from Redis. + address_info = get_address_info_from_redis(redis_address, + node_ip_address) - # Connect this driver to Redis, the object store, and the local scheduler. - # Choose the first object store and local scheduler if there are multiple. - # The corresponding call to disconnect will happen in the call to cleanup() - # when the Python script exits. - if driver_mode == PYTHON_MODE: - driver_address_info = {} - else: - driver_address_info = { - "node_ip_address": node_ip_address, - "redis_address": address_info["redis_address"], - "store_socket_name": address_info["object_store_addresses"][0].name, - "manager_socket_name": (address_info["object_store_addresses"][0] - .manager_name), - "local_scheduler_socket_name": (address_info - ["local_scheduler_socket_names"][0])} - connect(driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, - worker=global_worker, actor_id=NIL_ACTOR_ID) - return address_info + # Connect this driver to Redis, the object store, and the local scheduler. + # Choose the first object store and local scheduler if there are multiple. + # The corresponding call to disconnect will happen in the call to cleanup() + # when the Python script exits. + if driver_mode == PYTHON_MODE: + driver_address_info = {} + else: + driver_address_info = { + "node_ip_address": node_ip_address, + "redis_address": address_info["redis_address"], + "store_socket_name": ( + address_info["object_store_addresses"][0].name), + "manager_socket_name": ( + address_info["object_store_addresses"][0].manager_name), + "local_scheduler_socket_name": ( + address_info["local_scheduler_socket_names"][0])} + connect(driver_address_info, object_id_seed=object_id_seed, + mode=driver_mode, worker=global_worker, actor_id=NIL_ACTOR_ID) + return address_info def init(redis_address=None, node_ip_address=None, object_id_seed=None, num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False, num_cpus=None, num_gpus=None, num_redis_shards=None): - """Either connect to an existing Ray cluster or start one and connect to it. + """Connect to an existing Ray cluster or start one and connect to it. - This method handles two cases. Either a Ray cluster already exists and we - just attach this driver to it, or we start all of the processes associated - with a Ray cluster and attach to the newly started cluster. + This method handles two cases. Either a Ray cluster already exists and we + just attach this driver to it, or we start all of the processes associated + with a Ray cluster and attach to the newly started cluster. - Args: - node_ip_address (str): The IP address of the node that we are on. - redis_address (str): The address of the Redis server to connect to. If this - address is not provided, then this command will start Redis, a global - scheduler, a local scheduler, a plasma store, a plasma manager, and some - workers. It will also kill these processes when Python exits. - object_id_seed (int): Used to seed the deterministic generation of object - IDs. The same value can be used across multiple runs of the same job in - order to generate the object IDs in a consistent manner. However, the - same ID should not be used for different jobs. - num_workers (int): The number of workers to start. This is only provided if - redis_address is not provided. - driver_mode (bool): The mode in which to start the driver. This should be - one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE. - redirect_output (bool): True if stdout and stderr for all the processes - should be redirected to files and false otherwise. - num_cpus (int): Number of cpus the user wishes all local schedulers to be - configured with. - num_gpus (int): Number of gpus the user wishes all local schedulers to be - configured with. - num_redis_shards: The number of Redis shards to start in addition to the - primary Redis shard. + Args: + node_ip_address (str): The IP address of the node that we are on. + redis_address (str): The address of the Redis server to connect to. If + this address is not provided, then this command will start Redis, a + global scheduler, a local scheduler, a plasma store, a plasma + manager, and some workers. It will also kill these processes when + Python exits. + object_id_seed (int): Used to seed the deterministic generation of + object IDs. The same value can be used across multiple runs of the + same job in order to generate the object IDs in a consistent + manner. However, the same ID should not be used for different jobs. + num_workers (int): The number of workers to start. This is only + provided if redis_address is not provided. + driver_mode (bool): The mode in which to start the driver. This should + be one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE. + redirect_output (bool): True if stdout and stderr for all the processes + should be redirected to files and false otherwise. + num_cpus (int): Number of cpus the user wishes all local schedulers to + be configured with. + num_gpus (int): Number of gpus the user wishes all local schedulers to + be configured with. + num_redis_shards: The number of Redis shards to start in addition to + the primary Redis shard. - Returns: - Address information about the started processes. + Returns: + Address information about the started processes. - Raises: - Exception: An exception is raised if an inappropriate combination of - arguments is passed in. - """ - info = {"node_ip_address": node_ip_address, - "redis_address": redis_address} - return _init(address_info=info, start_ray_local=(redis_address is None), - num_workers=num_workers, driver_mode=driver_mode, - redirect_output=redirect_output, num_cpus=num_cpus, - num_gpus=num_gpus, num_redis_shards=num_redis_shards) + Raises: + Exception: An exception is raised if an inappropriate combination of + arguments is passed in. + """ + info = {"node_ip_address": node_ip_address, + "redis_address": redis_address} + return _init(address_info=info, start_ray_local=(redis_address is None), + num_workers=num_workers, driver_mode=driver_mode, + redirect_output=redirect_output, num_cpus=num_cpus, + num_gpus=num_gpus, num_redis_shards=num_redis_shards) def cleanup(worker=global_worker): - """Disconnect the worker, and terminate any processes started in init. + """Disconnect the worker, and terminate any processes started in init. - This will automatically run at the end when a Python process that uses Ray - exits. It is ok to run this twice in a row. Note that we manually call - services.cleanup() in the tests because we need to start and stop many - clusters in the tests, but the import and exit only happen once. - """ - disconnect(worker) - if hasattr(worker, "local_scheduler_client"): - del worker.local_scheduler_client - if hasattr(worker, "plasma_client"): - worker.plasma_client.shutdown() + This will automatically run at the end when a Python process that uses Ray + exits. It is ok to run this twice in a row. Note that we manually call + services.cleanup() in the tests because we need to start and stop many + clusters in the tests, but the import and exit only happen once. + """ + disconnect(worker) + if hasattr(worker, "local_scheduler_client"): + del worker.local_scheduler_client + if hasattr(worker, "plasma_client"): + worker.plasma_client.shutdown() - if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - # If this is a driver, push the finish time to Redis and clean up any - # other services that were started with the driver. - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, - {"end_time": time.time()}) - services.cleanup() - else: - # If this is not a driver, make sure there are no orphan processes, besides - # possibly the worker itself. - for process_type, processes in services.all_processes.items(): - if process_type == services.PROCESS_TYPE_WORKER: - assert(len(processes)) <= 1 - else: - assert(len(processes) == 0) + if worker.mode in [SCRIPT_MODE, SILENT_MODE]: + # If this is a driver, push the finish time to Redis and clean up any + # other services that were started with the driver. + worker.redis_client.hmset(b"Drivers:" + worker.worker_id, + {"end_time": time.time()}) + services.cleanup() + else: + # If this is not a driver, make sure there are no orphan processes, + # besides possibly the worker itself. + for process_type, processes in services.all_processes.items(): + if process_type == services.PROCESS_TYPE_WORKER: + assert(len(processes)) <= 1 + else: + assert(len(processes) == 0) - worker.set_mode(None) + worker.set_mode(None) atexit.register(cleanup) @@ -1013,1128 +1047,1166 @@ normal_excepthook = sys.excepthook def custom_excepthook(type, value, tb): - # If this is a driver, push the exception to redis. - if global_worker.mode in [SCRIPT_MODE, SILENT_MODE]: - error_message = "".join(traceback.format_tb(tb)) - global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id, - {"exception": error_message}) - # Call the normal excepthook. - normal_excepthook(type, value, tb) + # If this is a driver, push the exception to redis. + if global_worker.mode in [SCRIPT_MODE, SILENT_MODE]: + error_message = "".join(traceback.format_tb(tb)) + global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id, + {"exception": error_message}) + # Call the normal excepthook. + normal_excepthook(type, value, tb) sys.excepthook = custom_excepthook def print_error_messages(worker): - """Print error messages in the background on the driver. + """Print error messages in the background on the driver. - This runs in a separate thread on the driver and prints error messages in the - background. - """ - # TODO(rkn): All error messages should have a "component" field indicating - # which process the error came from (e.g., a worker or a plasma store). - # Currently all error messages come from workers. + This runs in a separate thread on the driver and prints error messages in + the background. + """ + # TODO(rkn): All error messages should have a "component" field indicating + # which process the error came from (e.g., a worker or a plasma store). + # Currently all error messages come from workers. - helpful_message = """ -You can inspect errors by running + helpful_message = """ + You can inspect errors by running - ray.error_info() + ray.error_info() -If this driver is hanging, start a new one with + If this driver is hanging, start a new one with - ray.init(redis_address="{}") -""".format(worker.redis_address) + ray.init(redis_address="{}") + """.format(worker.redis_address) - worker.error_message_pubsub_client = worker.redis_client.pubsub() - # Exports that are published after the call to - # error_message_pubsub_client.psubscribe and before the call to - # error_message_pubsub_client.listen will still be processed in the loop. - worker.error_message_pubsub_client.psubscribe("__keyspace@0__:ErrorKeys") - num_errors_received = 0 + worker.error_message_pubsub_client = worker.redis_client.pubsub() + # Exports that are published after the call to + # error_message_pubsub_client.psubscribe and before the call to + # error_message_pubsub_client.listen will still be processed in the loop. + worker.error_message_pubsub_client.psubscribe("__keyspace@0__:ErrorKeys") + num_errors_received = 0 - # Get the exports that occurred before the call to psubscribe. - with worker.lock: - error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) - for error_key in error_keys: - if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget(error_key, - "message").decode("ascii") - print(error_message) - print(helpful_message) - num_errors_received += 1 + # Get the exports that occurred before the call to psubscribe. + with worker.lock: + error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) + for error_key in error_keys: + if error_applies_to_driver(error_key, worker=worker): + error_message = worker.redis_client.hget( + error_key, "message").decode("ascii") + print(error_message) + print(helpful_message) + num_errors_received += 1 - try: - for msg in worker.error_message_pubsub_client.listen(): - with worker.lock: - for error_key in worker.redis_client.lrange("ErrorKeys", - num_errors_received, -1): - if error_applies_to_driver(error_key, worker=worker): - error_message = worker.redis_client.hget(error_key, - "message").decode("ascii") - print(error_message) - print(helpful_message) - num_errors_received += 1 - except redis.ConnectionError: - # When Redis terminates the listen call will throw a ConnectionError, which - # we catch here. - pass + try: + for msg in worker.error_message_pubsub_client.listen(): + with worker.lock: + for error_key in worker.redis_client.lrange( + "ErrorKeys", num_errors_received, -1): + if error_applies_to_driver(error_key, worker=worker): + error_message = worker.redis_client.hget( + error_key, "message").decode("ascii") + print(error_message) + print(helpful_message) + num_errors_received += 1 + except redis.ConnectionError: + # When Redis terminates the listen call will throw a ConnectionError, + # which we catch here. + pass def fetch_and_register_remote_function(key, worker=global_worker): - """Import a remote function.""" - (driver_id, function_id_str, function_name, - serialized_function, num_return_vals, module, - num_cpus, num_gpus, max_calls) = worker.redis_client.hmget( - key, ["driver_id", - "function_id", - "name", - "function", - "num_return_vals", - "module", - "num_cpus", - "num_gpus", - "max_calls"]) - function_id = ray.local_scheduler.ObjectID(function_id_str) - function_name = function_name.decode("ascii") - function_properties = FunctionProperties( - num_return_vals=int(num_return_vals), - num_cpus=int(num_cpus), - num_gpus=int(num_gpus), - max_calls=int(max_calls)) - module = module.decode("ascii") + """Import a remote function.""" + (driver_id, function_id_str, function_name, + serialized_function, num_return_vals, module, + num_cpus, num_gpus, max_calls) = worker.redis_client.hmget( + key, ["driver_id", + "function_id", + "name", + "function", + "num_return_vals", + "module", + "num_cpus", + "num_gpus", + "max_calls"]) + function_id = ray.local_scheduler.ObjectID(function_id_str) + function_name = function_name.decode("ascii") + function_properties = FunctionProperties( + num_return_vals=int(num_return_vals), + num_cpus=int(num_cpus), + num_gpus=int(num_gpus), + max_calls=int(max_calls)) + module = module.decode("ascii") - # This is a placeholder in case the function can't be unpickled. This will be - # overwritten if the function is successfully registered. - def f(): - raise Exception("This function was not imported properly.") - remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f()) - worker.functions[driver_id][function_id.id()] = (function_name, - remote_f_placeholder) - worker.function_properties[driver_id][function_id.id()] = function_properties - worker.num_task_executions[driver_id][function_id.id()] = 0 + # This is a placeholder in case the function can't be unpickled. This will + # be overwritten if the function is successfully registered. + def f(): + raise Exception("This function was not imported properly.") + remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f()) + worker.functions[driver_id][function_id.id()] = (function_name, + remote_f_placeholder) + worker.function_properties[driver_id][function_id.id()] = ( + function_properties) + worker.num_task_executions[driver_id][function_id.id()] = 0 - try: - function = pickle.loads(serialized_function) - except: - # If an exception was thrown when the remote function was imported, we - # record the traceback and notify the scheduler of the failure. - traceback_str = format_error_message(traceback.format_exc()) - # Log the error message. - worker.push_error_to_driver(driver_id, "register_remote_function", - traceback_str, - data={"function_id": function_id.id(), - "function_name": function_name}) - else: - # TODO(rkn): Why is the below line necessary? - function.__module__ = module - worker.functions[driver_id][function_id.id()] = ( - function_name, remote(function_id=function_id)(function)) - # Add the function to the function table. - worker.redis_client.rpush(b"FunctionTable:" + function_id.id(), - worker.worker_id) + try: + function = pickle.loads(serialized_function) + except: + # If an exception was thrown when the remote function was imported, we + # record the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + # Log the error message. + worker.push_error_to_driver(driver_id, "register_remote_function", + traceback_str, + data={"function_id": function_id.id(), + "function_name": function_name}) + else: + # TODO(rkn): Why is the below line necessary? + function.__module__ = module + worker.functions[driver_id][function_id.id()] = ( + function_name, remote(function_id=function_id)(function)) + # Add the function to the function table. + worker.redis_client.rpush(b"FunctionTable:" + function_id.id(), + worker.worker_id) def fetch_and_execute_function_to_run(key, worker=global_worker): - """Run on arbitrary function on the worker.""" - driver_id, serialized_function = worker.redis_client.hmget( - key, ["driver_id", "function"]) - # Get the number of workers on this node that have already started executing - # this remote function, and increment that value. Subtract 1 so the counter - # starts at 0. - counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1 - try: - # Deserialize the function. - function = pickle.loads(serialized_function) - # Run the function. - function({"counter": counter}) - except: - # If an exception was thrown when the function was run, we record the - # traceback and notify the scheduler of the failure. - traceback_str = traceback.format_exc() - # Log the error message. - name = function.__name__ if ("function" in locals() and - hasattr(function, "__name__")) else "" - worker.push_error_to_driver(driver_id, "function_to_run", traceback_str, - data={"name": name}) + """Run on arbitrary function on the worker.""" + driver_id, serialized_function = worker.redis_client.hmget( + key, ["driver_id", "function"]) + # Get the number of workers on this node that have already started + # executing this remote function, and increment that value. Subtract 1 so + # the counter starts at 0. + counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1 + try: + # Deserialize the function. + function = pickle.loads(serialized_function) + # Run the function. + function({"counter": counter}) + except: + # If an exception was thrown when the function was run, we record the + # traceback and notify the scheduler of the failure. + traceback_str = traceback.format_exc() + # Log the error message. + name = function.__name__ if ("function" in locals() and + hasattr(function, "__name__")) else "" + worker.push_error_to_driver(driver_id, "function_to_run", + traceback_str, data={"name": name}) def import_thread(worker, mode): - worker.import_pubsub_client = worker.redis_client.pubsub() - # Exports that are published after the call to - # import_pubsub_client.psubscribe and before the call to - # import_pubsub_client.listen will still be processed in the loop. - worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports") - # Keep track of the number of imports that we've imported. - num_imported = 0 + worker.import_pubsub_client = worker.redis_client.pubsub() + # Exports that are published after the call to + # import_pubsub_client.psubscribe and before the call to + # import_pubsub_client.listen will still be processed in the loop. + worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports") + # Keep track of the number of imports that we've imported. + num_imported = 0 - # Get the exports that occurred before the call to psubscribe. - with worker.lock: - export_keys = worker.redis_client.lrange("Exports", 0, -1) - for key in export_keys: - num_imported += 1 + # Get the exports that occurred before the call to psubscribe. + with worker.lock: + export_keys = worker.redis_client.lrange("Exports", 0, -1) + for key in export_keys: + num_imported += 1 - # Handle the driver case first. - if mode != WORKER_MODE: - if key.startswith(b"FunctionsToRun"): - fetch_and_execute_function_to_run(key, worker=worker) - # Continue because FunctionsToRun are the only things that the driver - # should import. - continue + # Handle the driver case first. + if mode != WORKER_MODE: + if key.startswith(b"FunctionsToRun"): + fetch_and_execute_function_to_run(key, worker=worker) + # Continue because FunctionsToRun are the only things that the + # driver should import. + continue - if key.startswith(b"RemoteFunction"): - fetch_and_register_remote_function(key, worker=worker) - elif key.startswith(b"FunctionsToRun"): - fetch_and_execute_function_to_run(key, worker=worker) - elif key.startswith(b"ActorClass"): - # If this worker is an actor that is supposed to construct this class, - # fetch the actor and class information and construct the class. - class_id = key.split(b":", 1)[1] - if worker.actor_id != NIL_ACTOR_ID and worker.class_id == class_id: - worker.fetch_and_register_actor(key, worker) - else: - raise Exception("This code should be unreachable.") - - try: - for msg in worker.import_pubsub_client.listen(): - with worker.lock: - if msg["type"] == "psubscribe": - continue - assert msg["data"] == b"rpush" - num_imports = worker.redis_client.llen("Exports") - assert num_imports >= num_imported - for i in range(num_imported, num_imports): - num_imported += 1 - key = worker.redis_client.lindex("Exports", i) - - # Handle the driver case first. - if mode != WORKER_MODE: - if key.startswith(b"FunctionsToRun"): - with log_span("ray:import_function_to_run", worker=worker): + if key.startswith(b"RemoteFunction"): + fetch_and_register_remote_function(key, worker=worker) + elif key.startswith(b"FunctionsToRun"): fetch_and_execute_function_to_run(key, worker=worker) - # Continue because FunctionsToRun are the only things that the - # driver should import. - continue + elif key.startswith(b"ActorClass"): + # If this worker is an actor that is supposed to construct this + # class, fetch the actor and class information and construct + # the class. + class_id = key.split(b":", 1)[1] + if (worker.actor_id != NIL_ACTOR_ID and + worker.class_id == class_id): + worker.fetch_and_register_actor(key, worker) + else: + raise Exception("This code should be unreachable.") - if key.startswith(b"RemoteFunction"): - with log_span("ray:import_remote_function", worker=worker): - fetch_and_register_remote_function(key, worker=worker) - elif key.startswith(b"FunctionsToRun"): - with log_span("ray:import_function_to_run", worker=worker): - fetch_and_execute_function_to_run(key, worker=worker) - elif key.startswith(b"Actor"): - # Only get the actor if the actor ID matches the actor ID of this - # worker. - actor_id, = worker.redis_client.hmget(key, "actor_id") - if worker.actor_id == actor_id: - worker.fetch_and_register["Actor"](key, worker) - else: - raise Exception("This code should be unreachable.") - except redis.ConnectionError: - # When Redis terminates the listen call will throw a ConnectionError, which - # we catch here. - pass + try: + for msg in worker.import_pubsub_client.listen(): + with worker.lock: + if msg["type"] == "psubscribe": + continue + assert msg["data"] == b"rpush" + num_imports = worker.redis_client.llen("Exports") + assert num_imports >= num_imported + for i in range(num_imported, num_imports): + num_imported += 1 + key = worker.redis_client.lindex("Exports", i) + + # Handle the driver case first. + if mode != WORKER_MODE: + if key.startswith(b"FunctionsToRun"): + with log_span("ray:import_function_to_run", + worker=worker): + fetch_and_execute_function_to_run( + key, worker=worker) + # Continue because FunctionsToRun are the only things + # that the driver should import. + continue + + if key.startswith(b"RemoteFunction"): + with log_span("ray:import_remote_function", + worker=worker): + fetch_and_register_remote_function(key, + worker=worker) + elif key.startswith(b"FunctionsToRun"): + with log_span("ray:import_function_to_run", + worker=worker): + fetch_and_execute_function_to_run(key, + worker=worker) + elif key.startswith(b"Actor"): + # Only get the actor if the actor ID matches the actor + # ID of this worker. + actor_id, = worker.redis_client.hmget(key, "actor_id") + if worker.actor_id == actor_id: + worker.fetch_and_register["Actor"](key, worker) + else: + raise Exception("This code should be unreachable.") + except redis.ConnectionError: + # When Redis terminates the listen call will throw a ConnectionError, + # which we catch here. + pass def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, actor_id=NIL_ACTOR_ID): - """Connect this worker to the local scheduler, to Plasma, and to Redis. + """Connect this worker to the local scheduler, to Plasma, and to Redis. - Args: - info (dict): A dictionary with address of the Redis server and the sockets - of the plasma store, plasma manager, and local scheduler. - object_id_seed: A seed to use to make the generation of object IDs - deterministic. - mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, - and SILENT_MODE. - actor_id: The ID of the actor running on this worker. If this worker is not - an actor, then this is NIL_ACTOR_ID. - """ - check_main_thread() - # Do some basic checking to make sure we didn't call ray.init twice. - error_message = "Perhaps you called ray.init twice by accident?" - assert not worker.connected, error_message - assert worker.cached_functions_to_run is not None, error_message - assert worker.cached_remote_functions is not None, error_message - # Initialize some fields. - worker.worker_id = random_string() - worker.actor_id = actor_id - worker.connected = True - worker.set_mode(mode) - # Redirect worker output and error to their own files. - if mode == WORKER_MODE: - log_stdout_file, log_stderr_file = services.new_log_files("worker", True) - sys.stdout = log_stdout_file - sys.stderr = log_stderr_file - services.record_log_files_in_redis(info["redis_address"], - info["node_ip_address"], - [log_stdout_file, log_stderr_file]) - # The worker.events field is used to aggregate logging information and - # display it in the web UI. Note that Python lists protected by the GIL, - # which is important because we will append to this field from multiple - # threads. - worker.events = [] - # If running Ray in PYTHON_MODE, there is no need to create call - # create_worker or to start the worker service. - if mode == PYTHON_MODE: - return - # Set the node IP address. - worker.node_ip_address = info["node_ip_address"] - worker.redis_address = info["redis_address"] - # Create a Redis client. - redis_ip_address, redis_port = info["redis_address"].split(":") - worker.redis_client = redis.StrictRedis(host=redis_ip_address, - port=int(redis_port)) - worker.lock = threading.Lock() + Args: + info (dict): A dictionary with address of the Redis server and the + sockets of the plasma store, plasma manager, and local scheduler. + object_id_seed: A seed to use to make the generation of object IDs + deterministic. + mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, + PYTHON_MODE, and SILENT_MODE. + actor_id: The ID of the actor running on this worker. If this worker is + not an actor, then this is NIL_ACTOR_ID. + """ + check_main_thread() + # Do some basic checking to make sure we didn't call ray.init twice. + error_message = "Perhaps you called ray.init twice by accident?" + assert not worker.connected, error_message + assert worker.cached_functions_to_run is not None, error_message + assert worker.cached_remote_functions is not None, error_message + # Initialize some fields. + worker.worker_id = random_string() + worker.actor_id = actor_id + worker.connected = True + worker.set_mode(mode) + # Redirect worker output and error to their own files. + if mode == WORKER_MODE: + log_stdout_file, log_stderr_file = services.new_log_files("worker", + True) + sys.stdout = log_stdout_file + sys.stderr = log_stderr_file + services.record_log_files_in_redis(info["redis_address"], + info["node_ip_address"], + [log_stdout_file, log_stderr_file]) + # The worker.events field is used to aggregate logging information and + # display it in the web UI. Note that Python lists protected by the GIL, + # which is important because we will append to this field from multiple + # threads. + worker.events = [] + # If running Ray in PYTHON_MODE, there is no need to create call + # create_worker or to start the worker service. + if mode == PYTHON_MODE: + return + # Set the node IP address. + worker.node_ip_address = info["node_ip_address"] + worker.redis_address = info["redis_address"] + # Create a Redis client. + redis_ip_address, redis_port = info["redis_address"].split(":") + worker.redis_client = redis.StrictRedis(host=redis_ip_address, + port=int(redis_port)) + worker.lock = threading.Lock() - # Create an object for interfacing with the global state. - global_state._initialize_global_state(redis_ip_address, int(redis_port)) + # Create an object for interfacing with the global state. + global_state._initialize_global_state(redis_ip_address, int(redis_port)) - # Register the worker with Redis. - if mode in [SCRIPT_MODE, SILENT_MODE]: - # The concept of a driver is the same as the concept of a "job". Register - # the driver/job with Redis here. - import __main__ as main - driver_info = { - "node_ip_address": worker.node_ip_address, - "driver_id": worker.worker_id, - "start_time": time.time(), - "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info["manager_socket_name"], - "local_scheduler_socket": info["local_scheduler_socket_name"]} - driver_info["name"] = (main.__file__ if hasattr(main, "__file__") - else "INTERACTIVE MODE") - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) - is_worker = False - elif mode == WORKER_MODE: # Register the worker with Redis. - worker.redis_client.hmset( - b"Workers:" + worker.worker_id, - {"node_ip_address": worker.node_ip_address, - "stdout_file": os.path.abspath(log_stdout_file.name), - "stderr_file": os.path.abspath(log_stderr_file.name), - "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info["manager_socket_name"], - "local_scheduler_socket": info["local_scheduler_socket_name"]}) - is_worker = True - else: - raise Exception("This code should be unreachable.") - - # Create an object store client. - worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], - info["manager_socket_name"]) - # Create the local scheduler client. - if worker.actor_id != NIL_ACTOR_ID: - num_gpus = int(worker.redis_client.hget(b"Actor:" + actor_id, - "num_gpus")) - else: - num_gpus = 0 - worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( - info["local_scheduler_socket_name"], worker.worker_id, worker.actor_id, - is_worker, num_gpus) - - # If this is a driver, set the current task ID, the task driver ID, and set - # the task index to 0. - if mode in [SCRIPT_MODE, SILENT_MODE]: - # If the user provided an object_id_seed, then set the current task ID - # deterministically based on that seed (without altering the state of the - # user's random number generator). Otherwise, set the current task ID - # randomly to avoid object ID collisions. - numpy_state = np.random.get_state() - if object_id_seed is not None: - np.random.seed(object_id_seed) + if mode in [SCRIPT_MODE, SILENT_MODE]: + # The concept of a driver is the same as the concept of a "job". + # Register the driver/job with Redis here. + import __main__ as main + driver_info = { + "node_ip_address": worker.node_ip_address, + "driver_id": worker.worker_id, + "start_time": time.time(), + "plasma_store_socket": info["store_socket_name"], + "plasma_manager_socket": info["manager_socket_name"], + "local_scheduler_socket": info["local_scheduler_socket_name"]} + driver_info["name"] = (main.__file__ if hasattr(main, "__file__") + else "INTERACTIVE MODE") + worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) + is_worker = False + elif mode == WORKER_MODE: + # Register the worker with Redis. + worker.redis_client.hmset( + b"Workers:" + worker.worker_id, + {"node_ip_address": worker.node_ip_address, + "stdout_file": os.path.abspath(log_stdout_file.name), + "stderr_file": os.path.abspath(log_stderr_file.name), + "plasma_store_socket": info["store_socket_name"], + "plasma_manager_socket": info["manager_socket_name"], + "local_scheduler_socket": info["local_scheduler_socket_name"]}) + is_worker = True else: - # Try to use true randomness. - np.random.seed(None) - worker.current_task_id = ray.local_scheduler.ObjectID(np.random.bytes(20)) - # When tasks are executed on remote workers in the context of multiple - # drivers, the task driver ID is used to keep track of which driver is - # responsible for the task so that error messages will be propagated to the - # correct driver. - worker.task_driver_id = ray.local_scheduler.ObjectID(worker.worker_id) - # Reset the state of the numpy random number generator. - np.random.set_state(numpy_state) - # Set other fields needed for computing task IDs. - worker.task_index = 0 - worker.put_index = 0 + raise Exception("This code should be unreachable.") - # Create an entry for the driver task in the task table. This task is added - # immediately with status RUNNING. This allows us to push errors related to - # this driver task back to the driver. For example, if the driver creates - # an object that is later evicted, we should notify the user that we're - # unable to reconstruct the object, since we cannot rerun the driver. - driver_task = ray.local_scheduler.Task( - worker.task_driver_id, - ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), - [], - 0, - worker.current_task_id, - worker.task_index, - ray.local_scheduler.ObjectID(NIL_ACTOR_ID), - worker.actor_counters[actor_id], - [0, 0]) - global_state._execute_command( - driver_task.task_id(), - "RAY.TASK_TABLE_ADD", - driver_task.task_id().id(), - TASK_STATUS_RUNNING, - NIL_LOCAL_SCHEDULER_ID, - ray.local_scheduler.task_to_string(driver_task)) - # Set the driver's current task ID to the task ID assigned to the driver - # task. - worker.current_task_id = driver_task.task_id() + # Create an object store client. + worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], + info["manager_socket_name"]) + # Create the local scheduler client. + if worker.actor_id != NIL_ACTOR_ID: + num_gpus = int(worker.redis_client.hget(b"Actor:" + actor_id, + "num_gpus")) + else: + num_gpus = 0 + worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( + info["local_scheduler_socket_name"], worker.worker_id, worker.actor_id, + is_worker, num_gpus) - # If this is an actor, get the ID of the corresponding class for the actor. - if worker.actor_id != NIL_ACTOR_ID: - actor_key = b"Actor:" + worker.actor_id - class_id = worker.redis_client.hget(actor_key, "class_id") - worker.class_id = class_id + # If this is a driver, set the current task ID, the task driver ID, and set + # the task index to 0. + if mode in [SCRIPT_MODE, SILENT_MODE]: + # If the user provided an object_id_seed, then set the current task ID + # deterministically based on that seed (without altering the state of + # the user's random number generator). Otherwise, set the current task + # ID randomly to avoid object ID collisions. + numpy_state = np.random.get_state() + if object_id_seed is not None: + np.random.seed(object_id_seed) + else: + # Try to use true randomness. + np.random.seed(None) + worker.current_task_id = ray.local_scheduler.ObjectID( + np.random.bytes(20)) + # When tasks are executed on remote workers in the context of multiple + # drivers, the task driver ID is used to keep track of which driver is + # responsible for the task so that error messages will be propagated to + # the correct driver. + worker.task_driver_id = ray.local_scheduler.ObjectID(worker.worker_id) + # Reset the state of the numpy random number generator. + np.random.set_state(numpy_state) + # Set other fields needed for computing task IDs. + worker.task_index = 0 + worker.put_index = 0 - # Start a thread to import exports from the driver or from other workers. - # Note that the driver also has an import thread, which is used only to - # import custom class definitions from calls to _register_class that happen - # under the hood on workers. - t = threading.Thread(target=import_thread, args=(worker, mode)) - # Making the thread a daemon causes it to exit when the main thread exits. - t.daemon = True - t.start() + # Create an entry for the driver task in the task table. This task is + # added immediately with status RUNNING. This allows us to push errors + # related to this driver task back to the driver. For example, if the + # driver creates an object that is later evicted, we should notify the + # user that we're unable to reconstruct the object, since we cannot + # rerun the driver. + driver_task = ray.local_scheduler.Task( + worker.task_driver_id, + ray.local_scheduler.ObjectID(NIL_FUNCTION_ID), + [], + 0, + worker.current_task_id, + worker.task_index, + ray.local_scheduler.ObjectID(NIL_ACTOR_ID), + worker.actor_counters[actor_id], + [0, 0]) + global_state._execute_command( + driver_task.task_id(), + "RAY.TASK_TABLE_ADD", + driver_task.task_id().id(), + TASK_STATUS_RUNNING, + NIL_LOCAL_SCHEDULER_ID, + ray.local_scheduler.task_to_string(driver_task)) + # Set the driver's current task ID to the task ID assigned to the + # driver task. + worker.current_task_id = driver_task.task_id() - # If this is a driver running in SCRIPT_MODE, start a thread to print error - # messages asynchronously in the background. Ideally the scheduler would push - # messages to the driver's worker service, but we ran into bugs when trying - # to properly shutdown the driver's worker service, so we are temporarily - # using this implementation which constantly queries the scheduler for new - # error messages. - if mode == SCRIPT_MODE: - t = threading.Thread(target=print_error_messages, args=(worker,)) + # If this is an actor, get the ID of the corresponding class for the actor. + if worker.actor_id != NIL_ACTOR_ID: + actor_key = b"Actor:" + worker.actor_id + class_id = worker.redis_client.hget(actor_key, "class_id") + worker.class_id = class_id + + # Start a thread to import exports from the driver or from other workers. + # Note that the driver also has an import thread, which is used only to + # import custom class definitions from calls to _register_class that happen + # under the hood on workers. + t = threading.Thread(target=import_thread, args=(worker, mode)) # Making the thread a daemon causes it to exit when the main thread exits. t.daemon = True t.start() - # Initialize the serialization library. This registers some classes, and so - # it must be run before we export all of the cached remote functions. - initialize_numbuf() - if mode in [SCRIPT_MODE, SILENT_MODE]: - # Add the directory containing the script that is running to the Python - # paths of the workers. Also add the current directory. Note that this - # assumes that the directory structures on the machines in the clusters are - # the same. - script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) - current_directory = os.path.abspath(os.path.curdir) - worker.run_function_on_all_workers( - lambda worker_info: sys.path.insert(1, script_directory)) - worker.run_function_on_all_workers( - lambda worker_info: sys.path.insert(1, current_directory)) - # TODO(rkn): Here we first export functions to run, then remote functions. - # The order matters. For example, one of the functions to run may set the - # Python path, which is needed to import a module used to define a remote - # function. We may want to change the order to simply be the order in which - # the exports were defined on the driver. In addition, we will need to - # retain the ability to decide what the first few exports are (mostly to - # set the Python path). Additionally, note that the first exports to be - # defined on the driver will be the ones defined in separate modules that - # are imported by the driver. - # Export cached functions_to_run. - for function in worker.cached_functions_to_run: - worker.run_function_on_all_workers(function) - # Export cached remote functions to the workers. - for info in worker.cached_remote_functions: - (function_id, func_name, func, func_invoker, function_properties) = info - export_remote_function(function_id, func_name, func, func_invoker, - function_properties, worker) - worker.cached_functions_to_run = None - worker.cached_remote_functions = None + + # If this is a driver running in SCRIPT_MODE, start a thread to print error + # messages asynchronously in the background. Ideally the scheduler would + # push messages to the driver's worker service, but we ran into bugs when + # trying to properly shutdown the driver's worker service, so we are + # temporarily using this implementation which constantly queries the + # scheduler for new error messages. + if mode == SCRIPT_MODE: + t = threading.Thread(target=print_error_messages, args=(worker,)) + # Making the thread a daemon causes it to exit when the main thread + # exits. + t.daemon = True + t.start() + # Initialize the serialization library. This registers some classes, and so + # it must be run before we export all of the cached remote functions. + initialize_numbuf() + if mode in [SCRIPT_MODE, SILENT_MODE]: + # Add the directory containing the script that is running to the Python + # paths of the workers. Also add the current directory. Note that this + # assumes that the directory structures on the machines in the clusters + # are the same. + script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) + current_directory = os.path.abspath(os.path.curdir) + worker.run_function_on_all_workers( + lambda worker_info: sys.path.insert(1, script_directory)) + worker.run_function_on_all_workers( + lambda worker_info: sys.path.insert(1, current_directory)) + # TODO(rkn): Here we first export functions to run, then remote + # functions. The order matters. For example, one of the functions to + # run may set the Python path, which is needed to import a module used + # to define a remote function. We may want to change the order to + # simply be the order in which the exports were defined on the driver. + # In addition, we will need to retain the ability to decide what the + # first few exports are (mostly to set the Python path). Additionally, + # note that the first exports to be defined on the driver will be the + # ones defined in separate modules that are imported by the driver. + # Export cached functions_to_run. + for function in worker.cached_functions_to_run: + worker.run_function_on_all_workers(function) + # Export cached remote functions to the workers. + for info in worker.cached_remote_functions: + (function_id, func_name, func, + func_invoker, function_properties) = info + export_remote_function(function_id, func_name, func, func_invoker, + function_properties, worker) + worker.cached_functions_to_run = None + worker.cached_remote_functions = None def disconnect(worker=global_worker): - """Disconnect this worker from the scheduler and object store.""" - # Reset the list of cached remote functions so that if more remote functions - # are defined and then connect is called again, the remote functions will be - # exported. This is mostly relevant for the tests. - worker.connected = False - worker.cached_functions_to_run = [] - worker.cached_remote_functions = [] - serialization.clear_state() + """Disconnect this worker from the scheduler and object store.""" + # Reset the list of cached remote functions so that if more remote + # functions are defined and then connect is called again, the remote + # functions will be exported. This is mostly relevant for the tests. + worker.connected = False + worker.cached_functions_to_run = [] + worker.cached_remote_functions = [] + serialization.clear_state() def register_class(cls, pickle=False, worker=global_worker): - raise Exception("The function ray.register_class is deprecated. It should " - "be safe to remove any calls to this function.") + raise Exception("The function ray.register_class is deprecated. It should " + "be safe to remove any calls to this function.") def _register_class(cls, pickle=False, worker=global_worker): - """Enable workers to serialize or deserialize objects of a particular class. + """Enable serialization and deserialization for a particular class. - This method runs the register_class function defined below on every worker, - which will enable numbuf to properly serialize and deserialize objects of - this class. + This method runs the register_class function defined below on every worker, + which will enable numbuf to properly serialize and deserialize objects of + this class. - Args: - cls (type): The class that numbuf should serialize. - pickle (bool): If False then objects of this class will be serialized by - turning their __dict__ fields into a dictionary. If True, then objects - of this class will be serialized using pickle. + Args: + cls (type): The class that numbuf should serialize. + pickle (bool): If False then objects of this class will be serialized + by turning their __dict__ fields into a dictionary. If True, then + objects of this class will be serialized using pickle. - Raises: - Exception: An exception is raised if pickle=False and the class cannot be - efficiently serialized by Ray. - """ - class_id = random_string() + Raises: + Exception: An exception is raised if pickle=False and the class cannot + be efficiently serialized by Ray. + """ + class_id = random_string() - def register_class_for_serialization(worker_info): - serialization.add_class_to_whitelist(cls, class_id, pickle=pickle) + def register_class_for_serialization(worker_info): + serialization.add_class_to_whitelist(cls, class_id, pickle=pickle) - if not pickle: - # Raise an exception if cls cannot be serialized efficiently by Ray. - serialization.check_serializable(cls) - worker.run_function_on_all_workers(register_class_for_serialization) - else: - # Since we are pickling objects of this class, we don't actually need to - # ship the class definition. - register_class_for_serialization({}) + if not pickle: + # Raise an exception if cls cannot be serialized efficiently by Ray. + serialization.check_serializable(cls) + worker.run_function_on_all_workers(register_class_for_serialization) + else: + # Since we are pickling objects of this class, we don't actually need + # to ship the class definition. + register_class_for_serialization({}) class RayLogSpan(object): - """An object used to enable logging a span of events with a with statement. + """An object used to enable logging a span of events with a with statement. - Attributes: - event_type (str): The type of the event being logged. - contents: Additional information to log. - """ - def __init__(self, event_type, contents=None, worker=global_worker): - """Initialize a RayLogSpan object.""" - self.event_type = event_type - self.contents = contents - self.worker = worker + Attributes: + event_type (str): The type of the event being logged. + contents: Additional information to log. + """ + def __init__(self, event_type, contents=None, worker=global_worker): + """Initialize a RayLogSpan object.""" + self.event_type = event_type + self.contents = contents + self.worker = worker - def __enter__(self): - """Log the beginning of a span event.""" - log(event_type=self.event_type, - contents=self.contents, - kind=LOG_SPAN_START, - worker=self.worker) + def __enter__(self): + """Log the beginning of a span event.""" + log(event_type=self.event_type, + contents=self.contents, + kind=LOG_SPAN_START, + worker=self.worker) - def __exit__(self, type, value, tb): - """Log the end of a span event. Log any exception that occurred.""" - if type is None: - log(event_type=self.event_type, kind=LOG_SPAN_END, worker=self.worker) - else: - log(event_type=self.event_type, - contents={"type": str(type), - "value": value, - "traceback": traceback.format_exc()}, - kind=LOG_SPAN_END, - worker=self.worker) + def __exit__(self, type, value, tb): + """Log the end of a span event. Log any exception that occurred.""" + if type is None: + log(event_type=self.event_type, kind=LOG_SPAN_END, + worker=self.worker) + else: + log(event_type=self.event_type, + contents={"type": str(type), + "value": value, + "traceback": traceback.format_exc()}, + kind=LOG_SPAN_END, + worker=self.worker) def log_span(event_type, contents=None, worker=global_worker): - return RayLogSpan(event_type, contents=contents, worker=worker) + return RayLogSpan(event_type, contents=contents, worker=worker) def log_event(event_type, contents=None, worker=global_worker): - log(event_type, kind=LOG_POINT, contents=contents, worker=worker) + log(event_type, kind=LOG_POINT, contents=contents, worker=worker) def log(event_type, kind, contents=None, worker=global_worker): - """Log an event to the global state store. + """Log an event to the global state store. - This adds the event to a buffer of events locally. The buffer can be flushed - and written to the global state store by calling flush_log(). + This adds the event to a buffer of events locally. The buffer can be + flushed and written to the global state store by calling flush_log(). - Args: - event_type (str): The type of the event. - contents: More general data to store with the event. - kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This is - LOG_POINT if the event being logged happens at a single point in time. It - is LOG_SPAN_START if we are starting to log a span of time, and it is - LOG_SPAN_END if we are finishing logging a span of time. - """ - # TODO(rkn): This code currently takes around half a microsecond. Since we - # call it tens of times per task, this adds up. We will need to redo the - # logging code, perhaps in C. - contents = {} if contents is None else contents - assert isinstance(contents, dict) - # Make sure all of the keys and values in the dictionary are strings. - contents = {str(k): str(v) for k, v in contents.items()} - worker.events.append((time.time(), event_type, kind, contents)) + Args: + event_type (str): The type of the event. + contents: More general data to store with the event. + kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This is + LOG_POINT if the event being logged happens at a single point in + time. It is LOG_SPAN_START if we are starting to log a span of + time, and it is LOG_SPAN_END if we are finishing logging a span of + time. + """ + # TODO(rkn): This code currently takes around half a microsecond. Since we + # call it tens of times per task, this adds up. We will need to redo the + # logging code, perhaps in C. + contents = {} if contents is None else contents + assert isinstance(contents, dict) + # Make sure all of the keys and values in the dictionary are strings. + contents = {str(k): str(v) for k, v in contents.items()} + worker.events.append((time.time(), event_type, kind, contents)) def flush_log(worker=global_worker): - """Send the logged worker events to the global state store.""" - event_log_key = b"event_log:" + worker.worker_id - event_log_value = json.dumps(worker.events) - worker.local_scheduler_client.log_event(event_log_key, - event_log_value, - time.time()) - worker.events = [] + """Send the logged worker events to the global state store.""" + event_log_key = b"event_log:" + worker.worker_id + event_log_value = json.dumps(worker.events) + worker.local_scheduler_client.log_event(event_log_key, + event_log_value, + time.time()) + worker.events = [] def get(object_ids, worker=global_worker): - """Get a remote object or a list of remote objects from the object store. + """Get a remote object or a list of remote objects from the object store. - This method blocks until the object corresponding to the object ID is - available in the local object store. If this object is not in the local - object store, it will be shipped from an object store that has it (once the - object has been created). If object_ids is a list, then the objects - corresponding to each object in the list will be returned. + This method blocks until the object corresponding to the object ID is + available in the local object store. If this object is not in the local + object store, it will be shipped from an object store that has it (once the + object has been created). If object_ids is a list, then the objects + corresponding to each object in the list will be returned. - Args: - object_ids: Object ID of the object to get or a list of object IDs to get. + Args: + object_ids: Object ID of the object to get or a list of object IDs to + get. - Returns: - A Python object or a list of Python objects. - """ - check_connected(worker) - with log_span("ray:get", worker=worker): - check_main_thread() + Returns: + A Python object or a list of Python objects. + """ + check_connected(worker) + with log_span("ray:get", worker=worker): + check_main_thread() - if worker.mode == PYTHON_MODE: - # In PYTHON_MODE, ray.get is the identity operation (the input will - # actually be a value not an objectid). - return object_ids - if isinstance(object_ids, list): - values = worker.get_object(object_ids) - for i, value in enumerate(values): - if isinstance(value, RayTaskError): - raise RayGetError(object_ids[i], value) - return values - else: - value = worker.get_object([object_ids])[0] - if isinstance(value, RayTaskError): - # If the result is a RayTaskError, then the task that created this - # object failed, and we should propagate the error message here. - raise RayGetError(object_ids, value) - return value + if worker.mode == PYTHON_MODE: + # In PYTHON_MODE, ray.get is the identity operation (the input will + # actually be a value not an objectid). + return object_ids + if isinstance(object_ids, list): + values = worker.get_object(object_ids) + for i, value in enumerate(values): + if isinstance(value, RayTaskError): + raise RayGetError(object_ids[i], value) + return values + else: + value = worker.get_object([object_ids])[0] + if isinstance(value, RayTaskError): + # If the result is a RayTaskError, then the task that created + # this object failed, and we should propagate the error message + # here. + raise RayGetError(object_ids, value) + return value def put(value, worker=global_worker): - """Store an object in the object store. + """Store an object in the object store. - Args: - value: The Python object to be stored. + Args: + value: The Python object to be stored. - Returns: - The object ID assigned to this value. - """ - check_connected(worker) - with log_span("ray:put", worker=worker): - check_main_thread() + Returns: + The object ID assigned to this value. + """ + check_connected(worker) + with log_span("ray:put", worker=worker): + check_main_thread() - if worker.mode == PYTHON_MODE: - # In PYTHON_MODE, ray.put is the identity operation. - return value - object_id = worker.local_scheduler_client.compute_put_id( - worker.current_task_id, worker.put_index) - worker.put_object(object_id, value) - worker.put_index += 1 - return object_id + if worker.mode == PYTHON_MODE: + # In PYTHON_MODE, ray.put is the identity operation. + return value + object_id = worker.local_scheduler_client.compute_put_id( + worker.current_task_id, worker.put_index) + worker.put_object(object_id, value) + worker.put_index += 1 + return object_id def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): - """Return a list of IDs that are ready and a list of IDs that are not ready. + """Return a list of IDs that are ready and a list of IDs that are not. - If timeout is set, the function returns either when the requested number of - IDs are ready or when the timeout is reached, whichever occurs first. If it - is not set, the function simply waits until that number of objects is ready - and returns that exact number of objectids. + If timeout is set, the function returns either when the requested number of + IDs are ready or when the timeout is reached, whichever occurs first. If it + is not set, the function simply waits until that number of objects is ready + and returns that exact number of objectids. - This method returns two lists. The first list consists of object IDs that - correspond to objects that are stored in the object store. The second list - corresponds to the rest of the object IDs (which may or may not be ready). + This method returns two lists. The first list consists of object IDs that + correspond to objects that are stored in the object store. The second list + corresponds to the rest of the object IDs (which may or may not be ready). - Args: - object_ids (List[ObjectID]): List of object IDs for objects that may or may - not be ready. Note that these IDs must be unique. - num_returns (int): The number of object IDs that should be returned. - timeout (int): The maximum amount of time in milliseconds to wait before - returning. + Args: + object_ids (List[ObjectID]): List of object IDs for objects that may or + may not be ready. Note that these IDs must be unique. + num_returns (int): The number of object IDs that should be returned. + timeout (int): The maximum amount of time in milliseconds to wait + before returning. - Returns: - A list of object IDs that are ready and a list of the remaining object IDs. - """ - check_connected(worker) - with log_span("ray:wait", worker=worker): - check_main_thread() - object_id_strs = [object_id.id() for object_id in object_ids] - timeout = timeout if timeout is not None else 2 ** 30 - ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, - timeout, num_returns) - ready_ids = [ray.local_scheduler.ObjectID(object_id) - for object_id in ready_ids] - remaining_ids = [ray.local_scheduler.ObjectID(object_id) - for object_id in remaining_ids] - return ready_ids, remaining_ids + Returns: + A list of object IDs that are ready and a list of the remaining object + IDs. + """ + check_connected(worker) + with log_span("ray:wait", worker=worker): + check_main_thread() + object_id_strs = [object_id.id() for object_id in object_ids] + timeout = timeout if timeout is not None else 2 ** 30 + ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, + timeout, + num_returns) + ready_ids = [ray.local_scheduler.ObjectID(object_id) + for object_id in ready_ids] + remaining_ids = [ray.local_scheduler.ObjectID(object_id) + for object_id in remaining_ids] + return ready_ids, remaining_ids def wait_for_function(function_id, driver_id, timeout=10, worker=global_worker): - """Wait until the function to be executed is present on this worker. + """Wait until the function to be executed is present on this worker. - This method will simply loop until the import thread has imported the - relevant function. If we spend too long in this loop, that may indicate a - problem somewhere and we will push an error message to the user. + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate a + problem somewhere and we will push an error message to the user. - If this worker is an actor, then this will wait until the actor has been - defined. + If this worker is an actor, then this will wait until the actor has been + defined. - Args: - is_actor (bool): True if this worker is an actor, and false otherwise. - function_id (str): The ID of the function that we want to execute. - driver_id (str): The ID of the driver to push the error message to if this - times out. - """ - start_time = time.time() - # Only send the warning once. - warning_sent = False - while True: - with worker.lock: - if worker.actor_id == NIL_ACTOR_ID and (function_id.id() in - worker.functions[driver_id]): - break - elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in - worker.actors): - break - if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a function that " - "it does not have registered. You may have to " - "restart Ray.") - if not warning_sent: - worker.push_error_to_driver(driver_id, "wait_for_function", - warning_message) - warning_sent = True - time.sleep(0.001) + Args: + is_actor (bool): True if this worker is an actor, and false otherwise. + function_id (str): The ID of the function that we want to execute. + driver_id (str): The ID of the driver to push the error message to if + this times out. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + with worker.lock: + if (worker.actor_id == NIL_ACTOR_ID and + (function_id.id() in worker.functions[driver_id])): + break + elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in + worker.actors): + break + if time.time() - start_time > timeout: + warning_message = ("This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart Ray.") + if not warning_sent: + worker.push_error_to_driver(driver_id, "wait_for_function", + warning_message) + warning_sent = True + time.sleep(0.001) def format_error_message(exception_message, task_exception=False): - """Improve the formatting of an exception thrown by a remote function. + """Improve the formatting of an exception thrown by a remote function. - This method takes a traceback from an exception and makes it nicer by - removing a few uninformative lines and adding some space to indent the - remaining lines nicely. + This method takes a traceback from an exception and makes it nicer by + removing a few uninformative lines and adding some space to indent the + remaining lines nicely. - Args: - exception_message (str): A message generated by traceback.format_exc(). + Args: + exception_message (str): A message generated by traceback.format_exc(). - Returns: - A string of the formatted exception message. - """ - lines = exception_message.split("\n") - if task_exception: - # For errors that occur inside of tasks, remove lines 1, 2, 3, and 4, - # which are always the same, they just contain information about the main - # loop. - lines = lines[0:1] + lines[5:] - return "\n".join(lines) + Returns: + A string of the formatted exception message. + """ + lines = exception_message.split("\n") + if task_exception: + # For errors that occur inside of tasks, remove lines 1, 2, 3, and 4, + # which are always the same, they just contain information about the + # main loop. + lines = lines[0:1] + lines[5:] + return "\n".join(lines) def main_loop(worker=global_worker): - """The main loop a worker runs to receive and execute tasks.""" + """The main loop a worker runs to receive and execute tasks.""" - def exit(signum, frame): - cleanup(worker=worker) - sys.exit(0) + def exit(signum, frame): + cleanup(worker=worker) + sys.exit(0) - signal.signal(signal.SIGTERM, exit) + signal.signal(signal.SIGTERM, exit) - def process_task(task): - """Execute a task assigned to this worker. + def process_task(task): + """Execute a task assigned to this worker. - This method deserializes a task from the scheduler, and attempts to execute - the task. If the task succeeds, the outputs are stored in the local object - store. If the task throws an exception, RayTaskError objects are stored in - the object store to represent the failed task (these will be retrieved by - calls to get or by subsequent tasks that use the outputs of this task). - """ - try: - # The ID of the driver that this task belongs to. This is needed so that - # if the task throws an exception, we propagate the error message to the - # correct driver. - worker.task_driver_id = task.driver_id() - worker.current_task_id = task.task_id() - worker.current_function_id = task.function_id().id() - worker.task_index = 0 - worker.put_index = 0 - function_id = task.function_id() - args = task.arguments() - return_object_ids = task.returns() - function_name, function_executor = (worker.functions - [worker.task_driver_id.id()] - [function_id.id()]) + This method deserializes a task from the scheduler, and attempts to + execute the task. If the task succeeds, the outputs are stored in the + local object store. If the task throws an exception, RayTaskError + objects are stored in the object store to represent the failed task + (these will be retrieved by calls to get or by subsequent tasks that + use the outputs of this task). + """ + try: + # The ID of the driver that this task belongs to. This is needed so + # that if the task throws an exception, we propagate the error + # message to the correct driver. + worker.task_driver_id = task.driver_id() + worker.current_task_id = task.task_id() + worker.current_function_id = task.function_id().id() + worker.task_index = 0 + worker.put_index = 0 + function_id = task.function_id() + args = task.arguments() + return_object_ids = task.returns() + function_name, function_executor = (worker.functions + [worker.task_driver_id.id()] + [function_id.id()]) - # Get task arguments from the object store. - with log_span("ray:task:get_arguments", worker=worker): - arguments = get_arguments_for_execution(function_name, args, worker) + # Get task arguments from the object store. + with log_span("ray:task:get_arguments", worker=worker): + arguments = get_arguments_for_execution(function_name, args, + worker) - # Execute the task. - with log_span("ray:task:execute", worker=worker): - if task.actor_id().id() == NIL_ACTOR_ID: - outputs = function_executor.executor(arguments) - else: - outputs = function_executor( - worker.actors[task.actor_id().id()], *arguments) + # Execute the task. + with log_span("ray:task:execute", worker=worker): + if task.actor_id().id() == NIL_ACTOR_ID: + outputs = function_executor.executor(arguments) + else: + outputs = function_executor( + worker.actors[task.actor_id().id()], *arguments) - # Store the outputs in the local object store. - with log_span("ray:task:store_outputs", worker=worker): - if len(return_object_ids) == 1: - outputs = (outputs,) - store_outputs_in_objstore(return_object_ids, outputs, worker) - except Exception as e: - # We determine whether the exception was caused by the call to - # get_arguments_for_execution or by the execution of the remote function - # or by the call to store_outputs_in_objstore. Depending on which case - # occurred, we format the error message differently. - # whether the variables "arguments" and "outputs" are defined. - if "arguments" in locals() and "outputs" not in locals(): - if task.actor_id().id() == NIL_ACTOR_ID: - # The error occurred during the task execution. - traceback_str = format_error_message(traceback.format_exc(), - task_exception=True) - else: - # The error occurred during the execution of an actor task. - traceback_str = format_error_message(traceback.format_exc()) - elif "arguments" in locals() and "outputs" in locals(): - # The error occurred after the task executed. - traceback_str = format_error_message(traceback.format_exc()) - else: - # The error occurred before the task execution. - if isinstance(e, RayGetError) or isinstance(e, RayGetArgumentError): - # In this case, getting the task arguments failed. - traceback_str = None - else: - traceback_str = traceback.format_exc() - failure_object = RayTaskError(function_name, e, traceback_str) - failure_objects = [failure_object for _ in range(len(return_object_ids))] - store_outputs_in_objstore(return_object_ids, failure_objects, worker) - # Log the error message. - worker.push_error_to_driver(worker.task_driver_id.id(), "task", - str(failure_object), - data={"function_id": function_id.id(), - "function_name": function_name}) + # Store the outputs in the local object store. + with log_span("ray:task:store_outputs", worker=worker): + if len(return_object_ids) == 1: + outputs = (outputs,) + store_outputs_in_objstore(return_object_ids, outputs, worker) + except Exception as e: + # We determine whether the exception was caused by the call to + # get_arguments_for_execution or by the execution of the remote + # function or by the call to store_outputs_in_objstore. Depending + # on which case occurred, we format the error message differently. + # whether the variables "arguments" and "outputs" are defined. + if "arguments" in locals() and "outputs" not in locals(): + if task.actor_id().id() == NIL_ACTOR_ID: + # The error occurred during the task execution. + traceback_str = format_error_message( + traceback.format_exc(), task_exception=True) + else: + # The error occurred during the execution of an actor task. + traceback_str = format_error_message( + traceback.format_exc()) + elif "arguments" in locals() and "outputs" in locals(): + # The error occurred after the task executed. + traceback_str = format_error_message(traceback.format_exc()) + else: + # The error occurred before the task execution. + if (isinstance(e, RayGetError) or + isinstance(e, RayGetArgumentError)): + # In this case, getting the task arguments failed. + traceback_str = None + else: + traceback_str = traceback.format_exc() + failure_object = RayTaskError(function_name, e, traceback_str) + failure_objects = [failure_object for _ + in range(len(return_object_ids))] + store_outputs_in_objstore(return_object_ids, failure_objects, + worker) + # Log the error message. + worker.push_error_to_driver(worker.task_driver_id.id(), "task", + str(failure_object), + data={"function_id": function_id.id(), + "function_name": function_name}) - check_main_thread() - while True: - with log_span("ray:get_task", worker=worker): - task = worker.local_scheduler_client.get_task() + check_main_thread() + while True: + with log_span("ray:get_task", worker=worker): + task = worker.local_scheduler_client.get_task() - function_id = task.function_id() - # Wait until the function to be executed has actually been registered on - # this worker. We will push warnings to the user if we spend too long in - # this loop. - with log_span("ray:wait_for_function", worker=worker): - wait_for_function(function_id, task.driver_id().id(), worker=worker) + function_id = task.function_id() + # Wait until the function to be executed has actually been registered + # on this worker. We will push warnings to the user if we spend too + # long in this loop. + with log_span("ray:wait_for_function", worker=worker): + wait_for_function(function_id, task.driver_id().id(), + worker=worker) - # Execute the task. - # TODO(rkn): Consider acquiring this lock with a timeout and pushing a - # warning to the user if we are waiting too long to acquire the lock - # because that may indicate that the system is hanging, and it'd be good to - # know where the system is hanging. - log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker) - with worker.lock: - log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker) + # Execute the task. + # TODO(rkn): Consider acquiring this lock with a timeout and pushing a + # warning to the user if we are waiting too long to acquire the lock + # because that may indicate that the system is hanging, and it'd be + # good to know where the system is hanging. + log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker) + with worker.lock: + log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, + worker=worker) - function_name, _ = (worker.functions[task.driver_id().id()] - [function_id.id()]) - contents = {"function_name": function_name, - "task_id": task.task_id().hex(), - "worker_id": binary_to_hex(worker.worker_id)} - with log_span("ray:task", contents=contents, worker=worker): - process_task(task) + function_name, _ = (worker.functions[task.driver_id().id()] + [function_id.id()]) + contents = {"function_name": function_name, + "task_id": task.task_id().hex(), + "worker_id": binary_to_hex(worker.worker_id)} + with log_span("ray:task", contents=contents, worker=worker): + process_task(task) - # Push all of the log events to the global state store. - flush_log() + # Push all of the log events to the global state store. + flush_log() - # Increase the task execution counter. - worker.num_task_executions[task.driver_id().id()][function_id.id()] += 1 + # Increase the task execution counter. + (worker.num_task_executions[task.driver_id().id()] + [function_id.id()]) += 1 - reached_max_executions = ( - worker.num_task_executions[task.driver_id().id()][function_id.id()] == - worker.function_properties[task.driver_id().id()] - [function_id.id()].max_calls) - if reached_max_executions: - ray.worker.global_worker.local_scheduler_client.disconnect() - os._exit(0) + reached_max_executions = ( + worker.num_task_executions[task.driver_id().id()] + [function_id.id()] == + worker.function_properties[task.driver_id().id()] + [function_id.id()].max_calls) + if reached_max_executions: + ray.worker.global_worker.local_scheduler_client.disconnect() + os._exit(0) def _submit_task(function_id, func_name, args, worker=global_worker): - """This is a wrapper around worker.submit_task. + """This is a wrapper around worker.submit_task. - We use this wrapper so that in the remote decorator, we can call _submit_task - instead of worker.submit_task. The difference is that when we attempt to - serialize remote functions, we don't attempt to serialize the worker object, - which cannot be serialized. - """ - return worker.submit_task(function_id, func_name, args) + We use this wrapper so that in the remote decorator, we can call + _submit_task instead of worker.submit_task. The difference is that when we + attempt to serialize remote functions, we don't attempt to serialize the + worker object, which cannot be serialized. + """ + return worker.submit_task(function_id, func_name, args) def _mode(worker=global_worker): - """This is a wrapper around worker.mode. + """This is a wrapper around worker.mode. - We use this wrapper so that in the remote decorator, we can call _mode() - instead of worker.mode. The difference is that when we attempt to serialize - remote functions, we don't attempt to serialize the worker object, which - cannot be serialized. - """ - return worker.mode + We use this wrapper so that in the remote decorator, we can call _mode() + instead of worker.mode. The difference is that when we attempt to serialize + remote functions, we don't attempt to serialize the worker object, which + cannot be serialized. + """ + return worker.mode def export_remote_function(function_id, func_name, func, func_invoker, function_properties, worker=global_worker): - check_main_thread() - if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: - raise Exception("export_remote_function can only be called on a driver.") + check_main_thread() + if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: + raise Exception("export_remote_function can only be called on a " + "driver.") - worker.function_properties[ - worker.task_driver_id.id()][function_id.id()] = function_properties - task_driver_id = worker.task_driver_id - key = b"RemoteFunction:" + task_driver_id.id() + b":" + function_id.id() + worker.function_properties[ + worker.task_driver_id.id()][function_id.id()] = function_properties + task_driver_id = worker.task_driver_id + key = b"RemoteFunction:" + task_driver_id.id() + b":" + function_id.id() - # Work around limitations of Python pickling. - func_name_global_valid = func.__name__ in func.__globals__ - func_name_global_value = func.__globals__.get(func.__name__) - # Allow the function to reference itself as a global variable - func.__globals__[func.__name__] = func_invoker - try: - pickled_func = pickle.dumps(func) - finally: - # Undo our changes - if func_name_global_valid: - func.__globals__[func.__name__] = func_name_global_value - else: - del func.__globals__[func.__name__] + # Work around limitations of Python pickling. + func_name_global_valid = func.__name__ in func.__globals__ + func_name_global_value = func.__globals__.get(func.__name__) + # Allow the function to reference itself as a global variable + func.__globals__[func.__name__] = func_invoker + try: + pickled_func = pickle.dumps(func) + finally: + # Undo our changes + if func_name_global_valid: + func.__globals__[func.__name__] = func_name_global_value + else: + del func.__globals__[func.__name__] - worker.redis_client.hmset(key, { - "driver_id": worker.task_driver_id.id(), - "function_id": function_id.id(), - "name": func_name, - "module": func.__module__, - "function": pickled_func, - "num_return_vals": function_properties.num_return_vals, - "num_cpus": function_properties.num_cpus, - "num_gpus": function_properties.num_gpus, - "max_calls": function_properties.max_calls}) - worker.redis_client.rpush("Exports", key) + worker.redis_client.hmset(key, { + "driver_id": worker.task_driver_id.id(), + "function_id": function_id.id(), + "name": func_name, + "module": func.__module__, + "function": pickled_func, + "num_return_vals": function_properties.num_return_vals, + "num_cpus": function_properties.num_cpus, + "num_gpus": function_properties.num_gpus, + "max_calls": function_properties.max_calls}) + worker.redis_client.rpush("Exports", key) def in_ipython(): - """Return true if we are in an IPython interpreter and false otherwise.""" - try: - __IPYTHON__ - return True - except NameError: - return False + """Return true if we are in an IPython interpreter and false otherwise.""" + try: + __IPYTHON__ + return True + except NameError: + return False def compute_function_id(func_name, func): - """Compute an function ID for a function. + """Compute an function ID for a function. - Args: - func_name: The name of the function (this includes the module name plus the - function name). - func: The actual function. + Args: + func_name: The name of the function (this includes the module name plus + the function name). + func: The actual function. - Returns: - This returns the function ID. - """ - function_id_hash = hashlib.sha1() - # Include the function name in the hash. - function_id_hash.update(func_name.encode("ascii")) - # If we are running a script or are in IPython, include the source code in - # the hash. If we are in a regular Python interpreter we skip this part - # because the source code is not accessible. - import __main__ as main - if hasattr(main, "__file__") or in_ipython(): - function_id_hash.update(inspect.getsource(func).encode("ascii")) - # Compute the function ID. - function_id = function_id_hash.digest() - assert len(function_id) == 20 - function_id = FunctionID(function_id) + Returns: + This returns the function ID. + """ + function_id_hash = hashlib.sha1() + # Include the function name in the hash. + function_id_hash.update(func_name.encode("ascii")) + # If we are running a script or are in IPython, include the source code in + # the hash. If we are in a regular Python interpreter we skip this part + # because the source code is not accessible. + import __main__ as main + if hasattr(main, "__file__") or in_ipython(): + function_id_hash.update(inspect.getsource(func).encode("ascii")) + # Compute the function ID. + function_id = function_id_hash.digest() + assert len(function_id) == 20 + function_id = FunctionID(function_id) - return function_id + return function_id def remote(*args, **kwargs): - """This decorator is used to define remote functions and to define actors. + """This decorator is used to define remote functions and to define actors. - Args: - num_return_vals (int): The number of object IDs that a call to this - function should return. - num_cpus (int): The number of CPUs needed to execute this function. This - should only be passed in when defining the remote function on the driver. - num_gpus (int): The number of GPUs needed to execute this function. This - should only be passed in when defining the remote function on the driver. - max_calls (int): The maximum number of tasks of this kind that can be - run on a worker before the worker needs to be restarted. - """ - worker = global_worker + Args: + num_return_vals (int): The number of object IDs that a call to this + function should return. + num_cpus (int): The number of CPUs needed to execute this function. + This should only be passed in when defining the remote function on + the driver. + num_gpus (int): The number of GPUs needed to execute this function. + This should only be passed in when defining the remote function on + the driver. + max_calls (int): The maximum number of tasks of this kind that can be + run on a worker before the worker needs to be restarted. + """ + worker = global_worker - def make_remote_decorator(num_return_vals, num_cpus, num_gpus, - max_calls, func_id=None): - def remote_decorator(func_or_class): - if inspect.isfunction(func_or_class): - function_properties = FunctionProperties( - num_return_vals=num_return_vals, - num_cpus=num_cpus, - num_gpus=num_gpus, - max_calls=max_calls) - return remote_function_decorator(func_or_class, function_properties) - if inspect.isclass(func_or_class): - return worker.make_actor(func_or_class, num_cpus, num_gpus) - raise Exception("The @ray.remote decorator must be applied to either a " - "function or to a class.") + def make_remote_decorator(num_return_vals, num_cpus, num_gpus, + max_calls, func_id=None): + def remote_decorator(func_or_class): + if inspect.isfunction(func_or_class): + function_properties = FunctionProperties( + num_return_vals=num_return_vals, + num_cpus=num_cpus, + num_gpus=num_gpus, + max_calls=max_calls) + return remote_function_decorator(func_or_class, + function_properties) + if inspect.isclass(func_or_class): + return worker.make_actor(func_or_class, num_cpus, num_gpus) + raise Exception("The @ray.remote decorator must be applied to " + "either a function or to a class.") - def remote_function_decorator(func, function_properties): - func_name = "{}.{}".format(func.__module__, func.__name__) - if func_id is None: - function_id = compute_function_id(func_name, func) - else: - function_id = func_id + def remote_function_decorator(func, function_properties): + func_name = "{}.{}".format(func.__module__, func.__name__) + if func_id is None: + function_id = compute_function_id(func_name, func) + else: + function_id = func_id - def func_call(*args, **kwargs): - """This gets run immediately when a worker calls a remote function.""" - check_connected() - check_main_thread() - args = signature.extend_args(function_signature, args, kwargs) + def func_call(*args, **kwargs): + """This runs immediately when a remote function is called.""" + check_connected() + check_main_thread() + args = signature.extend_args(function_signature, args, kwargs) - if _mode() == PYTHON_MODE: - # In PYTHON_MODE, remote calls simply execute the function. We copy - # the arguments to prevent the function call from mutating them and - # to match the usual behavior of immutable remote objects. - result = func(*copy.deepcopy(args)) - return result - objectids = _submit_task(function_id, func_name, args) - if len(objectids) == 1: - return objectids[0] - elif len(objectids) > 1: - return objectids + if _mode() == PYTHON_MODE: + # In PYTHON_MODE, remote calls simply execute the function. + # We copy the arguments to prevent the function call from + # mutating them and to match the usual behavior of + # immutable remote objects. + result = func(*copy.deepcopy(args)) + return result + objectids = _submit_task(function_id, func_name, args) + if len(objectids) == 1: + return objectids[0] + elif len(objectids) > 1: + return objectids - def func_executor(arguments): - """This gets run when the remote function is executed.""" - result = func(*arguments) - return result + def func_executor(arguments): + """This gets run when the remote function is executed.""" + result = func(*arguments) + return result - def func_invoker(*args, **kwargs): - """This is used to invoke the function.""" - raise Exception("Remote functions cannot be called directly. Instead " - "of running '{}()', try '{}.remote()'." - .format(func_name, func_name)) - func_invoker.remote = func_call - func_invoker.executor = func_executor - func_invoker.is_remote = True - func_name = "{}.{}".format(func.__module__, func.__name__) - func_invoker.func_name = func_name - if sys.version_info >= (3, 0): - func_invoker.__doc__ = func.__doc__ - else: - func_invoker.func_doc = func.func_doc + def func_invoker(*args, **kwargs): + """This is used to invoke the function.""" + raise Exception("Remote functions cannot be called directly. " + "Instead of running '{}()', try '{}.remote()'." + .format(func_name, func_name)) + func_invoker.remote = func_call + func_invoker.executor = func_executor + func_invoker.is_remote = True + func_name = "{}.{}".format(func.__module__, func.__name__) + func_invoker.func_name = func_name + if sys.version_info >= (3, 0): + func_invoker.__doc__ = func.__doc__ + else: + func_invoker.func_doc = func.func_doc - signature.check_signature_supported(func) - function_signature = signature.extract_signature(func) + signature.check_signature_supported(func) + function_signature = signature.extract_signature(func) - # Everything ready - export the function - if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - export_remote_function(function_id, func_name, func, func_invoker, - function_properties) - elif worker.mode is None: - worker.cached_remote_functions.append((function_id, func_name, func, - func_invoker, - function_properties)) - return func_invoker + # Everything ready - export the function + if worker.mode in [SCRIPT_MODE, SILENT_MODE]: + export_remote_function(function_id, func_name, func, + func_invoker, function_properties) + elif worker.mode is None: + worker.cached_remote_functions.append((function_id, func_name, + func, func_invoker, + function_properties)) + return func_invoker - return remote_decorator + return remote_decorator - num_return_vals = (kwargs["num_return_vals"] if "num_return_vals" - in kwargs else 1) - num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1 - num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0 - max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0 + num_return_vals = (kwargs["num_return_vals"] if "num_return_vals" + in kwargs else 1) + num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1 + num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0 + max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0 - if _mode() == WORKER_MODE: - if "function_id" in kwargs: - function_id = kwargs["function_id"] - return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - max_calls, function_id) + if _mode() == WORKER_MODE: + if "function_id" in kwargs: + function_id = kwargs["function_id"] + return make_remote_decorator(num_return_vals, num_cpus, num_gpus, + max_calls, function_id) - if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): - # This is the case where the decorator is just @ray.remote. - return make_remote_decorator(num_return_vals, num_cpus, - num_gpus, max_calls)(args[0]) - else: - # This is the case where the decorator is something like - # @ray.remote(num_return_vals=2). - error_string = ("The @ray.remote decorator must be applied either with no " - "arguments and no parentheses, for example '@ray.remote', " - "or it must be applied using some of the arguments " - "'num_return_vals', 'num_cpus', 'num_gpus', or 'max_calls'" - ", like '@ray.remote(num_return_vals=2)'.") - assert len(args) == 0 and ("num_return_vals" in kwargs or - "num_cpus" in kwargs or - "num_gpus" in kwargs or - "max_calls" in kwargs), error_string - for key in kwargs: - assert key in ["num_return_vals", "num_cpus", - "num_gpus", "max_calls"], error_string - assert "function_id" not in kwargs - return make_remote_decorator(num_return_vals, num_cpus, num_gpus, - max_calls) + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + return make_remote_decorator(num_return_vals, num_cpus, + num_gpus, max_calls)(args[0]) + else: + # This is the case where the decorator is something like + # @ray.remote(num_return_vals=2). + error_string = ("The @ray.remote decorator must be applied either " + "with no arguments and no parentheses, for example " + "'@ray.remote', or it must be applied using some of " + "the arguments 'num_return_vals', 'num_cpus', " + "'num_gpus', or 'max_calls', like " + "'@ray.remote(num_return_vals=2)'.") + assert len(args) == 0 and ("num_return_vals" in kwargs or + "num_cpus" in kwargs or + "num_gpus" in kwargs or + "max_calls" in kwargs), error_string + for key in kwargs: + assert key in ["num_return_vals", "num_cpus", + "num_gpus", "max_calls"], error_string + assert "function_id" not in kwargs + return make_remote_decorator(num_return_vals, num_cpus, num_gpus, + max_calls) def get_arguments_for_execution(function_name, serialized_args, worker=global_worker): - """Retrieve the arguments for the remote function. + """Retrieve the arguments for the remote function. - This retrieves the values for the arguments to the remote function that were - passed in as object IDs. Argumens that were passed by value are not changed. - This is called by the worker that is executing the remote function. + This retrieves the values for the arguments to the remote function that + were passed in as object IDs. Argumens that were passed by value are not + changed. This is called by the worker that is executing the remote + function. - Args: - function_name (str): The name of the remote function whose arguments are - being retrieved. - serialized_args (List): The arguments to the function. These are either - strings representing serialized objects passed by value or they are - ObjectIDs. + Args: + function_name (str): The name of the remote function whose arguments + are being retrieved. + serialized_args (List): The arguments to the function. These are either + strings representing serialized objects passed by value or they are + ObjectIDs. - Returns: - The retrieved arguments in addition to the arguments that were passed by - value. + Returns: + The retrieved arguments in addition to the arguments that were passed + by value. - Raises: - RayGetArgumentError: This exception is raised if a task that created one of - the arguments failed. - """ - arguments = [] - for (i, arg) in enumerate(serialized_args): - if isinstance(arg, ray.local_scheduler.ObjectID): - # get the object from the local object store - argument = worker.get_object([arg])[0] - if isinstance(argument, RayTaskError): - # If the result is a RayTaskError, then the task that created this - # object failed, and we should propagate the error message here. - raise RayGetArgumentError(function_name, i, arg, argument) - else: - # pass the argument by value - argument = arg + Raises: + RayGetArgumentError: This exception is raised if a task that created + one of the arguments failed. + """ + arguments = [] + for (i, arg) in enumerate(serialized_args): + if isinstance(arg, ray.local_scheduler.ObjectID): + # get the object from the local object store + argument = worker.get_object([arg])[0] + if isinstance(argument, RayTaskError): + # If the result is a RayTaskError, then the task that created + # this object failed, and we should propagate the error message + # here. + raise RayGetArgumentError(function_name, i, arg, argument) + else: + # pass the argument by value + argument = arg - arguments.append(argument) - return arguments + arguments.append(argument) + return arguments def store_outputs_in_objstore(objectids, outputs, worker=global_worker): - """Store the outputs of a remote function in the local object store. + """Store the outputs of a remote function in the local object store. - This stores the values that were returned by a remote function in the local - object store. If any of the return values are object IDs, then these object - IDs are aliased with the object IDs that the scheduler assigned for the - return values. This is called by the worker that executes the remote - function. + This stores the values that were returned by a remote function in the local + object store. If any of the return values are object IDs, then these object + IDs are aliased with the object IDs that the scheduler assigned for the + return values. This is called by the worker that executes the remote + function. - Note: - The arguments objectids and outputs should have the same length. + Note: + The arguments objectids and outputs should have the same length. - Args: - objectids (List[ObjectID]): The object IDs that were assigned to the - outputs of the remote function call. - outputs (Tuple): The value returned by the remote function. If the remote - function was supposed to only return one value, then its output was - wrapped in a tuple with one element prior to being passed into this - function. - """ - for i in range(len(objectids)): - worker.put_object(objectids[i], outputs[i]) + Args: + objectids (List[ObjectID]): The object IDs that were assigned to the + outputs of the remote function call. + outputs (Tuple): The value returned by the remote function. If the + remote function was supposed to only return one value, then its + output was wrapped in a tuple with one element prior to being + passed into this function. + """ + for i in range(len(objectids)): + worker.put_object(objectids[i], outputs[i]) diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 009ae6138..b2ef0674e 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -27,59 +27,61 @@ parser.add_argument("--actor-id", required=False, type=str, def random_string(): - return np.random.bytes(20) + return np.random.bytes(20) if __name__ == "__main__": - args = parser.parse_args() - info = {"node_ip_address": args.node_ip_address, - "redis_address": args.redis_address, - "store_socket_name": args.object_store_name, - "manager_socket_name": args.object_store_manager_name, - "local_scheduler_socket_name": args.local_scheduler_name} + args = parser.parse_args() + info = {"node_ip_address": args.node_ip_address, + "redis_address": args.redis_address, + "store_socket_name": args.object_store_name, + "manager_socket_name": args.object_store_manager_name, + "local_scheduler_socket_name": args.local_scheduler_name} - if args.actor_id is not None: - actor_id = binascii.unhexlify(args.actor_id) - else: - actor_id = ray.worker.NIL_ACTOR_ID + if args.actor_id is not None: + actor_id = binascii.unhexlify(args.actor_id) + else: + actor_id = ray.worker.NIL_ACTOR_ID - ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id) + ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id) - error_explanation = """ -This error is unexpected and should not have happened. Somehow a worker crashed -in an unanticipated way causing the main_loop to throw an exception, which is -being caught in "python/ray/workers/default_worker.py". -""" + error_explanation = """ + This error is unexpected and should not have happened. Somehow a worker + crashed in an unanticipated way causing the main_loop to throw an exception, + which is being caught in "python/ray/workers/default_worker.py". + """ - while True: - try: - # This call to main_loop should never return if things are working. Most - # exceptions that are thrown (e.g., inside the execution of a task) - # should be caught and handled inside of the call to main_loop. If an - # exception is thrown here, then that means that there is some error that - # we didn't anticipate. - ray.worker.main_loop() - except Exception as e: - traceback_str = traceback.format_exc() + error_explanation - DRIVER_ID_LENGTH = 20 - # We use a driver ID of all zeros to push an error message to all - # drivers. - driver_id = DRIVER_ID_LENGTH * b"\x00" - error_key = b"Error:" + driver_id + b":" + random_string() - redis_ip_address, redis_port = args.redis_address.split(":") - # For this command to work, some other client (on the same machine as - # Redis) must have run "CONFIG SET protected-mode no". - redis_client = redis.StrictRedis(host=redis_ip_address, - port=int(redis_port)) - redis_client.hmset(error_key, {"type": "worker_crash", - "message": traceback_str, - "note": ("This error is unexpected and " - "should not have happened.")}) - redis_client.rpush("ErrorKeys", error_key) - # TODO(rkn): Note that if the worker was in the middle of executing a - # task, the any worker or driver that is blocking in a get call and - # waiting for the output of that task will hang. We need to address this. + while True: + try: + # This call to main_loop should never return if things are working. + # Most exceptions that are thrown (e.g., inside the execution of a + # task) should be caught and handled inside of the call to + # main_loop. If an exception is thrown here, then that means that + # there is some error that we didn't anticipate. + ray.worker.main_loop() + except Exception as e: + traceback_str = traceback.format_exc() + error_explanation + DRIVER_ID_LENGTH = 20 + # We use a driver ID of all zeros to push an error message to all + # drivers. + driver_id = DRIVER_ID_LENGTH * b"\x00" + error_key = b"Error:" + driver_id + b":" + random_string() + redis_ip_address, redis_port = args.redis_address.split(":") + # For this command to work, some other client (on the same machine + # as Redis) must have run "CONFIG SET protected-mode no". + redis_client = redis.StrictRedis(host=redis_ip_address, + port=int(redis_port)) + redis_client.hmset(error_key, {"type": "worker_crash", + "message": traceback_str, + "note": ("This error is unexpected " + "and should not have " + "happened.")}) + redis_client.rpush("ErrorKeys", error_key) + # TODO(rkn): Note that if the worker was in the middle of executing + # a task, the any worker or driver that is blocking in a get call + # and waiting for the output of that task will hang. We need to + # address this. - # After putting the error message in Redis, this worker will attempt to - # reenter the main loop. TODO(rkn): We should probably reset it's state and - # call connect again. + # After putting the error message in Redis, this worker will attempt to + # reenter the main loop. TODO(rkn): We should probably reset it's state + # and call connect again. diff --git a/python/setup.py b/python/setup.py index 9ee34aece..40ba85501 100644 --- a/python/setup.py +++ b/python/setup.py @@ -11,32 +11,34 @@ import setuptools.command.build_ext as _build_ext class build_ext(_build_ext.build_ext): - def run(self): - subprocess.check_call(["../build.sh"]) - # Ideally, we could include these files by putting them in a MANIFEST.in or - # using the package_data argument to setup, but the MANIFEST.in gets - # applied at the very beginning when setup.py runs before these files have - # been created, so we have to move the files manually. - for filename in files_to_include: - self.move_file(filename) - # Copy over the autogenerated flatbuffer Python bindings. - generated_python_directory = "ray/core/generated" - for filename in os.listdir(generated_python_directory): - if filename[-3:] == ".py": - self.move_file(os.path.join(generated_python_directory, filename)) + def run(self): + subprocess.check_call(["../build.sh"]) + # Ideally, we could include these files by putting them in a + # MANIFEST.in or using the package_data argument to setup, but the + # MANIFEST.in gets applied at the very beginning when setup.py runs + # before these files have been created, so we have to move the files + # manually. + for filename in files_to_include: + self.move_file(filename) + # Copy over the autogenerated flatbuffer Python bindings. + generated_python_directory = "ray/core/generated" + for filename in os.listdir(generated_python_directory): + if filename[-3:] == ".py": + self.move_file(os.path.join(generated_python_directory, + filename)) - def move_file(self, filename): - # TODO(rkn): This feels very brittle. It may not handle all cases. See - # https://github.com/apache/arrow/blob/master/python/setup.py for an - # example. - source = filename - destination = os.path.join(self.build_lib, filename) - # Create the target directory if it doesn't already exist. - parent_directory = os.path.dirname(destination) - if not os.path.exists(parent_directory): - os.makedirs(parent_directory) - print("Copying {} to {}.".format(source, destination)) - shutil.copy(source, destination) + def move_file(self, filename): + # TODO(rkn): This feels very brittle. It may not handle all cases. See + # https://github.com/apache/arrow/blob/master/python/setup.py for an + # example. + source = filename + destination = os.path.join(self.build_lib, filename) + # Create the target directory if it doesn't already exist. + parent_directory = os.path.dirname(destination) + if not os.path.exists(parent_directory): + os.makedirs(parent_directory) + print("Copying {} to {}.".format(source, destination)) + shutil.copy(source, destination) files_to_include = [ @@ -54,8 +56,8 @@ files_to_include = [ class BinaryDistribution(Distribution): - def has_ext_modules(self): - return True + def has_ext_modules(self): + return True setup(name="ray", diff --git a/src/numbuf/python/test/runtest.py b/src/numbuf/python/test/runtest.py index 1932bb58b..9ed234af9 100644 --- a/src/numbuf/python/test/runtest.py +++ b/src/numbuf/python/test/runtest.py @@ -22,152 +22,152 @@ TEST_OBJECTS = [{(1, 2): 1}, {(): 2}, [1, "hello", 3.0], 42, 43, np.float32(1.0), np.float64(1.0)] if sys.version_info < (3, 0): - TEST_OBJECTS += [long(42), long(1 << 62)] # noqa: F821 + TEST_OBJECTS += [long(42), long(1 << 62)] # noqa: F821 class SerializationTests(unittest.TestCase): - def roundTripTest(self, data): - size, serialized = numbuf.serialize_list(data) - result = numbuf.deserialize_list(serialized) - assert_equal(data, result) + def roundTripTest(self, data): + size, serialized = numbuf.serialize_list(data) + result = numbuf.deserialize_list(serialized) + assert_equal(data, result) - def testSimple(self): - self.roundTripTest([1, 2, 3]) - self.roundTripTest([1.0, 2.0, 3.0]) - self.roundTripTest(['hello', 'world']) - self.roundTripTest([1, 'hello', 1.0]) - self.roundTripTest([{'hello': 1.0, 'world': 42}]) - self.roundTripTest([True, False]) + def testSimple(self): + self.roundTripTest([1, 2, 3]) + self.roundTripTest([1.0, 2.0, 3.0]) + self.roundTripTest(['hello', 'world']) + self.roundTripTest([1, 'hello', 1.0]) + self.roundTripTest([{'hello': 1.0, 'world': 42}]) + self.roundTripTest([True, False]) - def testNone(self): - self.roundTripTest([1, 2, None, 3]) + def testNone(self): + self.roundTripTest([1, 2, None, 3]) - def testNested(self): - self.roundTripTest([{"hello": {"world": (1, 2, 3)}}]) - self.roundTripTest([((1,), (1, 2, 3, (4, 5, 6), "string"))]) - self.roundTripTest([{"hello": [1, 2, 3]}]) - self.roundTripTest([{"hello": [1, [2, 3]]}]) - self.roundTripTest([{"hello": (None, 2, [3, 4])}]) - self.roundTripTest( - [{"hello": (None, 2, [3, 4], np.array([1.0, 2.0, 3.0]))}]) + def testNested(self): + self.roundTripTest([{"hello": {"world": (1, 2, 3)}}]) + self.roundTripTest([((1,), (1, 2, 3, (4, 5, 6), "string"))]) + self.roundTripTest([{"hello": [1, 2, 3]}]) + self.roundTripTest([{"hello": [1, [2, 3]]}]) + self.roundTripTest([{"hello": (None, 2, [3, 4])}]) + self.roundTripTest( + [{"hello": (None, 2, [3, 4], np.array([1.0, 2.0, 3.0]))}]) - def numpyTest(self, t): - a = np.random.randint(0, 10, size=(100, 100)).astype(t) - self.roundTripTest([a]) + def numpyTest(self, t): + a = np.random.randint(0, 10, size=(100, 100)).astype(t) + self.roundTripTest([a]) - def testArrays(self): - for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", "float32", - "float64"]: - self.numpyTest(t) + def testArrays(self): + for t in ["int8", "uint8", "int16", "uint16", "int32", "uint32", + "float32", "float64"]: + self.numpyTest(t) - def testRay(self): - for obj in TEST_OBJECTS: - self.roundTripTest([obj]) + def testRay(self): + for obj in TEST_OBJECTS: + self.roundTripTest([obj]) - def testCallback(self): + def testCallback(self): - class Foo(object): - def __init__(self): - self.x = 1 + class Foo(object): + def __init__(self): + self.x = 1 - class Bar(object): - def __init__(self): - self.foo = Foo() + class Bar(object): + def __init__(self): + self.foo = Foo() - def serialize(obj): - return dict(obj.__dict__, **{"_pytype_": type(obj).__name__}) + def serialize(obj): + return dict(obj.__dict__, **{"_pytype_": type(obj).__name__}) - def deserialize(obj): - if obj["_pytype_"] == "Foo": - result = Foo() - elif obj["_pytype_"] == "Bar": - result = Bar() + def deserialize(obj): + if obj["_pytype_"] == "Foo": + result = Foo() + elif obj["_pytype_"] == "Bar": + result = Bar() - obj.pop("_pytype_", None) - result.__dict__ = obj - return result + obj.pop("_pytype_", None) + result.__dict__ = obj + return result - bar = Bar() - bar.foo.x = 42 + bar = Bar() + bar.foo.x = 42 - numbuf.register_callbacks(serialize, deserialize) + numbuf.register_callbacks(serialize, deserialize) - size, serialized = numbuf.serialize_list([bar]) - self.assertEqual(numbuf.deserialize_list(serialized)[0].foo.x, 42) + size, serialized = numbuf.serialize_list([bar]) + self.assertEqual(numbuf.deserialize_list(serialized)[0].foo.x, 42) - def testObjectArray(self): - x = np.array([1, 2, "hello"], dtype=object) - y = np.array([[1, 2], [3, 4]], dtype=object) + def testObjectArray(self): + x = np.array([1, 2, "hello"], dtype=object) + y = np.array([[1, 2], [3, 4]], dtype=object) - def myserialize(obj): - return {"_pytype_": "numpy.array", "data": obj.tolist()} + def myserialize(obj): + return {"_pytype_": "numpy.array", "data": obj.tolist()} - def mydeserialize(obj): - if obj["_pytype_"] == "numpy.array": - return np.array(obj["data"], dtype=object) + def mydeserialize(obj): + if obj["_pytype_"] == "numpy.array": + return np.array(obj["data"], dtype=object) - numbuf.register_callbacks(myserialize, mydeserialize) + numbuf.register_callbacks(myserialize, mydeserialize) - size, serialized = numbuf.serialize_list([x, y]) + size, serialized = numbuf.serialize_list([x, y]) - assert_equal(numbuf.deserialize_list(serialized), [x, y]) + assert_equal(numbuf.deserialize_list(serialized), [x, y]) - def testBuffer(self): - for (i, obj) in enumerate(TEST_OBJECTS): - size, batch = numbuf.serialize_list([obj]) - size = size - buff = np.zeros(size, dtype="uint8") - numbuf.write_to_buffer(batch, memoryview(buff)) - array = numbuf.read_from_buffer(memoryview(buff)) - result = numbuf.deserialize_list(array) - assert_equal(result[0], obj) + def testBuffer(self): + for (i, obj) in enumerate(TEST_OBJECTS): + size, batch = numbuf.serialize_list([obj]) + size = size + buff = np.zeros(size, dtype="uint8") + numbuf.write_to_buffer(batch, memoryview(buff)) + array = numbuf.read_from_buffer(memoryview(buff)) + result = numbuf.deserialize_list(array) + assert_equal(result[0], obj) - def testObjectArrayImmutable(self): - obj = np.zeros([10]) - size, serialized = numbuf.serialize_list([obj]) - result = numbuf.deserialize_list(serialized) - assert_equal(result[0], obj) - with self.assertRaises(ValueError): - result[0][0] = 1 + def testObjectArrayImmutable(self): + obj = np.zeros([10]) + size, serialized = numbuf.serialize_list([obj]) + result = numbuf.deserialize_list(serialized) + assert_equal(result[0], obj) + with self.assertRaises(ValueError): + result[0][0] = 1 - def testArrowLimits(self): - # Test that objects that are too large for Arrow throw a Python exception. - # These tests give out of memory errors on Travis and need to be run on a - # machine with lots of RAM. - if os.getenv("TRAVIS") is None: - l = 2 ** 29 * [1.0] - with self.assertRaises(numbuf.numbuf_error): - self.roundTripTest(l) - self.roundTripTest([l]) - del l - l = 2 ** 29 * ["s"] - with self.assertRaises(numbuf.numbuf_error): - self.roundTripTest(l) - self.roundTripTest([l]) - del l - l = 2 ** 29 * [["1"], 2, 3, [{"s": 4}]] - with self.assertRaises(numbuf.numbuf_error): - self.roundTripTest(l) - self.roundTripTest([l]) - del l - with self.assertRaises(numbuf.numbuf_error): - l = 2 ** 29 * [{"s": 1}] + 2 ** 29 * [1.0] - self.roundTripTest(l) - self.roundTripTest([l]) - del l - with self.assertRaises(numbuf.numbuf_error): - l = np.zeros(2 ** 25) - self.roundTripTest([l]) - del l - with self.assertRaises(numbuf.numbuf_error): - l = [np.zeros(2 ** 18) for _ in range(2 ** 7)] - self.roundTripTest(l) - self.roundTripTest([l]) - del l - else: - print("Not running testArrowLimits on Travis because of the test's " - "memory requirements.") + def testArrowLimits(self): + # Test that objects that are too large for Arrow throw a Python + # exception. These tests give out of memory errors on Travis and need + # to be run on a machine with lots of RAM. + if os.getenv("TRAVIS") is None: + l = 2 ** 29 * [1.0] + with self.assertRaises(numbuf.numbuf_error): + self.roundTripTest(l) + self.roundTripTest([l]) + del l + l = 2 ** 29 * ["s"] + with self.assertRaises(numbuf.numbuf_error): + self.roundTripTest(l) + self.roundTripTest([l]) + del l + l = 2 ** 29 * [["1"], 2, 3, [{"s": 4}]] + with self.assertRaises(numbuf.numbuf_error): + self.roundTripTest(l) + self.roundTripTest([l]) + del l + with self.assertRaises(numbuf.numbuf_error): + l = 2 ** 29 * [{"s": 1}] + 2 ** 29 * [1.0] + self.roundTripTest(l) + self.roundTripTest([l]) + del l + with self.assertRaises(numbuf.numbuf_error): + l = np.zeros(2 ** 25) + self.roundTripTest([l]) + del l + with self.assertRaises(numbuf.numbuf_error): + l = [np.zeros(2 ** 18) for _ in range(2 ** 7)] + self.roundTripTest(l) + self.roundTripTest([l]) + del l + else: + print("Not running testArrowLimits on Travis because of the " + "test's memory requirements.") if __name__ == "__main__": diff --git a/src/plasma/setup.py b/src/plasma/setup.py index 1df0d4b00..2a6d3b160 100644 --- a/src/plasma/setup.py +++ b/src/plasma/setup.py @@ -9,17 +9,18 @@ import subprocess class install(_install.install): - def run(self): - subprocess.check_call(["make"]) - subprocess.check_call(["cp", "build/plasma_store", "plasma/plasma_store"]) - subprocess.check_call(["cp", "build/plasma_manager", - "plasma/plasma_manager"]) - subprocess.check_call(["cmake", ".."], cwd="./build") - subprocess.check_call(["make", "install"], cwd="./build") - # Calling _install.install.run(self) does not fetch required packages and - # instead performs an old-style install. See command/install.py in - # setuptools. So, calling do_egg_install() manually here. - self.do_egg_install() + def run(self): + subprocess.check_call(["make"]) + subprocess.check_call(["cp", "build/plasma_store", + "plasma/plasma_store"]) + subprocess.check_call(["cp", "build/plasma_manager", + "plasma/plasma_manager"]) + subprocess.check_call(["cmake", ".."], cwd="./build") + subprocess.check_call(["make", "install"], cwd="./build") + # Calling _install.install.run(self) does not fetch required packages + # and instead performs an old-style install. See command/install.py in + # setuptools. So, calling do_egg_install() manually here. + self.do_egg_install() setup(name="Plasma", diff --git a/test/actor_test.py b/test/actor_test.py index 638f5499a..8032b0774 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -13,1063 +13,1083 @@ import ray class ActorAPI(unittest.TestCase): - def testKeywordArgs(self): - ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) + def testKeywordArgs(self): + ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) - @ray.remote - class Actor(object): - def __init__(self, arg0, arg1=1, arg2="a"): - self.arg0 = arg0 - self.arg1 = arg1 - self.arg2 = arg2 + @ray.remote + class Actor(object): + def __init__(self, arg0, arg1=1, arg2="a"): + self.arg0 = arg0 + self.arg1 = arg1 + self.arg2 = arg2 - def get_values(self, arg0, arg1=2, arg2="b"): - return self.arg0 + arg0, self.arg1 + arg1, self.arg2 + arg2 + def get_values(self, arg0, arg1=2, arg2="b"): + return self.arg0 + arg0, self.arg1 + arg1, self.arg2 + arg2 - actor = Actor.remote(0) - self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, "ab")) + actor = Actor.remote(0) + self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, "ab")) - actor = Actor.remote(1, 2) - self.assertEqual(ray.get(actor.get_values.remote(2, 3)), (3, 5, "ab")) + actor = Actor.remote(1, 2) + self.assertEqual(ray.get(actor.get_values.remote(2, 3)), (3, 5, "ab")) - actor = Actor.remote(1, 2, "c") - self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")), (3, 5, "cd")) + actor = Actor.remote(1, 2, "c") + self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")), + (3, 5, "cd")) - actor = Actor.remote(1, arg2="c") - self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")), - (1, 3, "cd")) - self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), - (1, 1, "cd")) + actor = Actor.remote(1, arg2="c") + self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")), + (1, 3, "cd")) + self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), + (1, 1, "cd")) - actor = Actor.remote(1, arg2="c", arg1=2) - self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")), - (1, 4, "cd")) - self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), - (1, 2, "cd")) + actor = Actor.remote(1, arg2="c", arg1=2) + self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d")), + (1, 4, "cd")) + self.assertEqual(ray.get(actor.get_values.remote(0, arg2="d", arg1=0)), + (1, 2, "cd")) - # Make sure we get an exception if the constructor is called incorrectly. - with self.assertRaises(Exception): - actor = Actor.remote() + # Make sure we get an exception if the constructor is called + # incorrectly. + with self.assertRaises(Exception): + actor = Actor.remote() - with self.assertRaises(Exception): - actor = Actor.remote(0, 1, 2, arg3=3) + with self.assertRaises(Exception): + actor = Actor.remote(0, 1, 2, arg3=3) - # Make sure we get an exception if the method is called incorrectly. - actor = Actor.remote(1) - with self.assertRaises(Exception): - ray.get(actor.get_values.remote()) + # Make sure we get an exception if the method is called incorrectly. + actor = Actor.remote(1) + with self.assertRaises(Exception): + ray.get(actor.get_values.remote()) - ray.worker.cleanup() + ray.worker.cleanup() - def testVariableNumberOfArgs(self): - ray.init(num_workers=0) + def testVariableNumberOfArgs(self): + ray.init(num_workers=0) - @ray.remote - class Actor(object): - def __init__(self, arg0, arg1=1, *args): - self.arg0 = arg0 - self.arg1 = arg1 - self.args = args + @ray.remote + class Actor(object): + def __init__(self, arg0, arg1=1, *args): + self.arg0 = arg0 + self.arg1 = arg1 + self.args = args - def get_values(self, arg0, arg1=2, *args): - return self.arg0 + arg0, self.arg1 + arg1, self.args, args + def get_values(self, arg0, arg1=2, *args): + return self.arg0 + arg0, self.arg1 + arg1, self.args, args - actor = Actor.remote(0) - self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, (), ())) + actor = Actor.remote(0) + self.assertEqual(ray.get(actor.get_values.remote(1)), (1, 3, (), ())) - actor = Actor.remote(1, 2) - self.assertEqual(ray.get(actor.get_values.remote(2, 3)), (3, 5, (), ())) + actor = Actor.remote(1, 2) + self.assertEqual(ray.get(actor.get_values.remote(2, 3)), + (3, 5, (), ())) - actor = Actor.remote(1, 2, "c") - self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")), - (3, 5, ("c",), ("d",))) + actor = Actor.remote(1, 2, "c") + self.assertEqual(ray.get(actor.get_values.remote(2, 3, "d")), + (3, 5, ("c",), ("d",))) - actor = Actor.remote(1, 2, "a", "b", "c", "d") - self.assertEqual(ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)), - (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4))) + actor = Actor.remote(1, 2, "a", "b", "c", "d") + self.assertEqual(ray.get(actor.get_values.remote(2, 3, 1, 2, 3, 4)), + (3, 5, ("a", "b", "c", "d"), (1, 2, 3, 4))) - @ray.remote - class Actor(object): - def __init__(self, *args): - self.args = args + @ray.remote + class Actor(object): + def __init__(self, *args): + self.args = args - def get_values(self, *args): - return self.args, args + def get_values(self, *args): + return self.args, args - a = Actor.remote() - self.assertEqual(ray.get(a.get_values.remote()), ((), ())) - a = Actor.remote(1) - self.assertEqual(ray.get(a.get_values.remote(2)), ((1,), (2,))) - a = Actor.remote(1, 2) - self.assertEqual(ray.get(a.get_values.remote(3, 4)), ((1, 2), (3, 4))) + a = Actor.remote() + self.assertEqual(ray.get(a.get_values.remote()), ((), ())) + a = Actor.remote(1) + self.assertEqual(ray.get(a.get_values.remote(2)), ((1,), (2,))) + a = Actor.remote(1, 2) + self.assertEqual(ray.get(a.get_values.remote(3, 4)), ((1, 2), (3, 4))) - ray.worker.cleanup() + ray.worker.cleanup() - def testNoArgs(self): - ray.init(num_workers=0) + def testNoArgs(self): + ray.init(num_workers=0) - @ray.remote - class Actor(object): - def __init__(self): - pass + @ray.remote + class Actor(object): + def __init__(self): + pass - def get_values(self): - pass + def get_values(self): + pass - actor = Actor.remote() - self.assertEqual(ray.get(actor.get_values.remote()), None) + actor = Actor.remote() + self.assertEqual(ray.get(actor.get_values.remote()), None) - ray.worker.cleanup() + ray.worker.cleanup() - def testNoConstructor(self): - # If no __init__ method is provided, that should not be a problem. - ray.init(num_workers=0) + def testNoConstructor(self): + # If no __init__ method is provided, that should not be a problem. + ray.init(num_workers=0) - @ray.remote - class Actor(object): - def get_values(self): - pass + @ray.remote + class Actor(object): + def get_values(self): + pass - actor = Actor.remote() - self.assertEqual(ray.get(actor.get_values.remote()), None) + actor = Actor.remote() + self.assertEqual(ray.get(actor.get_values.remote()), None) - ray.worker.cleanup() + ray.worker.cleanup() - def testCustomClasses(self): - ray.init(num_workers=0) + def testCustomClasses(self): + ray.init(num_workers=0) - class Foo(object): - def __init__(self, x): - self.x = x + class Foo(object): + def __init__(self, x): + self.x = x - @ray.remote - class Actor(object): - def __init__(self, f2): - self.f1 = Foo(1) - self.f2 = f2 + @ray.remote + class Actor(object): + def __init__(self, f2): + self.f1 = Foo(1) + self.f2 = f2 - def get_values1(self): - return self.f1, self.f2 + def get_values1(self): + return self.f1, self.f2 - def get_values2(self, f3): - return self.f1, self.f2, f3 + def get_values2(self, f3): + return self.f1, self.f2, f3 - actor = Actor.remote(Foo(2)) - results1 = ray.get(actor.get_values1.remote()) - self.assertEqual(results1[0].x, 1) - self.assertEqual(results1[1].x, 2) - results2 = ray.get(actor.get_values2.remote(Foo(3))) - self.assertEqual(results2[0].x, 1) - self.assertEqual(results2[1].x, 2) - self.assertEqual(results2[2].x, 3) + actor = Actor.remote(Foo(2)) + results1 = ray.get(actor.get_values1.remote()) + self.assertEqual(results1[0].x, 1) + self.assertEqual(results1[1].x, 2) + results2 = ray.get(actor.get_values2.remote(Foo(3))) + self.assertEqual(results2[0].x, 1) + self.assertEqual(results2[1].x, 2) + self.assertEqual(results2[2].x, 3) - ray.worker.cleanup() + ray.worker.cleanup() - # def testCachingActors(self): - # # TODO(rkn): Implement this. - # pass + # def testCachingActors(self): + # # TODO(rkn): Implement this. + # pass - def testDecoratorArgs(self): - ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) + def testDecoratorArgs(self): + ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) - # This is an invalid way of using the actor decorator. - with self.assertRaises(Exception): - @ray.remote() - class Actor(object): - def __init__(self): - pass + # This is an invalid way of using the actor decorator. + with self.assertRaises(Exception): + @ray.remote() + class Actor(object): + def __init__(self): + pass - # This is an invalid way of using the actor decorator. - with self.assertRaises(Exception): - @ray.remote(invalid_kwarg=0) # noqa: F811 - class Actor(object): - def __init__(self): - pass + # This is an invalid way of using the actor decorator. + with self.assertRaises(Exception): + @ray.remote(invalid_kwarg=0) # noqa: F811 + class Actor(object): + def __init__(self): + pass - # This is an invalid way of using the actor decorator. - with self.assertRaises(Exception): - @ray.remote(num_cpus=0, invalid_kwarg=0) # noqa: F811 - class Actor(object): - def __init__(self): - pass + # This is an invalid way of using the actor decorator. + with self.assertRaises(Exception): + @ray.remote(num_cpus=0, invalid_kwarg=0) # noqa: F811 + class Actor(object): + def __init__(self): + pass - # This is a valid way of using the decorator. - @ray.remote(num_cpus=1) # noqa: F811 - class Actor(object): - def __init__(self): - pass + # This is a valid way of using the decorator. + @ray.remote(num_cpus=1) # noqa: F811 + class Actor(object): + def __init__(self): + pass - # This is a valid way of using the decorator. - @ray.remote(num_gpus=1) # noqa: F811 - class Actor(object): - def __init__(self): - pass + # This is a valid way of using the decorator. + @ray.remote(num_gpus=1) # noqa: F811 + class Actor(object): + def __init__(self): + pass - # This is a valid way of using the decorator. - @ray.remote(num_cpus=1, num_gpus=1) # noqa: F811 - class Actor(object): - def __init__(self): - pass + # This is a valid way of using the decorator. + @ray.remote(num_cpus=1, num_gpus=1) # noqa: F811 + class Actor(object): + def __init__(self): + pass - ray.worker.cleanup() + ray.worker.cleanup() - def testRandomIDGeneration(self): - ray.init(num_workers=0) + def testRandomIDGeneration(self): + ray.init(num_workers=0) - @ray.remote - class Foo(object): - def __init__(self): - pass + @ray.remote + class Foo(object): + def __init__(self): + pass - # Make sure that seeding numpy does not interfere with the generation of - # actor IDs. - np.random.seed(1234) - random.seed(1234) - f1 = Foo.remote() - np.random.seed(1234) - random.seed(1234) - f2 = Foo.remote() + # Make sure that seeding numpy does not interfere with the generation + # of actor IDs. + np.random.seed(1234) + random.seed(1234) + f1 = Foo.remote() + np.random.seed(1234) + random.seed(1234) + f2 = Foo.remote() - self.assertNotEqual(f1._ray_actor_id.id(), f2._ray_actor_id.id()) + self.assertNotEqual(f1._ray_actor_id.id(), f2._ray_actor_id.id()) - ray.worker.cleanup() + ray.worker.cleanup() - def testActorClassName(self): - ray.init(num_workers=0) + def testActorClassName(self): + ray.init(num_workers=0) - @ray.remote - class Foo(object): - def __init__(self): - pass + @ray.remote + class Foo(object): + def __init__(self): + pass - Foo.remote() + Foo.remote() - r = ray.worker.global_worker.redis_client - actor_keys = r.keys("ActorClass*") - self.assertEqual(len(actor_keys), 1) - actor_class_info = r.hgetall(actor_keys[0]) - self.assertEqual(actor_class_info[b"class_name"], b"Foo") - self.assertEqual(actor_class_info[b"module"], b"__main__") + r = ray.worker.global_worker.redis_client + actor_keys = r.keys("ActorClass*") + self.assertEqual(len(actor_keys), 1) + actor_class_info = r.hgetall(actor_keys[0]) + self.assertEqual(actor_class_info[b"class_name"], b"Foo") + self.assertEqual(actor_class_info[b"module"], b"__main__") - ray.worker.cleanup() + ray.worker.cleanup() class ActorMethods(unittest.TestCase): - def testDefineActor(self): - ray.init() + def testDefineActor(self): + ray.init() - @ray.remote - class Test(object): - def __init__(self, x): - self.x = x + @ray.remote + class Test(object): + def __init__(self, x): + self.x = x - def f(self, y): - return self.x + y + def f(self, y): + return self.x + y - t = Test.remote(2) - self.assertEqual(ray.get(t.f.remote(1)), 3) + t = Test.remote(2) + self.assertEqual(ray.get(t.f.remote(1)), 3) - # Make sure that calling an actor method directly raises an exception. - with self.assertRaises(Exception): - t.f(1) + # Make sure that calling an actor method directly raises an exception. + with self.assertRaises(Exception): + t.f(1) - ray.worker.cleanup() + ray.worker.cleanup() - def testActorState(self): - ray.init() + def testActorState(self): + ray.init() - @ray.remote - class Counter(object): - def __init__(self): - self.value = 0 + @ray.remote + class Counter(object): + def __init__(self): + self.value = 0 - def increase(self): - self.value += 1 + def increase(self): + self.value += 1 - def value(self): - return self.value + def value(self): + return self.value - c1 = Counter.remote() - c1.increase.remote() - self.assertEqual(ray.get(c1.value.remote()), 1) + c1 = Counter.remote() + c1.increase.remote() + self.assertEqual(ray.get(c1.value.remote()), 1) - c2 = Counter.remote() - c2.increase.remote() - c2.increase.remote() - self.assertEqual(ray.get(c2.value.remote()), 2) + c2 = Counter.remote() + c2.increase.remote() + c2.increase.remote() + self.assertEqual(ray.get(c2.value.remote()), 2) - ray.worker.cleanup() + ray.worker.cleanup() - def testMultipleActors(self): - # Create a bunch of actors and call a bunch of methods on all of them. - ray.init(num_workers=0) + def testMultipleActors(self): + # Create a bunch of actors and call a bunch of methods on all of them. + ray.init(num_workers=0) - @ray.remote - class Counter(object): - def __init__(self, value): - self.value = value + @ray.remote + class Counter(object): + def __init__(self, value): + self.value = value - def increase(self): - self.value += 1 - return self.value + def increase(self): + self.value += 1 + return self.value - def reset(self): - self.value = 0 + def reset(self): + self.value = 0 - num_actors = 20 - num_increases = 50 - # Create multiple actors. - actors = [Counter.remote(i) for i in range(num_actors)] - results = [] - # Call each actor's method a bunch of times. - for i in range(num_actors): - results += [actors[i].increase.remote() for _ in range(num_increases)] - result_values = ray.get(results) - for i in range(num_actors): - self.assertEqual( - result_values[(num_increases * i):(num_increases * (i + 1))], - list(range(i + 1, num_increases + i + 1))) + num_actors = 20 + num_increases = 50 + # Create multiple actors. + actors = [Counter.remote(i) for i in range(num_actors)] + results = [] + # Call each actor's method a bunch of times. + for i in range(num_actors): + results += [actors[i].increase.remote() + for _ in range(num_increases)] + result_values = ray.get(results) + for i in range(num_actors): + self.assertEqual( + result_values[(num_increases * i):(num_increases * (i + 1))], + list(range(i + 1, num_increases + i + 1))) - # Reset the actor values. - [actor.reset.remote() for actor in actors] + # Reset the actor values. + [actor.reset.remote() for actor in actors] - # Interweave the method calls on the different actors. - results = [] - for j in range(num_increases): - results += [actor.increase.remote() for actor in actors] - result_values = ray.get(results) - for j in range(num_increases): - self.assertEqual(result_values[(num_actors * j):(num_actors * (j + 1))], - num_actors * [j + 1]) + # Interweave the method calls on the different actors. + results = [] + for j in range(num_increases): + results += [actor.increase.remote() for actor in actors] + result_values = ray.get(results) + for j in range(num_increases): + self.assertEqual( + result_values[(num_actors * j):(num_actors * (j + 1))], + num_actors * [j + 1]) - ray.worker.cleanup() + ray.worker.cleanup() class ActorNesting(unittest.TestCase): - def testRemoteFunctionWithinActor(self): - # Make sure we can use remote funtions within actors. - ray.init(num_cpus=100) + def testRemoteFunctionWithinActor(self): + # Make sure we can use remote funtions within actors. + ray.init(num_cpus=100) - # Create some values to close over. - val1 = 1 - val2 = 2 + # Create some values to close over. + val1 = 1 + val2 = 2 - @ray.remote - def f(x): - return val1 + x + @ray.remote + def f(x): + return val1 + x - @ray.remote - def g(x): - return ray.get(f.remote(x)) + @ray.remote + def g(x): + return ray.get(f.remote(x)) - @ray.remote - class Actor(object): - def __init__(self, x): - self.x = x - self.y = val2 - self.object_ids = [f.remote(i) for i in range(5)] - self.values2 = ray.get([f.remote(i) for i in range(5)]) + @ray.remote + class Actor(object): + def __init__(self, x): + self.x = x + self.y = val2 + self.object_ids = [f.remote(i) for i in range(5)] + self.values2 = ray.get([f.remote(i) for i in range(5)]) - def get_values(self): - return self.x, self.y, self.object_ids, self.values2 + def get_values(self): + return self.x, self.y, self.object_ids, self.values2 - def f(self): - return [f.remote(i) for i in range(5)] + def f(self): + return [f.remote(i) for i in range(5)] - def g(self): - return ray.get([g.remote(i) for i in range(5)]) + def g(self): + return ray.get([g.remote(i) for i in range(5)]) - def h(self, object_ids): - return ray.get(object_ids) + def h(self, object_ids): + return ray.get(object_ids) - actor = Actor.remote(1) - values = ray.get(actor.get_values.remote()) - self.assertEqual(values[0], 1) - self.assertEqual(values[1], val2) - self.assertEqual(ray.get(values[2]), list(range(1, 6))) - self.assertEqual(values[3], list(range(1, 6))) + actor = Actor.remote(1) + values = ray.get(actor.get_values.remote()) + self.assertEqual(values[0], 1) + self.assertEqual(values[1], val2) + self.assertEqual(ray.get(values[2]), list(range(1, 6))) + self.assertEqual(values[3], list(range(1, 6))) - self.assertEqual(ray.get(ray.get(actor.f.remote())), list(range(1, 6))) - self.assertEqual(ray.get(actor.g.remote()), list(range(1, 6))) - self.assertEqual(ray.get(actor.h.remote([f.remote(i) for i in range(5)])), - list(range(1, 6))) + self.assertEqual(ray.get(ray.get(actor.f.remote())), list(range(1, 6))) + self.assertEqual(ray.get(actor.g.remote()), list(range(1, 6))) + self.assertEqual( + ray.get(actor.h.remote([f.remote(i) for i in range(5)])), + list(range(1, 6))) - ray.worker.cleanup() + ray.worker.cleanup() - def testDefineActorWithinActor(self): - # Make sure we can use remote funtions within actors. - ray.init(num_cpus=10) + def testDefineActorWithinActor(self): + # Make sure we can use remote funtions within actors. + ray.init(num_cpus=10) - @ray.remote - class Actor1(object): - def __init__(self, x): - self.x = x + @ray.remote + class Actor1(object): + def __init__(self, x): + self.x = x + + def new_actor(self, z): + @ray.remote + class Actor2(object): + def __init__(self, x): + self.x = x + + def get_value(self): + return self.x + self.actor2 = Actor2.remote(z) + + def get_values(self, z): + self.new_actor(z) + return self.x, ray.get(self.actor2.get_value.remote()) + + actor1 = Actor1.remote(3) + self.assertEqual(ray.get(actor1.get_values.remote(5)), (3, 5)) + + ray.worker.cleanup() + + def testUseActorWithinActor(self): + # Make sure we can use actors within actors. + ray.init(num_cpus=10) + + @ray.remote + class Actor1(object): + def __init__(self, x): + self.x = x + + def get_val(self): + return self.x - def new_actor(self, z): @ray.remote class Actor2(object): - def __init__(self, x): - self.x = x + def __init__(self, x, y): + self.x = x + self.actor1 = Actor1.remote(y) - def get_value(self): - return self.x - self.actor2 = Actor2.remote(z) + def get_values(self, z): + return self.x, ray.get(self.actor1.get_val.remote()) - def get_values(self, z): - self.new_actor(z) - return self.x, ray.get(self.actor2.get_value.remote()) + actor2 = Actor2.remote(3, 4) + self.assertEqual(ray.get(actor2.get_values.remote(5)), (3, 4)) - actor1 = Actor1.remote(3) - self.assertEqual(ray.get(actor1.get_values.remote(5)), (3, 5)) + ray.worker.cleanup() - ray.worker.cleanup() + def testDefineActorWithinRemoteFunction(self): + # Make sure we can define and actors within remote funtions. + ray.init(num_cpus=10) - def testUseActorWithinActor(self): - # Make sure we can use actors within actors. - ray.init(num_cpus=10) + @ray.remote + def f(x, n): + @ray.remote + class Actor1(object): + def __init__(self, x): + self.x = x - @ray.remote - class Actor1(object): - def __init__(self, x): - self.x = x + def get_value(self): + return self.x + actor = Actor1.remote(x) + return ray.get([actor.get_value.remote() for _ in range(n)]) - def get_val(self): - return self.x + self.assertEqual(ray.get(f.remote(3, 1)), [3]) + self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]), + [20 * [i] for i in range(10)]) - @ray.remote - class Actor2(object): - def __init__(self, x, y): - self.x = x - self.actor1 = Actor1.remote(y) + ray.worker.cleanup() - def get_values(self, z): - return self.x, ray.get(self.actor1.get_val.remote()) + def testUseActorWithinRemoteFunction(self): + # Make sure we can create and use actors within remote funtions. + ray.init(num_cpus=10) - actor2 = Actor2.remote(3, 4) - self.assertEqual(ray.get(actor2.get_values.remote(5)), (3, 4)) + @ray.remote + class Actor1(object): + def __init__(self, x): + self.x = x - ray.worker.cleanup() + def get_values(self): + return self.x - def testDefineActorWithinRemoteFunction(self): - # Make sure we can define and actors within remote funtions. - ray.init(num_cpus=10) + @ray.remote + def f(x): + actor = Actor1.remote(x) + return ray.get(actor.get_values.remote()) - @ray.remote - def f(x, n): - @ray.remote - class Actor1(object): - def __init__(self, x): - self.x = x + self.assertEqual(ray.get(f.remote(3)), 3) - def get_value(self): - return self.x - actor = Actor1.remote(x) - return ray.get([actor.get_value.remote() for _ in range(n)]) + ray.worker.cleanup() - self.assertEqual(ray.get(f.remote(3, 1)), [3]) - self.assertEqual(ray.get([f.remote(i, 20) for i in range(10)]), - [20 * [i] for i in range(10)]) + def testActorImportCounter(self): + # This is mostly a test of the export counters to make sure that when + # an actor is imported, all of the necessary remote functions have been + # imported. + ray.init(num_cpus=10) - ray.worker.cleanup() + # Export a bunch of remote functions. + num_remote_functions = 50 + for i in range(num_remote_functions): + @ray.remote + def f(): + return i - def testUseActorWithinRemoteFunction(self): - # Make sure we can create and use actors within remote funtions. - ray.init(num_cpus=10) + @ray.remote + def g(): + @ray.remote + class Actor(object): + def __init__(self): + # This should use the last version of f. + self.x = ray.get(f.remote()) - @ray.remote - class Actor1(object): - def __init__(self, x): - self.x = x + def get_val(self): + return self.x - def get_values(self): - return self.x + actor = Actor.remote() + return ray.get(actor.get_val.remote()) - @ray.remote - def f(x): - actor = Actor1.remote(x) - return ray.get(actor.get_values.remote()) + self.assertEqual(ray.get(g.remote()), num_remote_functions - 1) - self.assertEqual(ray.get(f.remote(3)), 3) - - ray.worker.cleanup() - - def testActorImportCounter(self): - # This is mostly a test of the export counters to make sure that when an - # actor is imported, all of the necessary remote functions have been - # imported. - ray.init(num_cpus=10) - - # Export a bunch of remote functions. - num_remote_functions = 50 - for i in range(num_remote_functions): - @ray.remote - def f(): - return i - - @ray.remote - def g(): - @ray.remote - class Actor(object): - def __init__(self): - # This should use the last version of f. - self.x = ray.get(f.remote()) - - def get_val(self): - return self.x - - actor = Actor.remote() - return ray.get(actor.get_val.remote()) - - self.assertEqual(ray.get(g.remote()), num_remote_functions - 1) - - ray.worker.cleanup() + ray.worker.cleanup() class ActorInheritance(unittest.TestCase): - def testInheritActorFromClass(self): - # Make sure we can define an actor by inheriting from a regular class. Note - # that actors cannot inherit from other actors. - ray.init() + def testInheritActorFromClass(self): + # Make sure we can define an actor by inheriting from a regular class. + # Note that actors cannot inherit from other actors. + ray.init() - class Foo(object): - def __init__(self, x): - self.x = x + class Foo(object): + def __init__(self, x): + self.x = x - def f(self): - return self.x + def f(self): + return self.x - def g(self, y): - return self.x + y + def g(self, y): + return self.x + y - @ray.remote - class Actor(Foo): - def __init__(self, x): - Foo.__init__(self, x) + @ray.remote + class Actor(Foo): + def __init__(self, x): + Foo.__init__(self, x) - def get_value(self): - return self.f() + def get_value(self): + return self.f() - actor = Actor.remote(1) - self.assertEqual(ray.get(actor.get_value.remote()), 1) - self.assertEqual(ray.get(actor.g.remote(5)), 6) + actor = Actor.remote(1) + self.assertEqual(ray.get(actor.get_value.remote()), 1) + self.assertEqual(ray.get(actor.g.remote(5)), 6) - ray.worker.cleanup() + ray.worker.cleanup() class ActorSchedulingProperties(unittest.TestCase): - def testRemoteFunctionsNotScheduledOnActors(self): - # Make sure that regular remote functions are not scheduled on actors. - ray.init(num_workers=0) + def testRemoteFunctionsNotScheduledOnActors(self): + # Make sure that regular remote functions are not scheduled on actors. + ray.init(num_workers=0) - @ray.remote - class Actor(object): - def __init__(self): - pass + @ray.remote + class Actor(object): + def __init__(self): + pass - def get_id(self): - return ray.worker.global_worker.worker_id + def get_id(self): + return ray.worker.global_worker.worker_id - a = Actor.remote() - actor_id = ray.get(a.get_id.remote()) + a = Actor.remote() + actor_id = ray.get(a.get_id.remote()) - @ray.remote - def f(): - return ray.worker.global_worker.worker_id + @ray.remote + def f(): + return ray.worker.global_worker.worker_id - resulting_ids = ray.get([f.remote() for _ in range(100)]) - self.assertNotIn(actor_id, resulting_ids) + resulting_ids = ray.get([f.remote() for _ in range(100)]) + self.assertNotIn(actor_id, resulting_ids) - ray.worker.cleanup() + ray.worker.cleanup() class ActorsOnMultipleNodes(unittest.TestCase): - def testActorsOnNodesWithNoCPUs(self): - ray.init(num_cpus=0) + def testActorsOnNodesWithNoCPUs(self): + ray.init(num_cpus=0) - @ray.remote - class Foo(object): - def __init__(self): - pass + @ray.remote + class Foo(object): + def __init__(self): + pass - with self.assertRaises(Exception): - Foo.remote() + with self.assertRaises(Exception): + Foo.remote() - ray.worker.cleanup() + ray.worker.cleanup() - def testActorLoadBalancing(self): - num_local_schedulers = 3 - ray.worker._init(start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers) + def testActorLoadBalancing(self): + num_local_schedulers = 3 + ray.worker._init(start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers) - @ray.remote - class Actor1(object): - def __init__(self): - pass + @ray.remote + class Actor1(object): + def __init__(self): + pass - def get_location(self): - return ray.worker.global_worker.plasma_client.store_socket_name + def get_location(self): + return ray.worker.global_worker.plasma_client.store_socket_name - # Create a bunch of actors. - num_actors = 30 - num_attempts = 20 - minimum_count = 5 + # Create a bunch of actors. + num_actors = 30 + num_attempts = 20 + minimum_count = 5 - # Make sure that actors are spread between the local schedulers. - attempts = 0 - while attempts < num_attempts: - actors = [Actor1.remote() for _ in range(num_actors)] - locations = ray.get([actor.get_location.remote() for actor in actors]) - names = set(locations) - counts = [locations.count(name) for name in names] - print("Counts are {}.".format(counts)) - if len(names) == num_local_schedulers and all([count >= minimum_count - for count in counts]): - break - attempts += 1 - self.assertLess(attempts, num_attempts) + # Make sure that actors are spread between the local schedulers. + attempts = 0 + while attempts < num_attempts: + actors = [Actor1.remote() for _ in range(num_actors)] + locations = ray.get([actor.get_location.remote() + for actor in actors]) + names = set(locations) + counts = [locations.count(name) for name in names] + print("Counts are {}.".format(counts)) + if (len(names) == num_local_schedulers and + all([count >= minimum_count for count in counts])): + break + attempts += 1 + self.assertLess(attempts, num_attempts) - # Make sure we can get the results of a bunch of tasks. - results = [] - for _ in range(1000): - index = np.random.randint(num_actors) - results.append(actors[index].get_location.remote()) - ray.get(results) + # Make sure we can get the results of a bunch of tasks. + results = [] + for _ in range(1000): + index = np.random.randint(num_actors) + results.append(actors[index].get_location.remote()) + ray.get(results) - ray.worker.cleanup() + ray.worker.cleanup() class ActorsWithGPUs(unittest.TestCase): - def testActorGPUs(self): - num_local_schedulers = 3 - num_gpus_per_scheduler = 4 - ray.worker._init( - start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) - - @ray.remote(num_gpus=1) - class Actor1(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - - def get_location_and_ids(self): - assert ray.get_gpu_ids() == self.gpu_ids - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - - # Create one actor per GPU. - actors = [Actor1.remote() for _ - in range(num_local_schedulers * num_gpus_per_scheduler)] - # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids.remote() - for actor in actors]) - node_names = set([location for location, gpu_id in locations_and_ids]) - self.assertEqual(len(node_names), num_local_schedulers) - location_actor_combinations = [] - for node_name in node_names: - for gpu_id in range(num_gpus_per_scheduler): - location_actor_combinations.append((node_name, (gpu_id,))) - self.assertEqual(set(locations_and_ids), set(location_actor_combinations)) - - # Creating a new actor should fail because all of the GPUs are being used. - with self.assertRaises(Exception): - Actor1.remote() - - ray.worker.cleanup() - - def testActorMultipleGPUs(self): - num_local_schedulers = 3 - num_gpus_per_scheduler = 5 - ray.worker._init( - start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) - - @ray.remote(num_gpus=2) - class Actor1(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - - def get_location_and_ids(self): - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - - # Create some actors. - actors1 = [Actor1.remote() for _ in range(num_local_schedulers * 2)] - # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids.remote() - for actor in actors1]) - node_names = set([location for location, gpu_id in locations_and_ids]) - self.assertEqual(len(node_names), num_local_schedulers) - - # Keep track of which GPU IDs are being used for each location. - gpus_in_use = {node_name: [] for node_name in node_names} - for location, gpu_ids in locations_and_ids: - gpus_in_use[location].extend(gpu_ids) - for node_name in node_names: - self.assertEqual(len(set(gpus_in_use[node_name])), 4) - - # Creating a new actor should fail because all of the GPUs are being used. - with self.assertRaises(Exception): - Actor1.remote() - - # We should be able to create more actors that use only a single GPU. - @ray.remote(num_gpus=1) - class Actor2(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - - def get_location_and_ids(self): - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - - # Create some actors. - actors2 = [Actor2.remote() for _ in range(num_local_schedulers)] - # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids.remote() - for actor in actors2]) - self.assertEqual(node_names, - set([location for location, gpu_id in locations_and_ids])) - for location, gpu_ids in locations_and_ids: - gpus_in_use[location].extend(gpu_ids) - for node_name in node_names: - self.assertEqual(len(gpus_in_use[node_name]), 5) - self.assertEqual(set(gpus_in_use[node_name]), set(range(5))) - - # Creating a new actor should fail because all of the GPUs are being used. - with self.assertRaises(Exception): - Actor2.remote() - - ray.worker.cleanup() - - def testActorDifferentNumbersOfGPUs(self): - # Test that we can create actors on two nodes that have different numbers - # of GPUs. - ray.worker._init(start_ray_local=True, num_workers=0, - num_local_schedulers=3, num_gpus=[0, 5, 10]) - - @ray.remote(num_gpus=1) - class Actor1(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - - def get_location_and_ids(self): - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - - # Create some actors. - actors = [Actor1.remote() for _ in range(0 + 5 + 10)] - # Make sure that no two actors are assigned to the same GPU. - locations_and_ids = ray.get([actor.get_location_and_ids.remote() - for actor in actors]) - node_names = set([location for location, gpu_id in locations_and_ids]) - self.assertEqual(len(node_names), 2) - for node_name in node_names: - node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids - if location == node_name] - self.assertIn(len(node_gpu_ids), [5, 10]) - self.assertEqual(set(node_gpu_ids), - set([(i,) for i in range(len(node_gpu_ids))])) - - # Creating a new actor should fail because all of the GPUs are being used. - with self.assertRaises(Exception): - Actor1.remote() - - ray.worker.cleanup() - - def testActorMultipleGPUsFromMultipleTasks(self): - num_local_schedulers = 10 - num_gpus_per_scheduler = 10 - ray.worker._init( - start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, redirect_output=True, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) - - @ray.remote - def create_actors(n): - @ray.remote(num_gpus=1) - class Actor(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - - def get_location_and_ids(self): - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - # Create n actors. - for _ in range(n): - Actor.remote() - - ray.get([create_actors.remote(num_gpus_per_scheduler) - for _ in range(num_local_schedulers)]) - - @ray.remote(num_gpus=1) - class Actor(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - - def get_location_and_ids(self): - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - - # All the GPUs should be used up now. - with self.assertRaises(Exception): - Actor.remote() - - ray.worker.cleanup() - - def testActorsAndTasksWithGPUs(self): - num_local_schedulers = 3 - num_gpus_per_scheduler = 6 - ray.worker._init( - start_ray_local=True, num_workers=0, - num_local_schedulers=num_local_schedulers, - num_cpus=num_gpus_per_scheduler, - num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) - - def check_intervals_non_overlapping(list_of_intervals): - for i in range(len(list_of_intervals)): - for j in range(i): - first_interval = list_of_intervals[i] - second_interval = list_of_intervals[j] - # Check that list_of_intervals[i] and list_of_intervals[j] don't - # overlap. - assert first_interval[0] < first_interval[1] - assert second_interval[0] < second_interval[1] - assert (first_interval[1] < second_interval[0] or - second_interval[1] < first_interval[0]) - - @ray.remote(num_gpus=1) - def f1(): - t1 = time.time() - time.sleep(0.1) - t2 = time.time() - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 1 - assert gpu_ids[0] in range(num_gpus_per_scheduler) - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(gpu_ids), [t1, t2]) - - @ray.remote(num_gpus=2) - def f2(): - t1 = time.time() - time.sleep(0.1) - t2 = time.time() - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 2 - assert gpu_ids[0] in range(num_gpus_per_scheduler) - assert gpu_ids[1] in range(num_gpus_per_scheduler) - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(gpu_ids), [t1, t2]) - - @ray.remote(num_gpus=1) - class Actor1(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - assert len(self.gpu_ids) == 1 - assert self.gpu_ids[0] in range(num_gpus_per_scheduler) - - def get_location_and_ids(self): - assert ray.get_gpu_ids() == self.gpu_ids - return (ray.worker.global_worker.plasma_client.store_socket_name, - tuple(self.gpu_ids)) - - def locations_to_intervals_for_many_tasks(): - # Launch a bunch of GPU tasks. - locations_ids_and_intervals = ray.get( - [f1.remote() for _ - in range(5 * num_local_schedulers * num_gpus_per_scheduler)] + - [f2.remote() for _ - in range(5 * num_local_schedulers * num_gpus_per_scheduler)] + - [f1.remote() for _ - in range(5 * num_local_schedulers * num_gpus_per_scheduler)]) - - locations_to_intervals = collections.defaultdict(lambda: []) - for location, gpu_ids, interval in locations_ids_and_intervals: - for gpu_id in gpu_ids: - locations_to_intervals[(location, gpu_id)].append(interval) - return locations_to_intervals - - # Run a bunch of GPU tasks. - locations_to_intervals = locations_to_intervals_for_many_tasks() - # Make sure that all GPUs were used. - self.assertEqual(len(locations_to_intervals), - num_local_schedulers * num_gpus_per_scheduler) - # For each GPU, verify that the set of tasks that used this specific GPU - # did not overlap in time. - for locations in locations_to_intervals: - check_intervals_non_overlapping(locations_to_intervals[locations]) - - # Create an actor that uses a GPU. - a = Actor1.remote() - actor_location = ray.get(a.get_location_and_ids.remote()) - actor_location = (actor_location[0], actor_location[1][0]) - # This check makes sure that actor_location is formatted the same way that - # the keys of locations_to_intervals are formatted. - self.assertIn(actor_location, locations_to_intervals) - - # Run a bunch of GPU tasks. - locations_to_intervals = locations_to_intervals_for_many_tasks() - # Make sure that all but one of the GPUs were used. - self.assertEqual(len(locations_to_intervals), - num_local_schedulers * num_gpus_per_scheduler - 1) - # For each GPU, verify that the set of tasks that used this specific GPU - # did not overlap in time. - for locations in locations_to_intervals: - check_intervals_non_overlapping(locations_to_intervals[locations]) - # Make sure that the actor's GPU was not used. - self.assertNotIn(actor_location, locations_to_intervals) - - # Create several more actors that use GPUs. - actors = [Actor1.remote() for _ in range(3)] - actor_locations = ray.get([actor.get_location_and_ids.remote() - for actor in actors]) - - # Run a bunch of GPU tasks. - locations_to_intervals = locations_to_intervals_for_many_tasks() - # Make sure that all but 11 of the GPUs were used. - self.assertEqual(len(locations_to_intervals), - num_local_schedulers * num_gpus_per_scheduler - 1 - 3) - # For each GPU, verify that the set of tasks that used this specific GPU - # did not overlap in time. - for locations in locations_to_intervals: - check_intervals_non_overlapping(locations_to_intervals[locations]) - # Make sure that the GPUs were not used. - self.assertNotIn(actor_location, locations_to_intervals) - for location in actor_locations: - self.assertNotIn(location, locations_to_intervals) - - # Create more actors to fill up all the GPUs. - more_actors = [Actor1.remote() for _ in - range(num_local_schedulers * - num_gpus_per_scheduler - 1 - 3)] - # Wait for the actors to finish being created. - ray.get([actor.get_location_and_ids.remote() for actor in more_actors]) - - # Now if we run some GPU tasks, they should not be scheduled. - results = [f1.remote() for _ in range(30)] - ready_ids, remaining_ids = ray.wait(results, timeout=1000) - self.assertEqual(len(ready_ids), 0) - - ray.worker.cleanup() - - def testActorsAndTasksWithGPUsVersionTwo(self): - # Create tasks and actors that both use GPUs and make sure that they are - # given different GPUs - ray.init(num_cpus=10, num_gpus=10) - - @ray.remote(num_gpus=1) - def f(): - time.sleep(4) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 1 - return gpu_ids[0] - - @ray.remote(num_gpus=1) - class Actor(object): - def __init__(self): - self.gpu_ids = ray.get_gpu_ids() - assert len(self.gpu_ids) == 1 - - def get_gpu_id(self): - assert ray.get_gpu_ids() == self.gpu_ids - return self.gpu_ids[0] - - results = [] - actors = [] - for _ in range(5): - results.append(f.remote()) - a = Actor.remote() - results.append(a.get_gpu_id.remote()) - # Prevent the actor handle from going out of scope so that its GPU - # resources don't get released. - actors.append(a) - - gpu_ids = ray.get(results) - self.assertEqual(set(gpu_ids), set(range(10))) - - ray.worker.cleanup() - - def testActorsAndTaskResourceBookkeeping(self): - ray.init(num_cpus=1) - - @ray.remote - class Foo(object): - def __init__(self): - start = time.time() - time.sleep(0.1) - end = time.time() - self.interval = (start, end) - - def get_interval(self): - return self.interval - - def sleep(self): - start = time.time() - time.sleep(0.01) - end = time.time() - return start, end - - # First make sure that we do not have more actor methods running at a time - # than we have CPUs. - actors = [Foo.remote() for _ in range(4)] - interval_ids = [] - interval_ids += [actor.get_interval.remote() for actor in actors] - for _ in range(4): - interval_ids += [actor.sleep.remote() for actor in actors] - - # Make sure that the intervals don't overlap. - intervals = ray.get(interval_ids) - intervals.sort(key=lambda x: x[0]) - for interval1, interval2 in zip(intervals[:-1], intervals[1:]): - self.assertLess(interval1[0], interval1[1]) - self.assertLess(interval1[1], interval2[0]) - self.assertLess(interval2[0], interval2[1]) - - ray.worker.cleanup() - - def testBlockingActorTask(self): - ray.init(num_cpus=1, num_gpus=1) - - @ray.remote(num_gpus=1) - def f(): - return 1 - - @ray.remote - class Foo(object): - def __init__(self): - pass - - def blocking_method(self): - ray.get(f.remote()) - - # Make sure we can execute a blocking actor method even if there is only - # one CPU. - actor = Foo.remote() - ray.get(actor.blocking_method.remote()) - - @ray.remote(num_gpus=1) - class GPUFoo(object): - def __init__(self): - pass - - def blocking_method(self): - ray.get(f.remote()) - - # Make sure that we GPU resources are not released when actors block. - actor = GPUFoo.remote() - x_id = actor.blocking_method.remote() - ready_ids, remaining_ids = ray.wait([x_id], timeout=500) - self.assertEqual(ready_ids, []) - self.assertEqual(remaining_ids, [x_id]) - - ray.worker.cleanup() + def testActorGPUs(self): + num_local_schedulers = 3 + num_gpus_per_scheduler = 4 + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + + @ray.remote(num_gpus=1) + class Actor1(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + + def get_location_and_ids(self): + assert ray.get_gpu_ids() == self.gpu_ids + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) + + # Create one actor per GPU. + actors = [Actor1.remote() for _ + in range(num_local_schedulers * num_gpus_per_scheduler)] + # Make sure that no two actors are assigned to the same GPU. + locations_and_ids = ray.get([actor.get_location_and_ids.remote() + for actor in actors]) + node_names = set([location for location, gpu_id in locations_and_ids]) + self.assertEqual(len(node_names), num_local_schedulers) + location_actor_combinations = [] + for node_name in node_names: + for gpu_id in range(num_gpus_per_scheduler): + location_actor_combinations.append((node_name, (gpu_id,))) + self.assertEqual(set(locations_and_ids), + set(location_actor_combinations)) + + # Creating a new actor should fail because all of the GPUs are being + # used. + with self.assertRaises(Exception): + Actor1.remote() + + ray.worker.cleanup() + + def testActorMultipleGPUs(self): + num_local_schedulers = 3 + num_gpus_per_scheduler = 5 + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + + @ray.remote(num_gpus=2) + class Actor1(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + + def get_location_and_ids(self): + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) + + # Create some actors. + actors1 = [Actor1.remote() for _ in range(num_local_schedulers * 2)] + # Make sure that no two actors are assigned to the same GPU. + locations_and_ids = ray.get([actor.get_location_and_ids.remote() + for actor in actors1]) + node_names = set([location for location, gpu_id in locations_and_ids]) + self.assertEqual(len(node_names), num_local_schedulers) + + # Keep track of which GPU IDs are being used for each location. + gpus_in_use = {node_name: [] for node_name in node_names} + for location, gpu_ids in locations_and_ids: + gpus_in_use[location].extend(gpu_ids) + for node_name in node_names: + self.assertEqual(len(set(gpus_in_use[node_name])), 4) + + # Creating a new actor should fail because all of the GPUs are being + # used. + with self.assertRaises(Exception): + Actor1.remote() + + # We should be able to create more actors that use only a single GPU. + @ray.remote(num_gpus=1) + class Actor2(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + + def get_location_and_ids(self): + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) + + # Create some actors. + actors2 = [Actor2.remote() for _ in range(num_local_schedulers)] + # Make sure that no two actors are assigned to the same GPU. + locations_and_ids = ray.get([actor.get_location_and_ids.remote() + for actor in actors2]) + self.assertEqual(node_names, + set([location for location, gpu_id + in locations_and_ids])) + for location, gpu_ids in locations_and_ids: + gpus_in_use[location].extend(gpu_ids) + for node_name in node_names: + self.assertEqual(len(gpus_in_use[node_name]), 5) + self.assertEqual(set(gpus_in_use[node_name]), set(range(5))) + + # Creating a new actor should fail because all of the GPUs are being + # used. + with self.assertRaises(Exception): + Actor2.remote() + + ray.worker.cleanup() + + def testActorDifferentNumbersOfGPUs(self): + # Test that we can create actors on two nodes that have different + # numbers of GPUs. + ray.worker._init(start_ray_local=True, num_workers=0, + num_local_schedulers=3, num_gpus=[0, 5, 10]) + + @ray.remote(num_gpus=1) + class Actor1(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + + def get_location_and_ids(self): + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) + + # Create some actors. + actors = [Actor1.remote() for _ in range(0 + 5 + 10)] + # Make sure that no two actors are assigned to the same GPU. + locations_and_ids = ray.get([actor.get_location_and_ids.remote() + for actor in actors]) + node_names = set([location for location, gpu_id in locations_and_ids]) + self.assertEqual(len(node_names), 2) + for node_name in node_names: + node_gpu_ids = [gpu_id for location, gpu_id in locations_and_ids + if location == node_name] + self.assertIn(len(node_gpu_ids), [5, 10]) + self.assertEqual(set(node_gpu_ids), + set([(i,) for i in range(len(node_gpu_ids))])) + + # Creating a new actor should fail because all of the GPUs are being + # used. + with self.assertRaises(Exception): + Actor1.remote() + + ray.worker.cleanup() + + def testActorMultipleGPUsFromMultipleTasks(self): + num_local_schedulers = 10 + num_gpus_per_scheduler = 10 + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, redirect_output=True, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + + @ray.remote + def create_actors(n): + @ray.remote(num_gpus=1) + class Actor(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + + def get_location_and_ids(self): + return ((ray.worker.global_worker.plasma_client + .store_socket_name), + tuple(self.gpu_ids)) + # Create n actors. + for _ in range(n): + Actor.remote() + + ray.get([create_actors.remote(num_gpus_per_scheduler) + for _ in range(num_local_schedulers)]) + + @ray.remote(num_gpus=1) + class Actor(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + + def get_location_and_ids(self): + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) + + # All the GPUs should be used up now. + with self.assertRaises(Exception): + Actor.remote() + + ray.worker.cleanup() + + def testActorsAndTasksWithGPUs(self): + num_local_schedulers = 3 + num_gpus_per_scheduler = 6 + ray.worker._init( + start_ray_local=True, num_workers=0, + num_local_schedulers=num_local_schedulers, + num_cpus=num_gpus_per_scheduler, + num_gpus=(num_local_schedulers * [num_gpus_per_scheduler])) + + def check_intervals_non_overlapping(list_of_intervals): + for i in range(len(list_of_intervals)): + for j in range(i): + first_interval = list_of_intervals[i] + second_interval = list_of_intervals[j] + # Check that list_of_intervals[i] and list_of_intervals[j] + # don't overlap. + assert first_interval[0] < first_interval[1] + assert second_interval[0] < second_interval[1] + assert (first_interval[1] < second_interval[0] or + second_interval[1] < first_interval[0]) + + @ray.remote(num_gpus=1) + def f1(): + t1 = time.time() + time.sleep(0.1) + t2 = time.time() + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 1 + assert gpu_ids[0] in range(num_gpus_per_scheduler) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(gpu_ids), [t1, t2]) + + @ray.remote(num_gpus=2) + def f2(): + t1 = time.time() + time.sleep(0.1) + t2 = time.time() + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 2 + assert gpu_ids[0] in range(num_gpus_per_scheduler) + assert gpu_ids[1] in range(num_gpus_per_scheduler) + return (ray.worker.global_worker.plasma_client.store_socket_name, + tuple(gpu_ids), [t1, t2]) + + @ray.remote(num_gpus=1) + class Actor1(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + assert len(self.gpu_ids) == 1 + assert self.gpu_ids[0] in range(num_gpus_per_scheduler) + + def get_location_and_ids(self): + assert ray.get_gpu_ids() == self.gpu_ids + return ( + ray.worker.global_worker.plasma_client.store_socket_name, + tuple(self.gpu_ids)) + + def locations_to_intervals_for_many_tasks(): + # Launch a bunch of GPU tasks. + locations_ids_and_intervals = ray.get( + [f1.remote() for _ + in range(5 * num_local_schedulers * num_gpus_per_scheduler)] + + [f2.remote() for _ + in range(5 * num_local_schedulers * num_gpus_per_scheduler)] + + [f1.remote() for _ + in range(5 * num_local_schedulers * num_gpus_per_scheduler)]) + + locations_to_intervals = collections.defaultdict(lambda: []) + for location, gpu_ids, interval in locations_ids_and_intervals: + for gpu_id in gpu_ids: + locations_to_intervals[(location, gpu_id)].append(interval) + return locations_to_intervals + + # Run a bunch of GPU tasks. + locations_to_intervals = locations_to_intervals_for_many_tasks() + # Make sure that all GPUs were used. + self.assertEqual(len(locations_to_intervals), + num_local_schedulers * num_gpus_per_scheduler) + # For each GPU, verify that the set of tasks that used this specific + # GPU did not overlap in time. + for locations in locations_to_intervals: + check_intervals_non_overlapping(locations_to_intervals[locations]) + + # Create an actor that uses a GPU. + a = Actor1.remote() + actor_location = ray.get(a.get_location_and_ids.remote()) + actor_location = (actor_location[0], actor_location[1][0]) + # This check makes sure that actor_location is formatted the same way + # that the keys of locations_to_intervals are formatted. + self.assertIn(actor_location, locations_to_intervals) + + # Run a bunch of GPU tasks. + locations_to_intervals = locations_to_intervals_for_many_tasks() + # Make sure that all but one of the GPUs were used. + self.assertEqual(len(locations_to_intervals), + num_local_schedulers * num_gpus_per_scheduler - 1) + # For each GPU, verify that the set of tasks that used this specific + # GPU did not overlap in time. + for locations in locations_to_intervals: + check_intervals_non_overlapping(locations_to_intervals[locations]) + # Make sure that the actor's GPU was not used. + self.assertNotIn(actor_location, locations_to_intervals) + + # Create several more actors that use GPUs. + actors = [Actor1.remote() for _ in range(3)] + actor_locations = ray.get([actor.get_location_and_ids.remote() + for actor in actors]) + + # Run a bunch of GPU tasks. + locations_to_intervals = locations_to_intervals_for_many_tasks() + # Make sure that all but 11 of the GPUs were used. + self.assertEqual(len(locations_to_intervals), + num_local_schedulers * num_gpus_per_scheduler - 1 - 3) + # For each GPU, verify that the set of tasks that used this specific + # GPU did not overlap in time. + for locations in locations_to_intervals: + check_intervals_non_overlapping(locations_to_intervals[locations]) + # Make sure that the GPUs were not used. + self.assertNotIn(actor_location, locations_to_intervals) + for location in actor_locations: + self.assertNotIn(location, locations_to_intervals) + + # Create more actors to fill up all the GPUs. + more_actors = [Actor1.remote() for _ in + range(num_local_schedulers * + num_gpus_per_scheduler - 1 - 3)] + # Wait for the actors to finish being created. + ray.get([actor.get_location_and_ids.remote() for actor in more_actors]) + + # Now if we run some GPU tasks, they should not be scheduled. + results = [f1.remote() for _ in range(30)] + ready_ids, remaining_ids = ray.wait(results, timeout=1000) + self.assertEqual(len(ready_ids), 0) + + ray.worker.cleanup() + + def testActorsAndTasksWithGPUsVersionTwo(self): + # Create tasks and actors that both use GPUs and make sure that they + # are given different GPUs + ray.init(num_cpus=10, num_gpus=10) + + @ray.remote(num_gpus=1) + def f(): + time.sleep(4) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 1 + return gpu_ids[0] + + @ray.remote(num_gpus=1) + class Actor(object): + def __init__(self): + self.gpu_ids = ray.get_gpu_ids() + assert len(self.gpu_ids) == 1 + + def get_gpu_id(self): + assert ray.get_gpu_ids() == self.gpu_ids + return self.gpu_ids[0] + + results = [] + actors = [] + for _ in range(5): + results.append(f.remote()) + a = Actor.remote() + results.append(a.get_gpu_id.remote()) + # Prevent the actor handle from going out of scope so that its GPU + # resources don't get released. + actors.append(a) + + gpu_ids = ray.get(results) + self.assertEqual(set(gpu_ids), set(range(10))) + + ray.worker.cleanup() + + def testActorsAndTaskResourceBookkeeping(self): + ray.init(num_cpus=1) + + @ray.remote + class Foo(object): + def __init__(self): + start = time.time() + time.sleep(0.1) + end = time.time() + self.interval = (start, end) + + def get_interval(self): + return self.interval + + def sleep(self): + start = time.time() + time.sleep(0.01) + end = time.time() + return start, end + + # First make sure that we do not have more actor methods running at a + # time than we have CPUs. + actors = [Foo.remote() for _ in range(4)] + interval_ids = [] + interval_ids += [actor.get_interval.remote() for actor in actors] + for _ in range(4): + interval_ids += [actor.sleep.remote() for actor in actors] + + # Make sure that the intervals don't overlap. + intervals = ray.get(interval_ids) + intervals.sort(key=lambda x: x[0]) + for interval1, interval2 in zip(intervals[:-1], intervals[1:]): + self.assertLess(interval1[0], interval1[1]) + self.assertLess(interval1[1], interval2[0]) + self.assertLess(interval2[0], interval2[1]) + + ray.worker.cleanup() + + def testBlockingActorTask(self): + ray.init(num_cpus=1, num_gpus=1) + + @ray.remote(num_gpus=1) + def f(): + return 1 + + @ray.remote + class Foo(object): + def __init__(self): + pass + + def blocking_method(self): + ray.get(f.remote()) + + # Make sure we can execute a blocking actor method even if there is + # only one CPU. + actor = Foo.remote() + ray.get(actor.blocking_method.remote()) + + @ray.remote(num_gpus=1) + class GPUFoo(object): + def __init__(self): + pass + + def blocking_method(self): + ray.get(f.remote()) + + # Make sure that we GPU resources are not released when actors block. + actor = GPUFoo.remote() + x_id = actor.blocking_method.remote() + ready_ids, remaining_ids = ray.wait([x_id], timeout=500) + self.assertEqual(ready_ids, []) + self.assertEqual(remaining_ids, [x_id]) + + ray.worker.cleanup() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/array_test.py b/test/array_test.py index 6068e907f..f3418e86f 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -12,223 +12,226 @@ import ray.experimental.array.remote as ra import ray.experimental.array.distributed as da if sys.version_info >= (3, 0): - from importlib import reload + from importlib import reload class RemoteArrayTest(unittest.TestCase): - def testMethods(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, - da.linalg]: - reload(module) - ray.init() + def testMethods(self): + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, + da.linalg]: + reload(module) + ray.init() - # test eye - object_id = ra.eye.remote(3) - val = ray.get(object_id) - assert_almost_equal(val, np.eye(3)) + # test eye + object_id = ra.eye.remote(3) + val = ray.get(object_id) + assert_almost_equal(val, np.eye(3)) - # test zeros - object_id = ra.zeros.remote([3, 4, 5]) - val = ray.get(object_id) - assert_equal(val, np.zeros([3, 4, 5])) + # test zeros + object_id = ra.zeros.remote([3, 4, 5]) + val = ray.get(object_id) + assert_equal(val, np.zeros([3, 4, 5])) - # test qr - pass by value - a_val = np.random.normal(size=[10, 11]) - q_id, r_id = ra.linalg.qr.remote(a_val) - q_val = ray.get(q_id) - r_val = ray.get(r_id) - assert_almost_equal(np.dot(q_val, r_val), a_val) + # test qr - pass by value + a_val = np.random.normal(size=[10, 11]) + q_id, r_id = ra.linalg.qr.remote(a_val) + q_val = ray.get(q_id) + r_val = ray.get(r_id) + assert_almost_equal(np.dot(q_val, r_val), a_val) - # test qr - pass by objectid - a = ra.random.normal.remote([10, 13]) - q_id, r_id = ra.linalg.qr.remote(a) - a_val = ray.get(a) - q_val = ray.get(q_id) - r_val = ray.get(r_id) - assert_almost_equal(np.dot(q_val, r_val), a_val) + # test qr - pass by objectid + a = ra.random.normal.remote([10, 13]) + q_id, r_id = ra.linalg.qr.remote(a) + a_val = ray.get(a) + q_val = ray.get(q_id) + r_val = ray.get(r_id) + assert_almost_equal(np.dot(q_val, r_val), a_val) - ray.worker.cleanup() + ray.worker.cleanup() class DistributedArrayTest(unittest.TestCase): - def testAssemble(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, - da.linalg]: - reload(module) - ray.init() + def testAssemble(self): + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, + da.linalg]: + reload(module) + ray.init() - a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) - b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) - x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]])) - assert_equal(x.assemble(), - np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), - np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])) + a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) + b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE]) + x = da.DistArray([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], + np.array([[a], [b]])) + assert_equal(x.assemble(), + np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), + np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])) - ray.worker.cleanup() + ray.worker.cleanup() - def testMethods(self): - for module in [ra.core, ra.random, ra.linalg, da.core, da.random, - da.linalg]: - reload(module) - ray.worker._init(start_ray_local=True, num_local_schedulers=2, - num_cpus=[10, 10]) + def testMethods(self): + for module in [ra.core, ra.random, ra.linalg, da.core, da.random, + da.linalg]: + reload(module) + ray.worker._init(start_ray_local=True, num_local_schedulers=2, + num_cpus=[10, 10]) - x = da.zeros.remote([9, 25, 51], "float") - assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) + x = da.zeros.remote([9, 25, 51], "float") + assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) - x = da.ones.remote([11, 25, 49], dtype_name="float") - assert_equal(ray.get(da.assemble.remote(x)), np.ones([11, 25, 49])) + x = da.ones.remote([11, 25, 49], dtype_name="float") + assert_equal(ray.get(da.assemble.remote(x)), np.ones([11, 25, 49])) - x = da.random.normal.remote([11, 25, 49]) - y = da.copy.remote(x) - assert_equal(ray.get(da.assemble.remote(x)), - ray.get(da.assemble.remote(y))) + x = da.random.normal.remote([11, 25, 49]) + y = da.copy.remote(x) + assert_equal(ray.get(da.assemble.remote(x)), + ray.get(da.assemble.remote(y))) - x = da.eye.remote(25, dtype_name="float") - assert_equal(ray.get(da.assemble.remote(x)), np.eye(25)) + x = da.eye.remote(25, dtype_name="float") + assert_equal(ray.get(da.assemble.remote(x)), np.eye(25)) - x = da.random.normal.remote([25, 49]) - y = da.triu.remote(x) - assert_equal(ray.get(da.assemble.remote(y)), - np.triu(ray.get(da.assemble.remote(x)))) + x = da.random.normal.remote([25, 49]) + y = da.triu.remote(x) + assert_equal(ray.get(da.assemble.remote(y)), + np.triu(ray.get(da.assemble.remote(x)))) - x = da.random.normal.remote([25, 49]) - y = da.tril.remote(x) - assert_equal(ray.get(da.assemble.remote(y)), - np.tril(ray.get(da.assemble.remote(x)))) + x = da.random.normal.remote([25, 49]) + y = da.tril.remote(x) + assert_equal(ray.get(da.assemble.remote(y)), + np.tril(ray.get(da.assemble.remote(x)))) - x = da.random.normal.remote([25, 49]) - y = da.random.normal.remote([49, 18]) - z = da.dot.remote(x, y) - w = da.assemble.remote(z) - u = da.assemble.remote(x) - v = da.assemble.remote(y) - assert_almost_equal(ray.get(w), np.dot(ray.get(u), ray.get(v))) - assert_almost_equal(ray.get(w), np.dot(ray.get(u), ray.get(v))) + x = da.random.normal.remote([25, 49]) + y = da.random.normal.remote([49, 18]) + z = da.dot.remote(x, y) + w = da.assemble.remote(z) + u = da.assemble.remote(x) + v = da.assemble.remote(y) + assert_almost_equal(ray.get(w), np.dot(ray.get(u), ray.get(v))) + assert_almost_equal(ray.get(w), np.dot(ray.get(u), ray.get(v))) - # test add - x = da.random.normal.remote([23, 42]) - y = da.random.normal.remote([23, 42]) - z = da.add.remote(x, y) - assert_almost_equal(ray.get(da.assemble.remote(z)), - ray.get(da.assemble.remote(x)) + - ray.get(da.assemble.remote(y))) + # test add + x = da.random.normal.remote([23, 42]) + y = da.random.normal.remote([23, 42]) + z = da.add.remote(x, y) + assert_almost_equal(ray.get(da.assemble.remote(z)), + ray.get(da.assemble.remote(x)) + + ray.get(da.assemble.remote(y))) - # test subtract - x = da.random.normal.remote([33, 40]) - y = da.random.normal.remote([33, 40]) - z = da.subtract.remote(x, y) - assert_almost_equal(ray.get(da.assemble.remote(z)), - ray.get(da.assemble.remote(x)) - - ray.get(da.assemble.remote(y))) + # test subtract + x = da.random.normal.remote([33, 40]) + y = da.random.normal.remote([33, 40]) + z = da.subtract.remote(x, y) + assert_almost_equal(ray.get(da.assemble.remote(z)), + ray.get(da.assemble.remote(x)) - + ray.get(da.assemble.remote(y))) - # test transpose - x = da.random.normal.remote([234, 432]) - y = da.transpose.remote(x) - assert_equal(ray.get(da.assemble.remote(x)).T, - ray.get(da.assemble.remote(y))) + # test transpose + x = da.random.normal.remote([234, 432]) + y = da.transpose.remote(x) + assert_equal(ray.get(da.assemble.remote(x)).T, + ray.get(da.assemble.remote(y))) - # test numpy_to_dist - x = da.random.normal.remote([23, 45]) - y = da.assemble.remote(x) - z = da.numpy_to_dist.remote(y) - w = da.assemble.remote(z) - assert_equal(ray.get(da.assemble.remote(x)), - ray.get(da.assemble.remote(z))) - assert_equal(ray.get(y), ray.get(w)) + # test numpy_to_dist + x = da.random.normal.remote([23, 45]) + y = da.assemble.remote(x) + z = da.numpy_to_dist.remote(y) + w = da.assemble.remote(z) + assert_equal(ray.get(da.assemble.remote(x)), + ray.get(da.assemble.remote(z))) + assert_equal(ray.get(y), ray.get(w)) - # test da.tsqr - for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], - [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], - [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]: - x = da.random.normal.remote(shape) - K = min(shape) - q, r = da.linalg.tsqr.remote(x) - x_val = ray.get(da.assemble.remote(x)) - q_val = ray.get(da.assemble.remote(q)) - r_val = ray.get(r) - self.assertTrue(r_val.shape == (K, shape[1])) - assert_equal(r_val, np.triu(r_val)) - assert_almost_equal(x_val, np.dot(q_val, r_val)) - assert_almost_equal(np.dot(q_val.T, q_val), np.eye(K)) + # test da.tsqr + for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], + [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], + [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]: + x = da.random.normal.remote(shape) + K = min(shape) + q, r = da.linalg.tsqr.remote(x) + x_val = ray.get(da.assemble.remote(x)) + q_val = ray.get(da.assemble.remote(q)) + r_val = ray.get(r) + self.assertTrue(r_val.shape == (K, shape[1])) + assert_equal(r_val, np.triu(r_val)) + assert_almost_equal(x_val, np.dot(q_val, r_val)) + assert_almost_equal(np.dot(q_val.T, q_val), np.eye(K)) - # test da.linalg.modified_lu - def test_modified_lu(d1, d2): - print("testing dist_modified_lu with d1 = " + str(d1) + - ", d2 = " + str(d2)) - assert d1 >= d2 - m = ra.random.normal.remote([d1, d2]) - q, r = ra.linalg.qr.remote(m) - l, u, s = da.linalg.modified_lu.remote(da.numpy_to_dist.remote(q)) - q_val = ray.get(q) - ray.get(r) - l_val = ray.get(da.assemble.remote(l)) - u_val = ray.get(u) - s_val = ray.get(s) - s_mat = np.zeros((d1, d2)) - for i in range(len(s_val)): - s_mat[i, i] = s_val[i] - # Check that q - s = l * u. - assert_almost_equal(q_val - s_mat, np.dot(l_val, u_val)) - # Check that u is upper triangular. - assert_equal(np.triu(u_val), u_val) - # Check that l is lower triangular. - assert_equal(np.tril(l_val), l_val) + # test da.linalg.modified_lu + def test_modified_lu(d1, d2): + print("testing dist_modified_lu with d1 = " + str(d1) + + ", d2 = " + str(d2)) + assert d1 >= d2 + m = ra.random.normal.remote([d1, d2]) + q, r = ra.linalg.qr.remote(m) + l, u, s = da.linalg.modified_lu.remote(da.numpy_to_dist.remote(q)) + q_val = ray.get(q) + ray.get(r) + l_val = ray.get(da.assemble.remote(l)) + u_val = ray.get(u) + s_val = ray.get(s) + s_mat = np.zeros((d1, d2)) + for i in range(len(s_val)): + s_mat[i, i] = s_val[i] + # Check that q - s = l * u. + assert_almost_equal(q_val - s_mat, np.dot(l_val, u_val)) + # Check that u is upper triangular. + assert_equal(np.triu(u_val), u_val) + # Check that l is lower triangular. + assert_equal(np.tril(l_val), l_val) - for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), (20, 10)]: - test_modified_lu(d1, d2) + for d1, d2 in [(100, 100), (99, 98), (7, 5), (7, 7), (20, 7), + (20, 10)]: + test_modified_lu(d1, d2) - # test dist_tsqr_hr - def test_dist_tsqr_hr(d1, d2): - print("testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + str(d2)) - a = da.random.normal.remote([d1, d2]) - y, t, y_top, r = da.linalg.tsqr_hr.remote(a) - a_val = ray.get(da.assemble.remote(a)) - y_val = ray.get(da.assemble.remote(y)) - t_val = ray.get(t) - y_top_val = ray.get(y_top) - r_val = ray.get(r) - tall_eye = np.zeros((d1, min(d1, d2))) - np.fill_diagonal(tall_eye, 1) - q = tall_eye - np.dot(y_val, np.dot(t_val, y_top_val.T)) - # Check that q.T * q = I. - assert_almost_equal(np.dot(q.T, q), np.eye(min(d1, d2))) - # Check that a = (I - y * t * y_top.T) * r. - assert_almost_equal(np.dot(q, r_val), a_val) + # test dist_tsqr_hr + def test_dist_tsqr_hr(d1, d2): + print("testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + + str(d2)) + a = da.random.normal.remote([d1, d2]) + y, t, y_top, r = da.linalg.tsqr_hr.remote(a) + a_val = ray.get(da.assemble.remote(a)) + y_val = ray.get(da.assemble.remote(y)) + t_val = ray.get(t) + y_top_val = ray.get(y_top) + r_val = ray.get(r) + tall_eye = np.zeros((d1, min(d1, d2))) + np.fill_diagonal(tall_eye, 1) + q = tall_eye - np.dot(y_val, np.dot(t_val, y_top_val.T)) + # Check that q.T * q = I. + assert_almost_equal(np.dot(q.T, q), np.eye(min(d1, d2))) + # Check that a = (I - y * t * y_top.T) * r. + assert_almost_equal(np.dot(q, r_val), a_val) - for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), - (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), - (10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]: - test_dist_tsqr_hr(d1, d2) + for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), + (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), + (10 * da.BLOCK_SIZE, da.BLOCK_SIZE)]: + test_dist_tsqr_hr(d1, d2) - def test_dist_qr(d1, d2): - print("testing qr with d1 = {}, and d2 = {}.".format(d1, d2)) - a = da.random.normal.remote([d1, d2]) - K = min(d1, d2) - q, r = da.linalg.qr.remote(a) - a_val = ray.get(da.assemble.remote(a)) - q_val = ray.get(da.assemble.remote(q)) - r_val = ray.get(da.assemble.remote(r)) - self.assertEqual(q_val.shape, (d1, K)) - self.assertEqual(r_val.shape, (K, d2)) - assert_almost_equal(np.dot(q_val.T, q_val), np.eye(K)) - assert_equal(r_val, np.triu(r_val)) - assert_almost_equal(a_val, np.dot(q_val, r_val)) + def test_dist_qr(d1, d2): + print("testing qr with d1 = {}, and d2 = {}.".format(d1, d2)) + a = da.random.normal.remote([d1, d2]) + K = min(d1, d2) + q, r = da.linalg.qr.remote(a) + a_val = ray.get(da.assemble.remote(a)) + q_val = ray.get(da.assemble.remote(q)) + r_val = ray.get(da.assemble.remote(r)) + self.assertEqual(q_val.shape, (d1, K)) + self.assertEqual(r_val.shape, (K, d2)) + assert_almost_equal(np.dot(q_val.T, q_val), np.eye(K)) + assert_equal(r_val, np.triu(r_val)) + assert_almost_equal(a_val, np.dot(q_val, r_val)) - for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), - (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), - (13, 21), (34, 35), (8, 7)]: - test_dist_qr(d1, d2) - test_dist_qr(d2, d1) - for _ in range(20): - d1 = np.random.randint(1, 35) - d2 = np.random.randint(1, 35) - test_dist_qr(d1, d2) + for d1, d2 in [(123, da.BLOCK_SIZE), (7, da.BLOCK_SIZE), + (da.BLOCK_SIZE, da.BLOCK_SIZE), (da.BLOCK_SIZE, 7), + (13, 21), (34, 35), (8, 7)]: + test_dist_qr(d1, d2) + test_dist_qr(d2, d1) + for _ in range(20): + d1 = np.random.randint(1, 35) + d2 = np.random.randint(1, 35) + test_dist_qr(d1, d2) - ray.worker.cleanup() + ray.worker.cleanup() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/component_failures_test.py b/test/component_failures_test.py index fc383af1e..72ab30f47 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -9,241 +9,250 @@ import unittest class ComponentFailureTest(unittest.TestCase): - def tearDown(self): - ray.worker.cleanup() + def tearDown(self): + ray.worker.cleanup() - # This test checks that when a worker dies in the middle of a get, the plasma - # store and manager will not die. - def testDyingWorkerGet(self): - obj_id = 20 * b"a" + # This test checks that when a worker dies in the middle of a get, the + # plasma store and manager will not die. + def testDyingWorkerGet(self): + obj_id = 20 * b"a" - @ray.remote - def f(): - ray.worker.global_worker.plasma_client.get(obj_id) + @ray.remote + def f(): + ray.worker.global_worker.plasma_client.get(obj_id) - ray.worker._init(num_workers=1, - driver_mode=ray.SILENT_MODE, - start_workers_from_local_scheduler=False, - start_ray_local=True, - redirect_output=True) + ray.worker._init(num_workers=1, + driver_mode=ray.SILENT_MODE, + start_workers_from_local_scheduler=False, + start_ray_local=True, + redirect_output=True) - # Have the worker wait in a get call. - f.remote() + # Have the worker wait in a get call. + f.remote() - # Kill the worker. - time.sleep(1) - ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate() - time.sleep(0.1) + # Kill the worker. + time.sleep(1) + (ray.services + .all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate()) + time.sleep(0.1) - # Seal the object so the store attempts to notify the worker that the get - # has been fulfilled. - ray.worker.global_worker.plasma_client.create(obj_id, 100) - ray.worker.global_worker.plasma_client.seal(obj_id) - time.sleep(0.1) + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.plasma_client.create(obj_id, 100) + ray.worker.global_worker.plasma_client.seal(obj_id) + time.sleep(0.1) - # Make sure that nothing has died. - self.assertTrue(ray.services.all_processes_alive( - exclude=[ray.services.PROCESS_TYPE_WORKER])) + # Make sure that nothing has died. + self.assertTrue(ray.services.all_processes_alive( + exclude=[ray.services.PROCESS_TYPE_WORKER])) - # This test checks that when a worker dies in the middle of a wait, the - # plasma store and manager will not die. - def testDyingWorkerWait(self): - obj_id = 20 * b"a" + # This test checks that when a worker dies in the middle of a wait, the + # plasma store and manager will not die. + def testDyingWorkerWait(self): + obj_id = 20 * b"a" - @ray.remote - def f(): - ray.worker.global_worker.plasma_client.wait([obj_id]) + @ray.remote + def f(): + ray.worker.global_worker.plasma_client.wait([obj_id]) - ray.worker._init(num_workers=1, - driver_mode=ray.SILENT_MODE, - start_workers_from_local_scheduler=False, - start_ray_local=True, - redirect_output=True) + ray.worker._init(num_workers=1, + driver_mode=ray.SILENT_MODE, + start_workers_from_local_scheduler=False, + start_ray_local=True, + redirect_output=True) - # Have the worker wait in a get call. - f.remote() + # Have the worker wait in a get call. + f.remote() - # Kill the worker. - time.sleep(1) - ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate() - time.sleep(0.1) + # Kill the worker. + time.sleep(1) + (ray.services + .all_processes[ray.services.PROCESS_TYPE_WORKER][0].terminate()) + time.sleep(0.1) - # Seal the object so the store attempts to notify the worker that the get - # has been fulfilled. - ray.worker.global_worker.plasma_client.create(obj_id, 100) - ray.worker.global_worker.plasma_client.seal(obj_id) - time.sleep(0.1) + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.plasma_client.create(obj_id, 100) + ray.worker.global_worker.plasma_client.seal(obj_id) + time.sleep(0.1) - # Make sure that nothing has died. - self.assertTrue(ray.services.all_processes_alive( - exclude=[ray.services.PROCESS_TYPE_WORKER])) + # Make sure that nothing has died. + self.assertTrue(ray.services.all_processes_alive( + exclude=[ray.services.PROCESS_TYPE_WORKER])) - def _testWorkerFailed(self, num_local_schedulers): - @ray.remote - def f(x): - time.sleep(0.5) - return x + def _testWorkerFailed(self, num_local_schedulers): + @ray.remote + def f(x): + time.sleep(0.5) + return x - num_initial_workers = 4 - ray.worker._init(num_workers=num_initial_workers * num_local_schedulers, - num_local_schedulers=num_local_schedulers, - start_workers_from_local_scheduler=False, - start_ray_local=True, - num_cpus=[num_initial_workers] * num_local_schedulers, - redirect_output=True) - # Submit more tasks than there are workers so that all workers and cores - # are utilized. - object_ids = [f.remote(i) for i - in range(num_initial_workers * num_local_schedulers)] - object_ids += [f.remote(object_id) for object_id in object_ids] - # Allow the tasks some time to begin executing. - time.sleep(0.1) - # Kill the workers as the tasks execute. - for worker in ray.services.all_processes[ray.services.PROCESS_TYPE_WORKER]: - worker.terminate() - time.sleep(0.1) - # Make sure that we can still get the objects after the executing tasks - # died. - ray.get(object_ids) + num_initial_workers = 4 + ray.worker._init(num_workers=(num_initial_workers * + num_local_schedulers), + num_local_schedulers=num_local_schedulers, + start_workers_from_local_scheduler=False, + start_ray_local=True, + num_cpus=[num_initial_workers] * num_local_schedulers, + redirect_output=True) + # Submit more tasks than there are workers so that all workers and + # cores are utilized. + object_ids = [f.remote(i) for i + in range(num_initial_workers * num_local_schedulers)] + object_ids += [f.remote(object_id) for object_id in object_ids] + # Allow the tasks some time to begin executing. + time.sleep(0.1) + # Kill the workers as the tasks execute. + for worker in (ray.services + .all_processes[ray.services.PROCESS_TYPE_WORKER]): + worker.terminate() + time.sleep(0.1) + # Make sure that we can still get the objects after the executing tasks + # died. + ray.get(object_ids) - def testWorkerFailed(self): - self._testWorkerFailed(1) + def testWorkerFailed(self): + self._testWorkerFailed(1) - def testWorkerFailedMultinode(self): - self._testWorkerFailed(4) + def testWorkerFailedMultinode(self): + self._testWorkerFailed(4) - def _testComponentFailed(self, component_type): - """Kill a component on all worker nodes and check workload succeeds.""" - @ray.remote - def f(x, j): - time.sleep(0.2) - return x + def _testComponentFailed(self, component_type): + """Kill a component on all worker nodes and check workload succeeds.""" + @ray.remote + def f(x, j): + time.sleep(0.2) + return x - # Start with 4 workers and 4 cores. - num_local_schedulers = 4 - num_workers_per_scheduler = 8 - ray.worker._init( - num_workers=num_local_schedulers * num_workers_per_scheduler, - num_local_schedulers=num_local_schedulers, - start_ray_local=True, - num_cpus=[num_workers_per_scheduler] * num_local_schedulers, - redirect_output=True) + # Start with 4 workers and 4 cores. + num_local_schedulers = 4 + num_workers_per_scheduler = 8 + ray.worker._init( + num_workers=num_local_schedulers * num_workers_per_scheduler, + num_local_schedulers=num_local_schedulers, + start_ray_local=True, + num_cpus=[num_workers_per_scheduler] * num_local_schedulers, + redirect_output=True) - # Submit more tasks than there are workers so that all workers and cores - # are utilized. - object_ids = [f.remote(i, 0) for i - in range(num_workers_per_scheduler * num_local_schedulers)] - object_ids += [f.remote(object_id, 1) for object_id in object_ids] - object_ids += [f.remote(object_id, 2) for object_id in object_ids] + # Submit more tasks than there are workers so that all workers and + # cores are utilized. + object_ids = [f.remote(i, 0) for i + in range(num_workers_per_scheduler * + num_local_schedulers)] + object_ids += [f.remote(object_id, 1) for object_id in object_ids] + object_ids += [f.remote(object_id, 2) for object_id in object_ids] - # Kill the component on all nodes except the head node as the tasks - # execute. - time.sleep(0.1) - components = ray.services.all_processes[component_type] - for process in components[1:]: - process.terminate() - time.sleep(0.1) - process.kill() - process.wait() - self.assertNotEqual(process.poll(), None) - time.sleep(1) + # Kill the component on all nodes except the head node as the tasks + # execute. + time.sleep(0.1) + components = ray.services.all_processes[component_type] + for process in components[1:]: + process.terminate() + time.sleep(0.1) + process.kill() + process.wait() + self.assertNotEqual(process.poll(), None) + time.sleep(1) - # Make sure that we can still get the objects after the executing tasks - # died. - results = ray.get(object_ids) - expected_results = 4 * list(range( - num_workers_per_scheduler * num_local_schedulers)) - self.assertEqual(results, expected_results) + # Make sure that we can still get the objects after the executing tasks + # died. + results = ray.get(object_ids) + expected_results = 4 * list(range( + num_workers_per_scheduler * num_local_schedulers)) + self.assertEqual(results, expected_results) - def check_components_alive(self, component_type, check_component_alive): - """Check that a given component type is alive on all worker nodes. - """ - components = ray.services.all_processes[component_type][1:] - for component in components: - if check_component_alive: - self.assertTrue(component.poll() is None) - else: - print("waiting for " + component_type + " with PID " + - str(component.pid) + "to terminate") - component.wait() - print("done waiting for " + component_type + " with PID " + - str(component.pid) + "to terminate") - self.assertTrue(not component.poll() is None) + def check_components_alive(self, component_type, check_component_alive): + """Check that a given component type is alive on all worker nodes. + """ + components = ray.services.all_processes[component_type][1:] + for component in components: + if check_component_alive: + self.assertTrue(component.poll() is None) + else: + print("waiting for " + component_type + " with PID " + + str(component.pid) + "to terminate") + component.wait() + print("done waiting for " + component_type + " with PID " + + str(component.pid) + "to terminate") + self.assertTrue(not component.poll() is None) - def testLocalSchedulerFailed(self): - # Kill all local schedulers on worker nodes. - self._testComponentFailed(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER) + def testLocalSchedulerFailed(self): + # Kill all local schedulers on worker nodes. + self._testComponentFailed(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER) - # The plasma stores and plasma managers should still be alive on the worker - # nodes. - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, True) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, - False) + # The plasma stores and plasma managers should still be alive on the + # worker nodes. + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, + True) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, + True) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, + False) - def testPlasmaManagerFailed(self): - # Kill all plasma managers on worker nodes. - self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_MANAGER) + def testPlasmaManagerFailed(self): + # Kill all plasma managers on worker nodes. + self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_MANAGER) - # The plasma stores should still be alive (but unreachable) on the worker - # nodes. - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, - False) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, - False) + # The plasma stores should still be alive (but unreachable) on the + # worker nodes. + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, + True) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, + False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, + False) - def testPlasmaStoreFailed(self): - # Kill all plasma stores on worker nodes. - self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_STORE) + def testPlasmaStoreFailed(self): + # Kill all plasma stores on worker nodes. + self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_STORE) - # No processes should be left alive on the worker nodes. - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, False) - self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, - False) - self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, - False) + # No processes should be left alive on the worker nodes. + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, + False) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, + False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, + False) - def testDriverLivesSequential(self): - ray.worker.init(redirect_output=True) - all_processes = ray.services.all_processes - processes = [ - all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], - all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], - all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], - all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]] + def testDriverLivesSequential(self): + ray.worker.init(redirect_output=True) + all_processes = ray.services.all_processes + processes = [ + all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], + all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], + all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], + all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]] - # Kill all the components sequentially. - for process in processes: - process.terminate() - time.sleep(0.1) - process.kill() - process.wait() + # Kill all the components sequentially. + for process in processes: + process.terminate() + time.sleep(0.1) + process.kill() + process.wait() - # If the driver can reach the tearDown method, then it is still alive. + # If the driver can reach the tearDown method, then it is still alive. - def testDriverLivesParallel(self): - ray.worker.init(redirect_output=True) - all_processes = ray.services.all_processes - processes = [ - all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], - all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], - all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], - all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]] + def testDriverLivesParallel(self): + ray.worker.init(redirect_output=True) + all_processes = ray.services.all_processes + processes = [ + all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], + all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], + all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], + all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0]] - # Kill all the components in parallel. - for process in processes: - process.terminate() + # Kill all the components in parallel. + for process in processes: + process.terminate() - time.sleep(0.1) - for process in processes: - process.kill() + time.sleep(0.1) + for process in processes: + process.kill() - for process in processes: - process.wait() + for process in processes: + process.wait() - # If the driver can reach the tearDown method, then it is still alive. + # If the driver can reach the tearDown method, then it is still alive. if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/failure_test.py b/test/failure_test.py index 4976a113d..e336fa0be 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -12,266 +12,270 @@ import unittest import ray.test.test_functions as test_functions if sys.version_info >= (3, 0): - from importlib import reload + from importlib import reload def relevant_errors(error_type): - return [info for info in ray.error_info() if info[b"type"] == error_type] + return [info for info in ray.error_info() if info[b"type"] == error_type] def wait_for_errors(error_type, num_errors, timeout=10): - start_time = time.time() - while time.time() - start_time < timeout: - if len(relevant_errors(error_type)) >= num_errors: - return - time.sleep(0.1) - print("Timing out of wait.") + start_time = time.time() + while time.time() - start_time < timeout: + if len(relevant_errors(error_type)) >= num_errors: + return + time.sleep(0.1) + print("Timing out of wait.") class TaskStatusTest(unittest.TestCase): - def testFailedTask(self): - reload(test_functions) - ray.init(num_workers=3, driver_mode=ray.SILENT_MODE) + def testFailedTask(self): + reload(test_functions) + ray.init(num_workers=3, driver_mode=ray.SILENT_MODE) - test_functions.throw_exception_fct1.remote() - test_functions.throw_exception_fct1.remote() - wait_for_errors(b"task", 2) - self.assertEqual(len(relevant_errors(b"task")), 2) - for task in relevant_errors(b"task"): - self.assertIn(b"Test function 1 intentionally failed.", - task.get(b"message")) + test_functions.throw_exception_fct1.remote() + test_functions.throw_exception_fct1.remote() + wait_for_errors(b"task", 2) + self.assertEqual(len(relevant_errors(b"task")), 2) + for task in relevant_errors(b"task"): + self.assertIn(b"Test function 1 intentionally failed.", + task.get(b"message")) - x = test_functions.throw_exception_fct2.remote() - try: - ray.get(x) - except Exception as e: - self.assertIn("Test function 2 intentionally failed.", str(e)) - else: - # ray.get should throw an exception. - self.assertTrue(False) + x = test_functions.throw_exception_fct2.remote() + try: + ray.get(x) + except Exception as e: + self.assertIn("Test function 2 intentionally failed.", str(e)) + else: + # ray.get should throw an exception. + self.assertTrue(False) - x, y, z = test_functions.throw_exception_fct3.remote(1.0) - for ref in [x, y, z]: - try: - ray.get(ref) - except Exception as e: - self.assertIn("Test function 3 intentionally failed.", str(e)) - else: - # ray.get should throw an exception. - self.assertTrue(False) + x, y, z = test_functions.throw_exception_fct3.remote(1.0) + for ref in [x, y, z]: + try: + ray.get(ref) + except Exception as e: + self.assertIn("Test function 3 intentionally failed.", str(e)) + else: + # ray.get should throw an exception. + self.assertTrue(False) - ray.worker.cleanup() + ray.worker.cleanup() - def testFailImportingRemoteFunction(self): - ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) + def testFailImportingRemoteFunction(self): + ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) - # Create the contents of a temporary Python file. - temporary_python_file = """ + # Create the contents of a temporary Python file. + temporary_python_file = """ def temporary_helper_function(): - return 1 + return 1 """ - f = tempfile.NamedTemporaryFile(suffix=".py") - f.write(temporary_python_file.encode("ascii")) - f.flush() - directory = os.path.dirname(f.name) - # Get the module name and strip ".py" from the end. - module_name = os.path.basename(f.name)[:-3] - sys.path.append(directory) - module = __import__(module_name) + f = tempfile.NamedTemporaryFile(suffix=".py") + f.write(temporary_python_file.encode("ascii")) + f.flush() + directory = os.path.dirname(f.name) + # Get the module name and strip ".py" from the end. + module_name = os.path.basename(f.name)[:-3] + sys.path.append(directory) + module = __import__(module_name) - # Define a function that closes over this temporary module. This should - # fail when it is unpickled. - @ray.remote - def g(): - return module.temporary_python_file() + # Define a function that closes over this temporary module. This should + # fail when it is unpickled. + @ray.remote + def g(): + return module.temporary_python_file() - wait_for_errors(b"register_remote_function", 2) - self.assertIn(b"No module named", ray.error_info()[0][b"message"]) - self.assertIn(b"No module named", ray.error_info()[1][b"message"]) + wait_for_errors(b"register_remote_function", 2) + self.assertIn(b"No module named", ray.error_info()[0][b"message"]) + self.assertIn(b"No module named", ray.error_info()[1][b"message"]) - # Check that if we try to call the function it throws an exception and does - # not hang. - for _ in range(10): - self.assertRaises(Exception, lambda: ray.get(g.remote())) + # Check that if we try to call the function it throws an exception and + # does not hang. + for _ in range(10): + self.assertRaises(Exception, lambda: ray.get(g.remote())) - f.close() + f.close() - # Clean up the junk we added to sys.path. - sys.path.pop(-1) - ray.worker.cleanup() + # Clean up the junk we added to sys.path. + sys.path.pop(-1) + ray.worker.cleanup() - def testFailedFunctionToRun(self): - ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) + def testFailedFunctionToRun(self): + ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) - def f(worker): - if ray.worker.global_worker.mode == ray.WORKER_MODE: - raise Exception("Function to run failed.") - ray.worker.global_worker.run_function_on_all_workers(f) - wait_for_errors(b"function_to_run", 2) - # Check that the error message is in the task info. - self.assertEqual(len(ray.error_info()), 2) - self.assertIn(b"Function to run failed.", ray.error_info()[0][b"message"]) - self.assertIn(b"Function to run failed.", ray.error_info()[1][b"message"]) + def f(worker): + if ray.worker.global_worker.mode == ray.WORKER_MODE: + raise Exception("Function to run failed.") + ray.worker.global_worker.run_function_on_all_workers(f) + wait_for_errors(b"function_to_run", 2) + # Check that the error message is in the task info. + self.assertEqual(len(ray.error_info()), 2) + self.assertIn(b"Function to run failed.", + ray.error_info()[0][b"message"]) + self.assertIn(b"Function to run failed.", + ray.error_info()[1][b"message"]) - ray.worker.cleanup() + ray.worker.cleanup() - def testFailImportingActor(self): - ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) + def testFailImportingActor(self): + ray.init(num_workers=2, driver_mode=ray.SILENT_MODE) - # Create the contents of a temporary Python file. - temporary_python_file = """ + # Create the contents of a temporary Python file. + temporary_python_file = """ def temporary_helper_function(): - return 1 + return 1 """ - f = tempfile.NamedTemporaryFile(suffix=".py") - f.write(temporary_python_file.encode("ascii")) - f.flush() - directory = os.path.dirname(f.name) - # Get the module name and strip ".py" from the end. - module_name = os.path.basename(f.name)[:-3] - sys.path.append(directory) - module = __import__(module_name) + f = tempfile.NamedTemporaryFile(suffix=".py") + f.write(temporary_python_file.encode("ascii")) + f.flush() + directory = os.path.dirname(f.name) + # Get the module name and strip ".py" from the end. + module_name = os.path.basename(f.name)[:-3] + sys.path.append(directory) + module = __import__(module_name) - # Define an actor that closes over this temporary module. This should fail - # when it is unpickled. - @ray.remote - class Foo(object): - def __init__(self): - self.x = module.temporary_python_file() + # Define an actor that closes over this temporary module. This should + # fail when it is unpickled. + @ray.remote + class Foo(object): + def __init__(self): + self.x = module.temporary_python_file() - def get_val(self): - return 1 + def get_val(self): + return 1 - # There should be no errors yet. - self.assertEqual(len(ray.error_info()), 0) + # There should be no errors yet. + self.assertEqual(len(ray.error_info()), 0) - # Create an actor. - foo = Foo.remote() + # Create an actor. + foo = Foo.remote() - # Wait for the error to arrive. - wait_for_errors(b"register_actor", 1) - self.assertIn(b"No module named", ray.error_info()[0][b"message"]) + # Wait for the error to arrive. + wait_for_errors(b"register_actor", 1) + self.assertIn(b"No module named", ray.error_info()[0][b"message"]) - # Wait for the error from when the __init__ tries to run. - wait_for_errors(b"task", 1) - self.assertIn(b"failed to be imported, and so cannot execute this method", - ray.error_info()[1][b"message"]) + # Wait for the error from when the __init__ tries to run. + wait_for_errors(b"task", 1) + self.assertIn( + b"failed to be imported, and so cannot execute this method", + ray.error_info()[1][b"message"]) - # Check that if we try to get the function it throws an exception and does - # not hang. - with self.assertRaises(Exception): - ray.get(foo.get_val.remote()) + # Check that if we try to get the function it throws an exception and + # does not hang. + with self.assertRaises(Exception): + ray.get(foo.get_val.remote()) - # Wait for the error from when the call to get_val. - wait_for_errors(b"task", 2) - self.assertIn(b"failed to be imported, and so cannot execute this method", - ray.error_info()[2][b"message"]) + # Wait for the error from when the call to get_val. + wait_for_errors(b"task", 2) + self.assertIn( + b"failed to be imported, and so cannot execute this method", + ray.error_info()[2][b"message"]) - f.close() + f.close() - # Clean up the junk we added to sys.path. - sys.path.pop(-1) - ray.worker.cleanup() + # Clean up the junk we added to sys.path. + sys.path.pop(-1) + ray.worker.cleanup() class ActorTest(unittest.TestCase): - def testFailedActorInit(self): - ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) + def testFailedActorInit(self): + ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) - error_message1 = "actor constructor failed" - error_message2 = "actor method failed" + error_message1 = "actor constructor failed" + error_message2 = "actor method failed" - @ray.remote - class FailedActor(object): - def __init__(self): - raise Exception(error_message1) + @ray.remote + class FailedActor(object): + def __init__(self): + raise Exception(error_message1) - def get_val(self): - return 1 + def get_val(self): + return 1 - def fail_method(self): - raise Exception(error_message2) + def fail_method(self): + raise Exception(error_message2) - a = FailedActor.remote() + a = FailedActor.remote() - # Make sure that we get errors from a failed constructor. - wait_for_errors(b"task", 1) - self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_message1, - ray.error_info()[0][b"message"].decode("ascii")) + # Make sure that we get errors from a failed constructor. + wait_for_errors(b"task", 1) + self.assertEqual(len(ray.error_info()), 1) + self.assertIn(error_message1, + ray.error_info()[0][b"message"].decode("ascii")) - # Make sure that we get errors from a failed method. - a.fail_method.remote() - wait_for_errors(b"task", 2) - self.assertEqual(len(ray.error_info()), 2) - self.assertIn(error_message2, - ray.error_info()[1][b"message"].decode("ascii")) + # Make sure that we get errors from a failed method. + a.fail_method.remote() + wait_for_errors(b"task", 2) + self.assertEqual(len(ray.error_info()), 2) + self.assertIn(error_message2, + ray.error_info()[1][b"message"].decode("ascii")) - ray.worker.cleanup() + ray.worker.cleanup() - def testIncorrectMethodCalls(self): - ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) + def testIncorrectMethodCalls(self): + ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) - @ray.remote - class Actor(object): - def __init__(self, missing_variable_name): - pass + @ray.remote + class Actor(object): + def __init__(self, missing_variable_name): + pass - def get_val(self, x): - pass + def get_val(self, x): + pass - # Make sure that we get errors if we call the constructor incorrectly. + # Make sure that we get errors if we call the constructor incorrectly. - # Create an actor with too few arguments. - with self.assertRaises(Exception): - a = Actor.remote() + # Create an actor with too few arguments. + with self.assertRaises(Exception): + a = Actor.remote() - # Create an actor with too many arguments. - with self.assertRaises(Exception): - a = Actor.remote(1, 2) + # Create an actor with too many arguments. + with self.assertRaises(Exception): + a = Actor.remote(1, 2) - # Create an actor the correct number of arguments. - a = Actor.remote(1) + # Create an actor the correct number of arguments. + a = Actor.remote(1) - # Call a method with too few arguments. - with self.assertRaises(Exception): - a.get_val.remote() + # Call a method with too few arguments. + with self.assertRaises(Exception): + a.get_val.remote() - # Call a method with too many arguments. - with self.assertRaises(Exception): - a.get_val.remote(1, 2) - # Call a method that doesn't exist. - with self.assertRaises(AttributeError): - a.nonexistent_method() - with self.assertRaises(AttributeError): - a.nonexistent_method.remote() + # Call a method with too many arguments. + with self.assertRaises(Exception): + a.get_val.remote(1, 2) + # Call a method that doesn't exist. + with self.assertRaises(AttributeError): + a.nonexistent_method() + with self.assertRaises(AttributeError): + a.nonexistent_method.remote() - ray.worker.cleanup() + ray.worker.cleanup() class WorkerDeath(unittest.TestCase): - def testWorkerDying(self): - ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) + def testWorkerDying(self): + ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) - # Define a remote function that will kill the worker that runs it. - @ray.remote - def f(): - eval("exit()") + # Define a remote function that will kill the worker that runs it. + @ray.remote + def f(): + eval("exit()") - f.remote() + f.remote() - wait_for_errors(b"worker_died", 1) + wait_for_errors(b"worker_died", 1) - self.assertEqual(len(ray.error_info()), 1) - self.assertIn("A worker died or was killed while executing a task.", - ray.error_info()[0][b"message"].decode("ascii")) + self.assertEqual(len(ray.error_info()), 1) + self.assertIn("A worker died or was killed while executing a task.", + ray.error_info()[0][b"message"].decode("ascii")) - ray.worker.cleanup() + ray.worker.cleanup() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/jenkins_tests/multi_node_docker_test.py b/test/jenkins_tests/multi_node_docker_test.py index 39113d893..9764e41f0 100644 --- a/test/jenkins_tests/multi_node_docker_test.py +++ b/test/jenkins_tests/multi_node_docker_test.py @@ -11,314 +11,326 @@ import sys def wait_for_output(proc): - """This is a convenience method to parse a process's stdout and stderr. + """This is a convenience method to parse a process's stdout and stderr. - Args: - proc: A process started by subprocess.Popen. + Args: + proc: A process started by subprocess.Popen. - Returns: - A tuple of the stdout and stderr of the process as strings. - """ - stdout_data, stderr_data = proc.communicate() + Returns: + A tuple of the stdout and stderr of the process as strings. + """ + stdout_data, stderr_data = proc.communicate() - if stdout_data is not None: - try: - # NOTE(rkn): This try/except block is here because I once saw an - # exception raised here and want to print more information if that - # happens again. - stdout_data = stdout_data.decode("ascii") - except UnicodeDecodeError: - raise Exception("Failed to decode stdout_data:", stdout_data) + if stdout_data is not None: + try: + # NOTE(rkn): This try/except block is here because I once saw an + # exception raised here and want to print more information if that + # happens again. + stdout_data = stdout_data.decode("ascii") + except UnicodeDecodeError: + raise Exception("Failed to decode stdout_data:", stdout_data) - if stderr_data is not None: - try: - # NOTE(rkn): This try/except block is here because I once saw an - # exception raised here and want to print more information if that - # happens again. - stderr_data = stderr_data.decode("ascii") - except UnicodeDecodeError: - raise Exception("Failed to decode stderr_data:", stderr_data) + if stderr_data is not None: + try: + # NOTE(rkn): This try/except block is here because I once saw an + # exception raised here and want to print more information if that + # happens again. + stderr_data = stderr_data.decode("ascii") + except UnicodeDecodeError: + raise Exception("Failed to decode stderr_data:", stderr_data) - return stdout_data, stderr_data + return stdout_data, stderr_data class DockerRunner(object): - """This class manages the logistics of running multiple nodes in Docker. + """This class manages the logistics of running multiple nodes in Docker. - This class is used for starting multiple Ray nodes within Docker, stopping - Ray, running a workload, and determining the success or failure of the - workload. + This class is used for starting multiple Ray nodes within Docker, stopping + Ray, running a workload, and determining the success or failure of the + workload. - Attributes: - head_container_id: The ID of the docker container that runs the head node. - worker_container_ids: A list of the docker container IDs of the Ray worker - nodes. - head_container_ip: The IP address of the docker container that runs the - head node. - """ - def __init__(self): - """Initialize the DockerRunner.""" - self.head_container_id = None - self.worker_container_ids = [] - self.head_container_ip = None - - def _get_container_id(self, stdout_data): - """Parse the docker container ID from stdout_data. - - Args: - stdout_data: This should be a string with the standard output of a call - to a docker command. - - Returns: - The container ID of the docker container. + Attributes: + head_container_id: The ID of the docker container that runs the head + node. + worker_container_ids: A list of the docker container IDs of the Ray + worker nodes. + head_container_ip: The IP address of the docker container that runs the + head node. """ - p = re.compile("([0-9a-f]{64})\n") - m = p.match(stdout_data) - if m is None: - return None - else: - return m.group(1) + def __init__(self): + """Initialize the DockerRunner.""" + self.head_container_id = None + self.worker_container_ids = [] + self.head_container_ip = None - def _get_container_ip(self, container_id): - """Get the IP address of a specific docker container. + def _get_container_id(self, stdout_data): + """Parse the docker container ID from stdout_data. - Args: - container_id: The docker container ID of the relevant docker container. + Args: + stdout_data: This should be a string with the standard output of a + call to a docker command. - Returns: - The IP address of the container. - """ - proc = subprocess.Popen(["docker", "inspect", - "--format={{.NetworkSettings.Networks.bridge" - ".IPAddress}}", - container_id], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout_data, _ = wait_for_output(proc) - p = re.compile("([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})") - m = p.match(stdout_data) - if m is None: - raise RuntimeError("Container IP not found.") - else: - return m.group(1) + Returns: + The container ID of the docker container. + """ + p = re.compile("([0-9a-f]{64})\n") + m = p.match(stdout_data) + if m is None: + return None + else: + return m.group(1) - def _start_head_node(self, docker_image, mem_size, shm_size, - num_redis_shards, num_cpus, num_gpus, development_mode): - """Start the Ray head node inside a docker container.""" - mem_arg = ["--memory=" + mem_size] if mem_size else [] - shm_arg = ["--shm-size=" + shm_size] if shm_size else [] - volume_arg = (["-v", - "{}:{}".format(os.path.dirname(os.path.realpath(__file__)), - "/ray/test/jenkins_tests")] - if development_mode else []) + def _get_container_ip(self, container_id): + """Get the IP address of a specific docker container. - command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + - [docker_image, "ray", "start", "--head", "--block", - "--redis-port=6379", - "--num-redis-shards={}".format(num_redis_shards), - "--num-cpus={}".format(num_cpus), - "--num-gpus={}".format(num_gpus)]) - print("Starting head node with command:{}".format(command)) + Args: + container_id: The docker container ID of the relevant docker + container. - proc = subprocess.Popen(command, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout_data, _ = wait_for_output(proc) - container_id = self._get_container_id(stdout_data) - if container_id is None: - raise RuntimeError("Failed to find container ID.") - self.head_container_id = container_id - self.head_container_ip = self._get_container_ip(container_id) + Returns: + The IP address of the container. + """ + proc = subprocess.Popen(["docker", "inspect", + "--format={{.NetworkSettings.Networks.bridge" + ".IPAddress}}", + container_id], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout_data, _ = wait_for_output(proc) + p = re.compile("([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})") + m = p.match(stdout_data) + if m is None: + raise RuntimeError("Container IP not found.") + else: + return m.group(1) - def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus, - num_gpus, development_mode): - """Start a Ray worker node inside a docker container.""" - mem_arg = ["--memory=" + mem_size] if mem_size else [] - shm_arg = ["--shm-size=" + shm_size] if shm_size else [] - volume_arg = (["-v", - "{}:{}".format(os.path.dirname(os.path.realpath(__file__)), - "/ray/test/jenkins_tests")] - if development_mode else []) - command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + - ["--shm-size=" + shm_size, docker_image, - "ray", "start", "--block", - "--redis-address={:s}:6379".format(self.head_container_ip), - "--num-cpus={}".format(num_cpus), - "--num-gpus={}".format(num_gpus)]) - print("Starting worker node with command:{}".format(command)) - proc = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - stdout_data, _ = wait_for_output(proc) - container_id = self._get_container_id(stdout_data) - if container_id is None: - raise RuntimeError("Failed to find container id") - self.worker_container_ids.append(container_id) + def _start_head_node(self, docker_image, mem_size, shm_size, + num_redis_shards, num_cpus, num_gpus, + development_mode): + """Start the Ray head node inside a docker container.""" + mem_arg = ["--memory=" + mem_size] if mem_size else [] + shm_arg = ["--shm-size=" + shm_size] if shm_size else [] + volume_arg = (["-v", + "{}:{}".format(os.path.dirname( + os.path.realpath(__file__)), + "/ray/test/jenkins_tests")] + if development_mode else []) - def start_ray(self, docker_image=None, mem_size=None, shm_size=None, - num_nodes=None, num_redis_shards=1, num_cpus=None, - num_gpus=None, development_mode=None): - """Start a Ray cluster within docker. + command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + + [docker_image, "ray", "start", "--head", "--block", + "--redis-port=6379", + "--num-redis-shards={}".format(num_redis_shards), + "--num-cpus={}".format(num_cpus), + "--num-gpus={}".format(num_gpus)]) + print("Starting head node with command:{}".format(command)) - This starts one docker container running the head node and num_nodes - 1 - docker containers running the Ray worker nodes. + proc = subprocess.Popen(command, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout_data, _ = wait_for_output(proc) + container_id = self._get_container_id(stdout_data) + if container_id is None: + raise RuntimeError("Failed to find container ID.") + self.head_container_id = container_id + self.head_container_ip = self._get_container_ip(container_id) - Args: - docker_image: The docker image to use for all of the nodes. - mem_size: The amount of memory to start each docker container with. This - will be passed into `docker run` as the --memory flag. If this is None, - then no --memory flag will be used. - shm_size: The amount of shared memory to start each docker container - with. This will be passed into `docker run` as the `--shm-size` flag. - num_nodes: The number of nodes to use in the cluster (this counts the - head node as well). - num_redis_shards: The number of Redis shards to use on the head node. - num_cpus: A list of the number of CPUs to start each node with. - num_gpus: A list of the number of GPUs to start each node with. - development_mode: True if you want to mount the local copy of - test/jenkins_test on the head node so we can avoid rebuilding docker - images during development. - """ - assert len(num_cpus) == num_nodes - assert len(num_gpus) == num_nodes + def _start_worker_node(self, docker_image, mem_size, shm_size, num_cpus, + num_gpus, development_mode): + """Start a Ray worker node inside a docker container.""" + mem_arg = ["--memory=" + mem_size] if mem_size else [] + shm_arg = ["--shm-size=" + shm_size] if shm_size else [] + volume_arg = (["-v", + "{}:{}".format(os.path.dirname( + os.path.realpath(__file__)), + "/ray/test/jenkins_tests")] + if development_mode else []) + command = (["docker", "run", "-d"] + mem_arg + shm_arg + volume_arg + + ["--shm-size=" + shm_size, docker_image, + "ray", "start", "--block", + "--redis-address={:s}:6379".format(self.head_container_ip), + "--num-cpus={}".format(num_cpus), + "--num-gpus={}".format(num_gpus)]) + print("Starting worker node with command:{}".format(command)) + proc = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout_data, _ = wait_for_output(proc) + container_id = self._get_container_id(stdout_data) + if container_id is None: + raise RuntimeError("Failed to find container id") + self.worker_container_ids.append(container_id) - # Launch the head node. - self._start_head_node(docker_image, mem_size, shm_size, num_redis_shards, - num_cpus[0], num_gpus[0], development_mode) - # Start the worker nodes. - for i in range(num_nodes - 1): - self._start_worker_node(docker_image, mem_size, shm_size, - num_cpus[1 + i], num_gpus[1 + i], + def start_ray(self, docker_image=None, mem_size=None, shm_size=None, + num_nodes=None, num_redis_shards=1, num_cpus=None, + num_gpus=None, development_mode=None): + """Start a Ray cluster within docker. + + This starts one docker container running the head node and + num_nodes - 1 docker containers running the Ray worker nodes. + + Args: + docker_image: The docker image to use for all of the nodes. + mem_size: The amount of memory to start each docker container with. + This will be passed into `docker run` as the --memory flag. If + this is None, then no --memory flag will be used. + shm_size: The amount of shared memory to start each docker + container with. This will be passed into `docker run` as the + `--shm-size` flag. + num_nodes: The number of nodes to use in the cluster (this counts + the head node as well). + num_redis_shards: The number of Redis shards to use on the head + node. + num_cpus: A list of the number of CPUs to start each node with. + num_gpus: A list of the number of GPUs to start each node with. + development_mode: True if you want to mount the local copy of + test/jenkins_test on the head node so we can avoid rebuilding + docker images during development. + """ + assert len(num_cpus) == num_nodes + assert len(num_gpus) == num_nodes + + # Launch the head node. + self._start_head_node(docker_image, mem_size, shm_size, + num_redis_shards, num_cpus[0], num_gpus[0], development_mode) + # Start the worker nodes. + for i in range(num_nodes - 1): + self._start_worker_node(docker_image, mem_size, shm_size, + num_cpus[1 + i], num_gpus[1 + i], + development_mode) - def _stop_node(self, container_id): - """Stop a node in the Ray cluster.""" - proc = subprocess.Popen(["docker", "kill", container_id], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout_data, _ = wait_for_output(proc) - stopped_container_id = self._get_container_id(stdout_data) - if not container_id == stopped_container_id: - raise Exception("Failed to stop container {}.".format(container_id)) + def _stop_node(self, container_id): + """Stop a node in the Ray cluster.""" + proc = subprocess.Popen(["docker", "kill", container_id], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout_data, _ = wait_for_output(proc) + stopped_container_id = self._get_container_id(stdout_data) + if not container_id == stopped_container_id: + raise Exception("Failed to stop container {}." + .format(container_id)) - proc = subprocess.Popen(["docker", "rm", "-f", container_id], - stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout_data, _ = wait_for_output(proc) - removed_container_id = self._get_container_id(stdout_data) - if not container_id == removed_container_id: - raise Exception("Failed to remove container {}.".format(container_id)) + proc = subprocess.Popen(["docker", "rm", "-f", container_id], + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout_data, _ = wait_for_output(proc) + removed_container_id = self._get_container_id(stdout_data) + if not container_id == removed_container_id: + raise Exception("Failed to remove container {}." + .format(container_id)) - print("stop_node", {"container_id": container_id, - "is_head": container_id == self.head_container_id}) + print("stop_node", {"container_id": container_id, + "is_head": container_id == self.head_container_id}) - def stop_ray(self): - """Stop the Ray cluster.""" - self._stop_node(self.head_container_id) - for container_id in self.worker_container_ids: - self._stop_node(container_id) + def stop_ray(self): + """Stop the Ray cluster.""" + self._stop_node(self.head_container_id) + for container_id in self.worker_container_ids: + self._stop_node(container_id) - def run_test(self, test_script, num_drivers, driver_locations=None): - """Run a test script. + def run_test(self, test_script, num_drivers, driver_locations=None): + """Run a test script. - Run a test using the Ray cluster. + Run a test using the Ray cluster. - Args: - test_script: The test script to run. - num_drivers: The number of copies of the test script to run. - driver_locations: A list of the indices of the containers that the - different copies of the test script should be run on. If this is None, - then the containers will be chosen randomly. + Args: + test_script: The test script to run. + num_drivers: The number of copies of the test script to run. + driver_locations: A list of the indices of the containers that the + different copies of the test script should be run on. If this + is None, then the containers will be chosen randomly. - Returns: - A dictionary with information about the test script run. - """ - all_container_ids = [self.head_container_id] + self.worker_container_ids - if driver_locations is None: - driver_locations = [np.random.randint(0, len(all_container_ids)) - for _ in range(num_drivers)] + Returns: + A dictionary with information about the test script run. + """ + all_container_ids = ([self.head_container_id] + + self.worker_container_ids) + if driver_locations is None: + driver_locations = [np.random.randint(0, len(all_container_ids)) + for _ in range(num_drivers)] - # Start the different drivers. - driver_processes = [] - for i in range(len(driver_locations)): - # Get the container ID to run the ith driver in. - container_id = all_container_ids[driver_locations[i]] - command = ["docker", "exec", container_id, "/bin/bash", "-c", - ("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python {}" - .format(self.head_container_ip, i, test_script))] - print("Starting driver with command {}.".format(test_script)) - # Start the driver. - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - driver_processes.append(p) + # Start the different drivers. + driver_processes = [] + for i in range(len(driver_locations)): + # Get the container ID to run the ith driver in. + container_id = all_container_ids[driver_locations[i]] + command = ["docker", "exec", container_id, "/bin/bash", "-c", + ("RAY_REDIS_ADDRESS={}:6379 RAY_DRIVER_INDEX={} python " + "{}".format(self.head_container_ip, i, test_script))] + print("Starting driver with command {}.".format(test_script)) + # Start the driver. + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + driver_processes.append(p) - # Wait for the drivers to finish. - results = [] - for p in driver_processes: - stdout_data, stderr_data = wait_for_output(p) - print("STDOUT:") - print(stdout_data) - print("STDERR:") - print(stderr_data) - results.append({"success": p.returncode == 0, - "return_code": p.returncode}) - return results + # Wait for the drivers to finish. + results = [] + for p in driver_processes: + stdout_data, stderr_data = wait_for_output(p) + print("STDOUT:") + print(stdout_data) + print("STDERR:") + print(stderr_data) + results.append({"success": p.returncode == 0, + "return_code": p.returncode}) + return results if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run multinode tests in Docker.") - parser.add_argument("--docker-image", default="ray-project/deploy", - help="docker image") - parser.add_argument("--mem-size", help="memory size") - parser.add_argument("--shm-size", default="1G", help="shared memory size") - parser.add_argument("--num-nodes", default=1, type=int, - help="number of nodes to use in the cluster") - parser.add_argument("--num-redis-shards", default=1, type=int, - help=("the number of Redis shards to start on the head " - "node")) - parser.add_argument("--num-cpus", type=str, - help=("a comma separated list of values representing " - "the number of CPUs to start each node with")) - parser.add_argument("--num-gpus", type=str, - help=("a comma separated list of values representing " - "the number of GPUs to start each node with")) - parser.add_argument("--num-drivers", default=1, type=int, - help="number of drivers to run") - parser.add_argument("--driver-locations", type=str, - help=("a comma separated list of indices of the " - "containers to run the drivers in")) - parser.add_argument("--test-script", required=True, help="test script") - parser.add_argument("--development-mode", action="store_true", - help="use local copies of the test scripts") - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Run multinode tests in Docker.") + parser.add_argument("--docker-image", default="ray-project/deploy", + help="docker image") + parser.add_argument("--mem-size", help="memory size") + parser.add_argument("--shm-size", default="1G", help="shared memory size") + parser.add_argument("--num-nodes", default=1, type=int, + help="number of nodes to use in the cluster") + parser.add_argument("--num-redis-shards", default=1, type=int, + help=("the number of Redis shards to start on the " + "head node")) + parser.add_argument("--num-cpus", type=str, + help=("a comma separated list of values representing " + "the number of CPUs to start each node with")) + parser.add_argument("--num-gpus", type=str, + help=("a comma separated list of values representing " + "the number of GPUs to start each node with")) + parser.add_argument("--num-drivers", default=1, type=int, + help="number of drivers to run") + parser.add_argument("--driver-locations", type=str, + help=("a comma separated list of indices of the " + "containers to run the drivers in")) + parser.add_argument("--test-script", required=True, help="test script") + parser.add_argument("--development-mode", action="store_true", + help="use local copies of the test scripts") + args = parser.parse_args() - # Parse the number of CPUs and GPUs to use for each worker. - num_nodes = args.num_nodes - num_cpus = ([int(i) for i in args.num_cpus.split(",")] - if args.num_cpus is not None else num_nodes * [10]) - num_gpus = ([int(i) for i in args.num_gpus.split(",")] - if args.num_gpus is not None else num_nodes * [0]) + # Parse the number of CPUs and GPUs to use for each worker. + num_nodes = args.num_nodes + num_cpus = ([int(i) for i in args.num_cpus.split(",")] + if args.num_cpus is not None else num_nodes * [10]) + num_gpus = ([int(i) for i in args.num_gpus.split(",")] + if args.num_gpus is not None else num_nodes * [0]) - # Parse the driver locations. - driver_locations = (None if args.driver_locations is None - else [int(i) for i in args.driver_locations.split(",")]) + # Parse the driver locations. + driver_locations = (None if args.driver_locations is None + else [int(i) for i + in args.driver_locations.split(",")]) - d = DockerRunner() - d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size, - shm_size=args.shm_size, num_nodes=num_nodes, - num_redis_shards=args.num_redis_shards, num_cpus=num_cpus, - num_gpus=num_gpus, development_mode=args.development_mode) - try: - run_results = d.run_test(args.test_script, args.num_drivers, - driver_locations=driver_locations) - finally: - d.stop_ray() + d = DockerRunner() + d.start_ray(docker_image=args.docker_image, mem_size=args.mem_size, + shm_size=args.shm_size, num_nodes=num_nodes, + num_redis_shards=args.num_redis_shards, num_cpus=num_cpus, + num_gpus=num_gpus, development_mode=args.development_mode) + try: + run_results = d.run_test(args.test_script, args.num_drivers, + driver_locations=driver_locations) + finally: + d.stop_ray() - any_failed = False - for run_result in run_results: - if "success" in run_result and run_result["success"]: - print("RESULT: Test {} succeeded.".format(args.test_script)) + any_failed = False + for run_result in run_results: + if "success" in run_result and run_result["success"]: + print("RESULT: Test {} succeeded.".format(args.test_script)) + else: + print("RESULT: Test {} failed.".format(args.test_script)) + any_failed = True + + if any_failed: + sys.exit(1) else: - print("RESULT: Test {} failed.".format(args.test_script)) - any_failed = True - - if any_failed: - sys.exit(1) - else: - sys.exit(0) + sys.exit(0) diff --git a/test/jenkins_tests/multi_node_tests/large_memory_test.py b/test/jenkins_tests/multi_node_tests/large_memory_test.py index 1072bdc7b..af31afce0 100644 --- a/test/jenkins_tests/multi_node_tests/large_memory_test.py +++ b/test/jenkins_tests/multi_node_tests/large_memory_test.py @@ -8,42 +8,42 @@ import ray if __name__ == "__main__": - ray.init(num_workers=0) + ray.init(num_workers=0) - A = np.ones(2 ** 31 + 1, dtype="int8") - a = ray.put(A) - assert np.sum(ray.get(a)) == np.sum(A) - del A - del a - print("Successfully put A.") + A = np.ones(2 ** 31 + 1, dtype="int8") + a = ray.put(A) + assert np.sum(ray.get(a)) == np.sum(A) + del A + del a + print("Successfully put A.") - B = {"hello": np.zeros(2 ** 30 + 1), - "world": np.ones(2 ** 30 + 1)} - b = ray.put(B) - assert np.sum(ray.get(b)["hello"]) == np.sum(B["hello"]) - assert np.sum(ray.get(b)["world"]) == np.sum(B["world"]) - del B - del b - print("Successfully put B.") + B = {"hello": np.zeros(2 ** 30 + 1), + "world": np.ones(2 ** 30 + 1)} + b = ray.put(B) + assert np.sum(ray.get(b)["hello"]) == np.sum(B["hello"]) + assert np.sum(ray.get(b)["world"]) == np.sum(B["world"]) + del B + del b + print("Successfully put B.") - C = [np.ones(2 ** 30 + 1), 42.0 * np.ones(2 ** 30 + 1)] - c = ray.put(C) - assert np.sum(ray.get(c)[0]) == np.sum(C[0]) - assert np.sum(ray.get(c)[1]) == np.sum(C[1]) - del C - del c - print("Successfully put C.") + C = [np.ones(2 ** 30 + 1), 42.0 * np.ones(2 ** 30 + 1)] + c = ray.put(C) + assert np.sum(ray.get(c)[0]) == np.sum(C[0]) + assert np.sum(ray.get(c)[1]) == np.sum(C[1]) + del C + del c + print("Successfully put C.") - # D = (2 ** 30 + 1) * ["h"] - # d = ray.put(D) - # assert ray.get(d) == D - # del D - # del d - # print("Successfully put D.") + # D = (2 ** 30 + 1) * ["h"] + # d = ray.put(D) + # assert ray.get(d) == D + # del D + # del d + # print("Successfully put D.") - # E = (2 ** 30 + 1) * ("i",) - # e = ray.put(E) - # assert ray.get(e) == E - # del E - # del e - # print("Successfully put E.") + # E = (2 ** 30 + 1) * ("i",) + # e = ray.put(E) + # assert ray.get(e) == E + # del E + # del e + # print("Successfully put E.") diff --git a/test/jenkins_tests/multi_node_tests/many_drivers_test.py b/test/jenkins_tests/multi_node_tests/many_drivers_test.py index 64ca334e1..ed9b24776 100644 --- a/test/jenkins_tests/multi_node_tests/many_drivers_test.py +++ b/test/jenkins_tests/multi_node_tests/many_drivers_test.py @@ -22,59 +22,59 @@ num_gpus_per_driver = 5 @ray.remote(num_gpus=1) class Actor1(object): - def __init__(self): - assert len(ray.get_gpu_ids()) == 1 + def __init__(self): + assert len(ray.get_gpu_ids()) == 1 - def check_ids(self): - assert len(ray.get_gpu_ids()) == 1 + def check_ids(self): + assert len(ray.get_gpu_ids()) == 1 def driver(redis_address, driver_index): - """The script for driver 0. + """The script for driver 0. - This driver should create five actors that each use one GPU and some actors - that use no GPUs. After a while, it should exit. - """ - ray.init(redis_address=redis_address) + This driver should create five actors that each use one GPU and some actors + that use no GPUs. After a while, it should exit. + """ + ray.init(redis_address=redis_address) - # Wait for all the nodes to join the cluster. - _wait_for_nodes_to_join(total_num_nodes) + # Wait for all the nodes to join the cluster. + _wait_for_nodes_to_join(total_num_nodes) - # Limit the number of drivers running concurrently. - for i in range(driver_index - max_concurrent_drivers + 1): - _wait_for_event("DRIVER_{}_DONE".format(i), redis_address) + # Limit the number of drivers running concurrently. + for i in range(driver_index - max_concurrent_drivers + 1): + _wait_for_event("DRIVER_{}_DONE".format(i), redis_address) - def try_to_create_actor(actor_class, timeout=100): - # Try to create an actor, but allow failures while we wait for the monitor - # to release the resources for the removed drivers. - start_time = time.time() - while time.time() - start_time < timeout: - try: - actor = actor_class.remote() - except Exception as e: - time.sleep(0.1) - else: - return actor - # If we are here, then we timed out while looping. - raise Exception("Timed out while trying to create actor.") + def try_to_create_actor(actor_class, timeout=100): + # Try to create an actor, but allow failures while we wait for the + # monitor to release the resources for the removed drivers. + start_time = time.time() + while time.time() - start_time < timeout: + try: + actor = actor_class.remote() + except Exception as e: + time.sleep(0.1) + else: + return actor + # If we are here, then we timed out while looping. + raise Exception("Timed out while trying to create actor.") - # Create some actors that require one GPU. - actors_one_gpu = [] - for _ in range(num_gpus_per_driver): - actors_one_gpu.append(try_to_create_actor(Actor1)) + # Create some actors that require one GPU. + actors_one_gpu = [] + for _ in range(num_gpus_per_driver): + actors_one_gpu.append(try_to_create_actor(Actor1)) - for _ in range(100): - ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) + for _ in range(100): + ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) - _broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address) + _broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address) if __name__ == "__main__": - driver_index = int(os.environ["RAY_DRIVER_INDEX"]) - redis_address = os.environ["RAY_REDIS_ADDRESS"] - print("Driver {} started at {}.".format(driver_index, time.time())) + driver_index = int(os.environ["RAY_DRIVER_INDEX"]) + redis_address = os.environ["RAY_REDIS_ADDRESS"] + print("Driver {} started at {}.".format(driver_index, time.time())) - # In this test, all drivers will run the same script. - driver(redis_address, driver_index) + # In this test, all drivers will run the same script. + driver(redis_address, driver_index) - print("Driver {} finished at {}.".format(driver_index, time.time())) + print("Driver {} finished at {}.".format(driver_index, time.time())) diff --git a/test/jenkins_tests/multi_node_tests/remove_driver_test.py b/test/jenkins_tests/multi_node_tests/remove_driver_test.py index 27e0506cc..f857fc2fe 100644 --- a/test/jenkins_tests/multi_node_tests/remove_driver_test.py +++ b/test/jenkins_tests/multi_node_tests/remove_driver_test.py @@ -19,21 +19,21 @@ total_num_nodes = 5 def actor_event_name(driver_index, actor_index): - return "DRIVER_{}_ACTOR_{}_RUNNING".format(driver_index, actor_index) + return "DRIVER_{}_ACTOR_{}_RUNNING".format(driver_index, actor_index) def remote_function_event_name(driver_index, task_index): - return "DRIVER_{}_TASK_{}_RUNNING".format(driver_index, task_index) + return "DRIVER_{}_TASK_{}_RUNNING".format(driver_index, task_index) @ray.remote def long_running_task(driver_index, task_index, redis_address): - _broadcast_event(remote_function_event_name(driver_index, task_index), - redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) - # Loop forever. - while True: - time.sleep(100) + _broadcast_event(remote_function_event_name(driver_index, task_index), + redis_address, + data=(ray.services.get_node_ip_address(), os.getpid())) + # Loop forever. + while True: + time.sleep(100) num_long_running_tasks_per_driver = 2 @@ -41,228 +41,235 @@ num_long_running_tasks_per_driver = 2 @ray.remote class Actor0(object): - def __init__(self, driver_index, actor_index, redis_address): - _broadcast_event(actor_event_name(driver_index, actor_index), - redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) - assert len(ray.get_gpu_ids()) == 0 + def __init__(self, driver_index, actor_index, redis_address): + _broadcast_event(actor_event_name(driver_index, actor_index), + redis_address, + data=(ray.services.get_node_ip_address(), + os.getpid())) + assert len(ray.get_gpu_ids()) == 0 - def check_ids(self): - assert len(ray.get_gpu_ids()) == 0 + def check_ids(self): + assert len(ray.get_gpu_ids()) == 0 - def long_running_method(self): - # Loop forever. - while True: - time.sleep(100) + def long_running_method(self): + # Loop forever. + while True: + time.sleep(100) @ray.remote(num_gpus=1) class Actor1(object): - def __init__(self, driver_index, actor_index, redis_address): - _broadcast_event(actor_event_name(driver_index, actor_index), - redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) - assert len(ray.get_gpu_ids()) == 1 + def __init__(self, driver_index, actor_index, redis_address): + _broadcast_event(actor_event_name(driver_index, actor_index), + redis_address, + data=(ray.services.get_node_ip_address(), + os.getpid())) + assert len(ray.get_gpu_ids()) == 1 - def check_ids(self): - assert len(ray.get_gpu_ids()) == 1 + def check_ids(self): + assert len(ray.get_gpu_ids()) == 1 - def long_running_method(self): - # Loop forever. - while True: - time.sleep(100) + def long_running_method(self): + # Loop forever. + while True: + time.sleep(100) @ray.remote(num_gpus=2) class Actor2(object): - def __init__(self, driver_index, actor_index, redis_address): - _broadcast_event(actor_event_name(driver_index, actor_index), - redis_address, - data=(ray.services.get_node_ip_address(), os.getpid())) - assert len(ray.get_gpu_ids()) == 2 + def __init__(self, driver_index, actor_index, redis_address): + _broadcast_event(actor_event_name(driver_index, actor_index), + redis_address, + data=(ray.services.get_node_ip_address(), + os.getpid())) + assert len(ray.get_gpu_ids()) == 2 - def check_ids(self): - assert len(ray.get_gpu_ids()) == 2 + def check_ids(self): + assert len(ray.get_gpu_ids()) == 2 - def long_running_method(self): - # Loop forever. - while True: - time.sleep(100) + def long_running_method(self): + # Loop forever. + while True: + time.sleep(100) def driver_0(redis_address, driver_index): - """The script for driver 0. + """The script for driver 0. - This driver should create five actors that each use one GPU and some actors - that use no GPUs. After a while, it should exit. - """ - ray.init(redis_address=redis_address) + This driver should create five actors that each use one GPU and some actors + that use no GPUs. After a while, it should exit. + """ + ray.init(redis_address=redis_address) - # Wait for all the nodes to join the cluster. - _wait_for_nodes_to_join(total_num_nodes) + # Wait for all the nodes to join the cluster. + _wait_for_nodes_to_join(total_num_nodes) - # Start some long running task. Driver 2 will make sure the worker running - # this task has been killed. - for i in range(num_long_running_tasks_per_driver): - long_running_task.remote(driver_index, i, redis_address) + # Start some long running task. Driver 2 will make sure the worker running + # this task has been killed. + for i in range(num_long_running_tasks_per_driver): + long_running_task.remote(driver_index, i, redis_address) - # Create some actors that require one GPU. - actors_one_gpu = [Actor1.remote(driver_index, i, redis_address) - for i in range(5)] - # Create some actors that don't require any GPUs. - actors_no_gpus = [Actor0.remote(driver_index, 5 + i, redis_address) - for i in range(5)] + # Create some actors that require one GPU. + actors_one_gpu = [Actor1.remote(driver_index, i, redis_address) + for i in range(5)] + # Create some actors that don't require any GPUs. + actors_no_gpus = [Actor0.remote(driver_index, 5 + i, redis_address) + for i in range(5)] - for _ in range(1000): - ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) - ray.get([actor.check_ids.remote() for actor in actors_no_gpus]) + for _ in range(1000): + ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) + ray.get([actor.check_ids.remote() for actor in actors_no_gpus]) - # Start a long-running method on one actor and make sure this doesn't affect - # anything. - actors_no_gpus[0].long_running_method.remote() + # Start a long-running method on one actor and make sure this doesn't + # affect anything. + actors_no_gpus[0].long_running_method.remote() - _broadcast_event("DRIVER_0_DONE", redis_address) + _broadcast_event("DRIVER_0_DONE", redis_address) def driver_1(redis_address, driver_index): - """The script for driver 1. + """The script for driver 1. - This driver should create one actor that uses two GPUs, three actors that - each use one GPU (the one requiring two must be created first), and some - actors that don't use any GPUs. After a while, it should exit. - """ - ray.init(redis_address=redis_address) + This driver should create one actor that uses two GPUs, three actors that + each use one GPU (the one requiring two must be created first), and some + actors that don't use any GPUs. After a while, it should exit. + """ + ray.init(redis_address=redis_address) - # Wait for all the nodes to join the cluster. - _wait_for_nodes_to_join(total_num_nodes) + # Wait for all the nodes to join the cluster. + _wait_for_nodes_to_join(total_num_nodes) - # Start some long running task. Driver 2 will make sure the worker running - # this task has been killed. - for i in range(num_long_running_tasks_per_driver): - long_running_task.remote(driver_index, i, redis_address) + # Start some long running task. Driver 2 will make sure the worker running + # this task has been killed. + for i in range(num_long_running_tasks_per_driver): + long_running_task.remote(driver_index, i, redis_address) - # Create an actor that requires two GPUs. - actors_two_gpus = [Actor2.remote(driver_index, i, redis_address) - for i in range(1)] - # Create some actors that require one GPU. - actors_one_gpu = [Actor1.remote(driver_index, 1 + i, redis_address) - for i in range(3)] - # Create some actors that don't require any GPUs. - actors_no_gpus = [Actor0.remote(driver_index, 1 + 3 + i, redis_address) - for i in range(5)] + # Create an actor that requires two GPUs. + actors_two_gpus = [Actor2.remote(driver_index, i, redis_address) + for i in range(1)] + # Create some actors that require one GPU. + actors_one_gpu = [Actor1.remote(driver_index, 1 + i, redis_address) + for i in range(3)] + # Create some actors that don't require any GPUs. + actors_no_gpus = [Actor0.remote(driver_index, 1 + 3 + i, redis_address) + for i in range(5)] - for _ in range(1000): - ray.get([actor.check_ids.remote() for actor in actors_two_gpus]) - ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) - ray.get([actor.check_ids.remote() for actor in actors_no_gpus]) + for _ in range(1000): + ray.get([actor.check_ids.remote() for actor in actors_two_gpus]) + ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) + ray.get([actor.check_ids.remote() for actor in actors_no_gpus]) - # Start a long-running method on one actor and make sure this doesn't affect - # anything. - actors_one_gpu[0].long_running_method.remote() + # Start a long-running method on one actor and make sure this doesn't + # affect anything. + actors_one_gpu[0].long_running_method.remote() - _broadcast_event("DRIVER_1_DONE", redis_address) + _broadcast_event("DRIVER_1_DONE", redis_address) def cleanup_driver(redis_address, driver_index): - """The script for drivers 2 through 6. + """The script for drivers 2 through 6. - This driver should wait for the first two drivers to finish. Then it should - create some actors that use a total of ten GPUs. - """ - ray.init(redis_address=redis_address) + This driver should wait for the first two drivers to finish. Then it should + create some actors that use a total of ten GPUs. + """ + ray.init(redis_address=redis_address) - # Only one of the cleanup drivers should create more actors. - if driver_index == 2: - # We go ahead and create some actors that don't require any GPUs. We don't - # need to wait for the other drivers to finish. We call methods on these - # actors later to make sure they haven't been killed. - actors_no_gpus = [Actor0.remote(driver_index, i, redis_address) - for i in range(10)] + # Only one of the cleanup drivers should create more actors. + if driver_index == 2: + # We go ahead and create some actors that don't require any GPUs. We + # don't need to wait for the other drivers to finish. We call methods + # on these actors later to make sure they haven't been killed. + actors_no_gpus = [Actor0.remote(driver_index, i, redis_address) + for i in range(10)] - _wait_for_event("DRIVER_0_DONE", redis_address) - _wait_for_event("DRIVER_1_DONE", redis_address) + _wait_for_event("DRIVER_0_DONE", redis_address) + _wait_for_event("DRIVER_1_DONE", redis_address) - def try_to_create_actor(actor_class, driver_index, actor_index, timeout=20): - # Try to create an actor, but allow failures while we wait for the monitor - # to release the resources for the removed drivers. - start_time = time.time() - while time.time() - start_time < timeout: - try: - actor = actor_class.remote(driver_index, actor_index, redis_address) - except Exception as e: - time.sleep(0.1) - else: - return actor - # If we are here, then we timed out while looping. - raise Exception("Timed out while trying to create actor.") - - # Only one of the cleanup drivers should create more actors. - if driver_index == 2: - # Create some actors that require two GPUs. - actors_two_gpus = [] - for i in range(3): - actors_two_gpus.append(try_to_create_actor(Actor2, driver_index, 10 + i)) - # Create some actors that require one GPU. - actors_one_gpu = [] - for i in range(4): - actors_one_gpu.append(try_to_create_actor(Actor1, driver_index, - 10 + 3 + i)) - - removed_workers = 0 - - # Make sure that the PIDs for the long-running tasks from driver 0 and driver - # 1 have been killed. - for i in range(num_long_running_tasks_per_driver): - node_ip_address, pid = _wait_for_event(remote_function_event_name(0, i), + def try_to_create_actor(actor_class, driver_index, actor_index, + timeout=20): + # Try to create an actor, but allow failures while we wait for the + # monitor to release the resources for the removed drivers. + start_time = time.time() + while time.time() - start_time < timeout: + try: + actor = actor_class.remote(driver_index, actor_index, redis_address) - if node_ip_address == ray.services.get_node_ip_address(): - wait_for_pid_to_exit(pid) - removed_workers += 1 - for i in range(num_long_running_tasks_per_driver): - node_ip_address, pid = _wait_for_event(remote_function_event_name(1, i), - redis_address) - if node_ip_address == ray.services.get_node_ip_address(): - wait_for_pid_to_exit(pid) - removed_workers += 1 - # Make sure that the PIDs for the actors from driver 0 and driver 1 have been - # killed. - for i in range(10): - node_ip_address, pid = _wait_for_event(actor_event_name(0, i), - redis_address) - if node_ip_address == ray.services.get_node_ip_address(): - wait_for_pid_to_exit(pid) - removed_workers += 1 - for i in range(9): - node_ip_address, pid = _wait_for_event(actor_event_name(1, i), - redis_address) - if node_ip_address == ray.services.get_node_ip_address(): - wait_for_pid_to_exit(pid) - removed_workers += 1 + except Exception as e: + time.sleep(0.1) + else: + return actor + # If we are here, then we timed out while looping. + raise Exception("Timed out while trying to create actor.") - print("{} workers/actors were removed on this node.".format(removed_workers)) + # Only one of the cleanup drivers should create more actors. + if driver_index == 2: + # Create some actors that require two GPUs. + actors_two_gpus = [] + for i in range(3): + actors_two_gpus.append(try_to_create_actor(Actor2, driver_index, + 10 + i)) + # Create some actors that require one GPU. + actors_one_gpu = [] + for i in range(4): + actors_one_gpu.append(try_to_create_actor(Actor1, driver_index, + 10 + 3 + i)) - # Only one of the cleanup drivers should create and use more actors. - if driver_index == 2: - for _ in range(1000): - ray.get([actor.check_ids.remote() for actor in actors_two_gpus]) - ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) - ray.get([actor.check_ids.remote() for actor in actors_no_gpus]) + removed_workers = 0 - _broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address) + # Make sure that the PIDs for the long-running tasks from driver 0 and + # driver 1 have been killed. + for i in range(num_long_running_tasks_per_driver): + node_ip_address, pid = _wait_for_event( + remote_function_event_name(0, i), redis_address) + if node_ip_address == ray.services.get_node_ip_address(): + wait_for_pid_to_exit(pid) + removed_workers += 1 + for i in range(num_long_running_tasks_per_driver): + node_ip_address, pid = _wait_for_event( + remote_function_event_name(1, i), redis_address) + if node_ip_address == ray.services.get_node_ip_address(): + wait_for_pid_to_exit(pid) + removed_workers += 1 + # Make sure that the PIDs for the actors from driver 0 and driver 1 have + # been killed. + for i in range(10): + node_ip_address, pid = _wait_for_event(actor_event_name(0, i), + redis_address) + if node_ip_address == ray.services.get_node_ip_address(): + wait_for_pid_to_exit(pid) + removed_workers += 1 + for i in range(9): + node_ip_address, pid = _wait_for_event(actor_event_name(1, i), + redis_address) + if node_ip_address == ray.services.get_node_ip_address(): + wait_for_pid_to_exit(pid) + removed_workers += 1 + + print("{} workers/actors were removed on this node." + .format(removed_workers)) + + # Only one of the cleanup drivers should create and use more actors. + if driver_index == 2: + for _ in range(1000): + ray.get([actor.check_ids.remote() for actor in actors_two_gpus]) + ray.get([actor.check_ids.remote() for actor in actors_one_gpu]) + ray.get([actor.check_ids.remote() for actor in actors_no_gpus]) + + _broadcast_event("DRIVER_{}_DONE".format(driver_index), redis_address) if __name__ == "__main__": - driver_index = int(os.environ["RAY_DRIVER_INDEX"]) - redis_address = os.environ["RAY_REDIS_ADDRESS"] - print("Driver {} started at {}.".format(driver_index, time.time())) + driver_index = int(os.environ["RAY_DRIVER_INDEX"]) + redis_address = os.environ["RAY_REDIS_ADDRESS"] + print("Driver {} started at {}.".format(driver_index, time.time())) - if driver_index == 0: - driver_0(redis_address, driver_index) - elif driver_index == 1: - driver_1(redis_address, driver_index) - elif driver_index in [2, 3, 4, 5, 6]: - cleanup_driver(redis_address, driver_index) - else: - raise Exception("This code should be unreachable.") + if driver_index == 0: + driver_0(redis_address, driver_index) + elif driver_index == 1: + driver_1(redis_address, driver_index) + elif driver_index in [2, 3, 4, 5, 6]: + cleanup_driver(redis_address, driver_index) + else: + raise Exception("This code should be unreachable.") - print("Driver {} finished at {}.".format(driver_index, time.time())) + print("Driver {} finished at {}.".format(driver_index, time.time())) diff --git a/test/jenkins_tests/multi_node_tests/test_0.py b/test/jenkins_tests/multi_node_tests/test_0.py index 8710a57bd..62ace9cbb 100644 --- a/test/jenkins_tests/multi_node_tests/test_0.py +++ b/test/jenkins_tests/multi_node_tests/test_0.py @@ -10,25 +10,26 @@ import ray @ray.remote def f(): - time.sleep(0.1) - return ray.services.get_node_ip_address() + time.sleep(0.1) + return ray.services.get_node_ip_address() if __name__ == "__main__": - driver_index = int(os.environ["RAY_DRIVER_INDEX"]) - redis_address = os.environ["RAY_REDIS_ADDRESS"] - print("Driver {} started at {}.".format(driver_index, time.time())) + driver_index = int(os.environ["RAY_DRIVER_INDEX"]) + redis_address = os.environ["RAY_REDIS_ADDRESS"] + print("Driver {} started at {}.".format(driver_index, time.time())) - ray.init(redis_address=redis_address) - # Check that tasks are scheduled on all nodes. - num_attempts = 30 - for i in range(num_attempts): - ip_addresses = ray.get([f.remote() for i in range(1000)]) - distinct_addresses = set(ip_addresses) - counts = [ip_addresses.count(address) for address in distinct_addresses] - print("Counts are {}".format(counts)) - if len(counts) == 5: - break - assert len(counts) == 5 + ray.init(redis_address=redis_address) + # Check that tasks are scheduled on all nodes. + num_attempts = 30 + for i in range(num_attempts): + ip_addresses = ray.get([f.remote() for i in range(1000)]) + distinct_addresses = set(ip_addresses) + counts = [ip_addresses.count(address) for address + in distinct_addresses] + print("Counts are {}".format(counts)) + if len(counts) == 5: + break + assert len(counts) == 5 - print("Driver {} finished at {}.".format(driver_index, time.time())) + print("Driver {} finished at {}.".format(driver_index, time.time())) diff --git a/test/microbenchmarks.py b/test/microbenchmarks.py index bb2537e16..f31a8fe56 100644 --- a/test/microbenchmarks.py +++ b/test/microbenchmarks.py @@ -12,111 +12,111 @@ import numpy as np import ray.test.test_functions as test_functions if sys.version_info >= (3, 0): - from importlib import reload + from importlib import reload class MicroBenchmarkTest(unittest.TestCase): - def testTiming(self): - reload(test_functions) - ray.init(num_workers=3) + def testTiming(self): + reload(test_functions) + ray.init(num_workers=3) - # Measure the time required to submit a remote task to the scheduler. - elapsed_times = [] - for _ in range(1000): - start_time = time.time() - test_functions.empty_function.remote() - end_time = time.time() - elapsed_times.append(end_time - start_time) - elapsed_times = np.sort(elapsed_times) - average_elapsed_time = sum(elapsed_times) / 1000 - print("Time required to submit an empty function call:") - print(" Average: {}".format(average_elapsed_time)) - print(" 90th percentile: {}".format(elapsed_times[900])) - print(" 99th percentile: {}".format(elapsed_times[990])) - print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.00038. + # Measure the time required to submit a remote task to the scheduler. + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + test_functions.empty_function.remote() + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to submit an empty function call:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.00038. - # Measure the time required to submit a remote task to the scheduler - # (where the remote task returns one value). - elapsed_times = [] - for _ in range(1000): - start_time = time.time() - test_functions.trivial_function.remote() - end_time = time.time() - elapsed_times.append(end_time - start_time) - elapsed_times = np.sort(elapsed_times) - average_elapsed_time = sum(elapsed_times) / 1000 - print("Time required to submit a trivial function call:") - print(" Average: {}".format(average_elapsed_time)) - print(" 90th percentile: {}".format(elapsed_times[900])) - print(" 99th percentile: {}".format(elapsed_times[990])) - print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.001. + # Measure the time required to submit a remote task to the scheduler + # (where the remote task returns one value). + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + test_functions.trivial_function.remote() + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to submit a trivial function call:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.001. - # Measure the time required to submit a remote task to the scheduler and - # get the result. - elapsed_times = [] - for _ in range(1000): - start_time = time.time() - x = test_functions.trivial_function.remote() - ray.get(x) - end_time = time.time() - elapsed_times.append(end_time - start_time) - elapsed_times = np.sort(elapsed_times) - average_elapsed_time = sum(elapsed_times) / 1000 - print("Time required to submit a trivial function call and get the " - "result:") - print(" Average: {}".format(average_elapsed_time)) - print(" 90th percentile: {}".format(elapsed_times[900])) - print(" 99th percentile: {}".format(elapsed_times[990])) - print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.0013. + # Measure the time required to submit a remote task to the scheduler + # and get the result. + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + x = test_functions.trivial_function.remote() + ray.get(x) + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to submit a trivial function call and get the " + "result:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.0013. - # Measure the time required to do do a put. - elapsed_times = [] - for _ in range(1000): - start_time = time.time() - ray.put(1) - end_time = time.time() - elapsed_times.append(end_time - start_time) - elapsed_times = np.sort(elapsed_times) - average_elapsed_time = sum(elapsed_times) / 1000 - print("Time required to put an int:") - print(" Average: {}".format(average_elapsed_time)) - print(" 90th percentile: {}".format(elapsed_times[900])) - print(" 99th percentile: {}".format(elapsed_times[990])) - print(" worst: {}".format(elapsed_times[999])) - # average_elapsed_time should be about 0.00087. + # Measure the time required to do do a put. + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + ray.put(1) + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to put an int:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.00087. - ray.worker.cleanup() + ray.worker.cleanup() - def testCache(self): - ray.init(num_workers=1) + def testCache(self): + ray.init(num_workers=1) - A = np.random.rand(1, 1000000) - v = np.random.rand(1000000) - A_id = ray.put(A) - v_id = ray.put(v) - a = time.time() - for i in range(100): - A.dot(v) - b = time.time() - a - c = time.time() - for i in range(100): - ray.get(A_id).dot(ray.get(v_id)) - d = time.time() - c + A = np.random.rand(1, 1000000) + v = np.random.rand(1000000) + A_id = ray.put(A) + v_id = ray.put(v) + a = time.time() + for i in range(100): + A.dot(v) + b = time.time() - a + c = time.time() + for i in range(100): + ray.get(A_id).dot(ray.get(v_id)) + d = time.time() - c - if d > 1.5 * b: - if os.getenv("TRAVIS") is None: - raise Exception("The caching test was too slow. " - "d = {}, b = {}".format(d, b)) - else: - print("WARNING: The caching test was too slow. " - "d = {}, b = {}".format(d, b)) + if d > 1.5 * b: + if os.getenv("TRAVIS") is None: + raise Exception("The caching test was too slow. " + "d = {}, b = {}".format(d, b)) + else: + print("WARNING: The caching test was too slow. " + "d = {}, b = {}".format(d, b)) - ray.worker.cleanup() + ray.worker.cleanup() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 31aef4aa3..0abe0e738 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -11,52 +11,53 @@ import time class MultiNodeTest(unittest.TestCase): - def setUp(self): - # Start the Ray processes on this machine. - out = subprocess.check_output(["ray", "start", "--head"]).decode("ascii") - # Get the redis address from the output. - redis_substring_prefix = "redis_address=\"" - redis_address_location = (out.find(redis_substring_prefix) + - len(redis_substring_prefix)) - redis_address = out[redis_address_location:] - self.redis_address = redis_address.split("\"")[0] + def setUp(self): + # Start the Ray processes on this machine. + out = subprocess.check_output( + ["ray", "start", "--head"]).decode("ascii") + # Get the redis address from the output. + redis_substring_prefix = "redis_address=\"" + redis_address_location = (out.find(redis_substring_prefix) + + len(redis_substring_prefix)) + redis_address = out[redis_address_location:] + self.redis_address = redis_address.split("\"")[0] - def tearDown(self): - # Kill the Ray cluster. - subprocess.Popen(["ray", "stop"]).wait() + def tearDown(self): + # Kill the Ray cluster. + subprocess.Popen(["ray", "stop"]).wait() - def testErrorIsolation(self): - # Connect a driver to the Ray cluster. - ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE) + def testErrorIsolation(self): + # Connect a driver to the Ray cluster. + ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE) - # There shouldn't be any errors yet. - self.assertEqual(len(ray.error_info()), 0) + # There shouldn't be any errors yet. + self.assertEqual(len(ray.error_info()), 0) - error_string1 = "error_string1" - error_string2 = "error_string2" + error_string1 = "error_string1" + error_string2 = "error_string2" - @ray.remote - def f(): - raise Exception(error_string1) + @ray.remote + def f(): + raise Exception(error_string1) - # Run a remote function that throws an error. - with self.assertRaises(Exception): - ray.get(f.remote()) + # Run a remote function that throws an error. + with self.assertRaises(Exception): + ray.get(f.remote()) - # Wait for the error to appear in Redis. - while len(ray.error_info()) != 1: - time.sleep(0.1) - print("Waiting for error to appear.") + # Wait for the error to appear in Redis. + while len(ray.error_info()) != 1: + time.sleep(0.1) + print("Waiting for error to appear.") - # Make sure we got the error. - self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_string1, - ray.error_info()[0][b"message"].decode("ascii")) + # Make sure we got the error. + self.assertEqual(len(ray.error_info()), 1) + self.assertIn(error_string1, + ray.error_info()[0][b"message"].decode("ascii")) - # Start another driver and make sure that it does not receive this error. - # Make the other driver throw an error, and make sure it receives that - # error. - driver_script = """ + # Start another driver and make sure that it does not receive this + # error. Make the other driver throw an error, and make sure it + # receives that error. + driver_script = """ import ray import time @@ -67,16 +68,16 @@ assert len(ray.error_info()) == 0 @ray.remote def f(): - raise Exception("{}") + raise Exception("{}") try: - ray.get(f.remote()) + ray.get(f.remote()) except Exception as e: - pass + pass while len(ray.error_info()) != 1: - print(len(ray.error_info())) - time.sleep(0.1) + print(len(ray.error_info())) + time.sleep(0.1) assert len(ray.error_info()) == 1 assert "{}" in ray.error_info()[0][b"message"].decode("ascii") @@ -84,126 +85,126 @@ assert "{}" in ray.error_info()[0][b"message"].decode("ascii") print("success") """.format(self.redis_address, error_string2, error_string2) - # Save the driver script as a file so we can call it using subprocess. - with tempfile.NamedTemporaryFile() as f: - f.write(driver_script.encode("ascii")) - f.flush() - out = subprocess.check_output(["python", f.name]).decode("ascii") + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = subprocess.check_output(["python", f.name]).decode("ascii") - # Make sure the other driver succeeded. - self.assertIn("success", out) + # Make sure the other driver succeeded. + self.assertIn("success", out) - # Make sure that the other error message doesn't show up for this driver. - self.assertEqual(len(ray.error_info()), 1) - self.assertIn(error_string1, - ray.error_info()[0][b"message"].decode("ascii")) + # Make sure that the other error message doesn't show up for this + # driver. + self.assertEqual(len(ray.error_info()), 1) + self.assertIn(error_string1, + ray.error_info()[0][b"message"].decode("ascii")) - ray.worker.cleanup() + ray.worker.cleanup() - def testRemoteFunctionIsolation(self): - # This test will run multiple remote functions with the same names in two - # different drivers. - # Connect a driver to the Ray cluster. - ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE) + def testRemoteFunctionIsolation(self): + # This test will run multiple remote functions with the same names in + # two different drivers. Connect a driver to the Ray cluster. + ray.init(redis_address=self.redis_address, driver_mode=ray.SILENT_MODE) - # Start another driver and make sure that it can define and call its own - # commands with the same names. - driver_script = """ + # Start another driver and make sure that it can define and call its + # own commands with the same names. + driver_script = """ import ray import time ray.init(redis_address="{}") @ray.remote def f(): - return 3 + return 3 @ray.remote def g(x, y): - return 4 + return 4 for _ in range(10000): - result = ray.get([f.remote(), g.remote(0, 0)]) - assert result == [3, 4] + result = ray.get([f.remote(), g.remote(0, 0)]) + assert result == [3, 4] print("success") """.format(self.redis_address) - # Save the driver script as a file so we can call it using subprocess. - with tempfile.NamedTemporaryFile() as f: - f.write(driver_script.encode("ascii")) - f.flush() - out = subprocess.check_output(["python", f.name]).decode("ascii") + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = subprocess.check_output(["python", f.name]).decode("ascii") - @ray.remote - def f(): - return 1 + @ray.remote + def f(): + return 1 - @ray.remote - def g(x): - return 2 + @ray.remote + def g(x): + return 2 - for _ in range(10000): - result = ray.get([f.remote(), g.remote(0)]) - self.assertEqual(result, [1, 2]) + for _ in range(10000): + result = ray.get([f.remote(), g.remote(0)]) + self.assertEqual(result, [1, 2]) - # Make sure the other driver succeeded. - self.assertIn("success", out) + # Make sure the other driver succeeded. + self.assertIn("success", out) - ray.worker.cleanup() + ray.worker.cleanup() class StartRayScriptTest(unittest.TestCase): - def testCallingStartRayHead(self): - # Test that we can call start-ray.sh with various command line parameters. - # TODO(rkn): This test only tests the --head code path. We should also test - # the non-head node code path. + def testCallingStartRayHead(self): + # Test that we can call start-ray.sh with various command line + # parameters. TODO(rkn): This test only tests the --head code path. We + # should also test the non-head node code path. - # Test starting Ray with no arguments. - subprocess.check_output(["ray", "start", "--head"]).decode("ascii") - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with no arguments. + subprocess.check_output(["ray", "start", "--head"]).decode("ascii") + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with a number of workers specified. - subprocess.check_output(["ray", "start", "--head", "--num-workers", - "20"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with a number of workers specified. + subprocess.check_output(["ray", "start", "--head", "--num-workers", + "20"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with a redis port specified. - subprocess.check_output(["ray", "start", "--head", - "--redis-port", "6379"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with a redis port specified. + subprocess.check_output(["ray", "start", "--head", + "--redis-port", "6379"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with a node IP address specified. - subprocess.check_output(["ray", "start", "--head", - "--node-ip-address", "127.0.0.1"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with a node IP address specified. + subprocess.check_output(["ray", "start", "--head", + "--node-ip-address", "127.0.0.1"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with an object manager port specified. - subprocess.check_output(["ray", "start", "--head", - "--object-manager-port", "12345"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with an object manager port specified. + subprocess.check_output(["ray", "start", "--head", + "--object-manager-port", "12345"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with the number of CPUs specified. - subprocess.check_output(["ray", "start", "--head", - "--num-cpus", "100"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with the number of CPUs specified. + subprocess.check_output(["ray", "start", "--head", + "--num-cpus", "100"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with the number of GPUs specified. - subprocess.check_output(["ray", "start", "--head", - "--num-gpus", "100"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with the number of GPUs specified. + subprocess.check_output(["ray", "start", "--head", + "--num-gpus", "100"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with all arguments specified. - subprocess.check_output(["ray", "start", "--head", - "--num-workers", "20", - "--redis-port", "6379", - "--object-manager-port", "12345", - "--num-cpus", "100", - "--num-gpus", "0"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with all arguments specified. + subprocess.check_output(["ray", "start", "--head", + "--num-workers", "20", + "--redis-port", "6379", + "--object-manager-port", "12345", + "--num-cpus", "100", + "--num-gpus", "0"]) + subprocess.Popen(["ray", "stop"]).wait() - # Test starting Ray with invalid arguments. - with self.assertRaises(Exception): - subprocess.check_output(["ray", "start", "--head", - "--redis-address", "127.0.0.1:6379"]) - subprocess.Popen(["ray", "stop"]).wait() + # Test starting Ray with invalid arguments. + with self.assertRaises(Exception): + subprocess.check_output(["ray", "start", "--head", + "--redis-address", "127.0.0.1:6379"]) + subprocess.Popen(["ray", "stop"]).wait() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/recursion_test.py b/test/recursion_test.py index 247781907..a769c52e8 100644 --- a/test/recursion_test.py +++ b/test/recursion_test.py @@ -13,9 +13,9 @@ ray.init() @ray.remote def factorial(n): - if n == 0: - return 1 - return n * ray.get(factorial.remote(n - 1)) + if n == 0: + return 1 + return n * ray.get(factorial.remote(n - 1)) assert ray.get(factorial.remote(0)) == 1 diff --git a/test/runtest.py b/test/runtest.py index d496cc779..4da81d773 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -17,61 +17,67 @@ import ray.test.test_functions as test_functions import ray.test.test_utils if sys.version_info >= (3, 0): - from importlib import reload + from importlib import reload def assert_equal(obj1, obj2): - module_numpy = (type(obj1).__module__ == np.__name__ or - type(obj2).__module__ == np.__name__) - if module_numpy: - empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or - (hasattr(obj2, "shape") and obj2.shape == ())) - if empty_shape: - # This is a special case because currently np.testing.assert_equal fails - # because we do not properly handle different numerical types. - assert obj1 == obj2, ("Objects {} and {} are " - "different.".format(obj1, obj2)) - else: - np.testing.assert_equal(obj1, obj2) - elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): - special_keys = ["_pytype_"] - assert (set(list(obj1.__dict__.keys()) + special_keys) == - set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} and " - "{} are " - "different." - .format(obj1, + module_numpy = (type(obj1).__module__ == np.__name__ or + type(obj2).__module__ == np.__name__) + if module_numpy: + empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or + (hasattr(obj2, "shape") and obj2.shape == ())) + if empty_shape: + # This is a special case because currently np.testing.assert_equal + # fails because we do not properly handle different numerical + # types. + assert obj1 == obj2, ("Objects {} and {} are " + "different.".format(obj1, obj2)) + else: + np.testing.assert_equal(obj1, obj2) + elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): + special_keys = ["_pytype_"] + assert (set(list(obj1.__dict__.keys()) + special_keys) == + set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} " + "and {} are " + "different." + .format( + obj1, obj2)) - for key in obj1.__dict__.keys(): - if key not in special_keys: - assert_equal(obj1.__dict__[key], obj2.__dict__[key]) - elif type(obj1) is dict or type(obj2) is dict: - assert_equal(obj1.keys(), obj2.keys()) - for key in obj1.keys(): - assert_equal(obj1[key], obj2[key]) - elif type(obj1) is list or type(obj2) is list: - assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " - "different lengths.".format(obj1, obj2)) - for i in range(len(obj1)): - assert_equal(obj1[i], obj2[i]) - elif type(obj1) is tuple or type(obj2) is tuple: - assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " - "different lengths.".format(obj1, obj2)) - for i in range(len(obj1)): - assert_equal(obj1[i], obj2[i]) - elif (ray.serialization.is_named_tuple(type(obj1)) or - ray.serialization.is_named_tuple(type(obj2))): - assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples with " - "different lengths.".format(obj1, obj2)) - for i in range(len(obj1)): - assert_equal(obj1[i], obj2[i]) - else: - assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, obj2) + for key in obj1.__dict__.keys(): + if key not in special_keys: + assert_equal(obj1.__dict__[key], obj2.__dict__[key]) + elif type(obj1) is dict or type(obj2) is dict: + assert_equal(obj1.keys(), obj2.keys()) + for key in obj1.keys(): + assert_equal(obj1[key], obj2[key]) + elif type(obj1) is list or type(obj2) is list: + assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " + "different lengths." + .format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif type(obj1) is tuple or type(obj2) is tuple: + assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " + "different lengths." + .format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif (ray.serialization.is_named_tuple(type(obj1)) or + ray.serialization.is_named_tuple(type(obj2))): + assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples " + "with different lengths." + .format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + else: + assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, + obj2) if sys.version_info >= (3, 0): - long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])] + long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])] else: - long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821 + long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821 PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, "a", string.printable, "\u262F", u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True, @@ -91,43 +97,43 @@ COMPLEX_OBJECTS = [ class Foo(object): - def __init__(self, value=0): - self.value = value + def __init__(self, value=0): + self.value = value - def __hash__(self): - return hash(self.value) + def __hash__(self): + return hash(self.value) - def __eq__(self, other): - return other.value == self.value + def __eq__(self, other): + return other.value == self.value class Bar(object): - def __init__(self): - for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS): - setattr(self, "field{}".format(i), val) + def __init__(self): + for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS): + setattr(self, "field{}".format(i), val) class Baz(object): - def __init__(self): - self.foo = Foo() - self.bar = Bar() + def __init__(self): + self.foo = Foo() + self.bar = Bar() - def method(self, arg): - pass + def method(self, arg): + pass class Qux(object): - def __init__(self): - self.objs = [Foo(), Bar(), Baz()] + def __init__(self): + self.objs = [Foo(), Bar(), Baz()] class SubQux(Qux): - def __init__(self): - Qux.__init__(self) + def __init__(self): + Qux.__init__(self) class CustomError(Exception): - pass + pass Point = namedtuple("Point", ["x", "y"]) @@ -154,1512 +160,1530 @@ RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS # Check that the correct version of cloudpickle is installed. try: - import cloudpickle - cloudpickle.dumps(Point) + import cloudpickle + cloudpickle.dumps(Point) except AttributeError: - cloudpickle_command = "pip install --upgrade cloudpickle" - raise Exception("You have an older version of cloudpickle that is not able " - "to serialize namedtuples. Try running " - "\n\n{}\n\n".format(cloudpickle_command)) + cloudpickle_command = "pip install --upgrade cloudpickle" + raise Exception("You have an older version of cloudpickle that is not " + "able to serialize namedtuples. Try running " + "\n\n{}\n\n".format(cloudpickle_command)) class SerializationTest(unittest.TestCase): - def testRecursiveObjects(self): - ray.init(num_workers=0) + def testRecursiveObjects(self): + ray.init(num_workers=0) - class ClassA(object): - pass + class ClassA(object): + pass - # Make a list that contains itself. - l = [] - l.append(l) - # Make an object that contains itself as a field. - a1 = ClassA() - a1.field = a1 - # Make two objects that contain each other as fields. - a2 = ClassA() - a3 = ClassA() - a2.field = a3 - a3.field = a2 - # Make a dictionary that contains itself. - d1 = {} - d1["key"] = d1 - # Create a list of recursive objects. - recursive_objects = [l, a1, a2, a3, d1] + # Make a list that contains itself. + l = [] + l.append(l) + # Make an object that contains itself as a field. + a1 = ClassA() + a1.field = a1 + # Make two objects that contain each other as fields. + a2 = ClassA() + a3 = ClassA() + a2.field = a3 + a3.field = a2 + # Make a dictionary that contains itself. + d1 = {} + d1["key"] = d1 + # Create a list of recursive objects. + recursive_objects = [l, a1, a2, a3, d1] - # Check that exceptions are thrown when we serialize the recursive objects. - for obj in recursive_objects: - self.assertRaises(Exception, lambda: ray.put(obj)) + # Check that exceptions are thrown when we serialize the recursive + # objects. + for obj in recursive_objects: + self.assertRaises(Exception, lambda: ray.put(obj)) - ray.worker.cleanup() + ray.worker.cleanup() - def testPassingArgumentsByValue(self): - ray.init(num_workers=1) + def testPassingArgumentsByValue(self): + ray.init(num_workers=1) - @ray.remote - def f(x): - return x + @ray.remote + def f(x): + return x - # Check that we can pass arguments by value to remote functions and that - # they are uncorrupted. - for obj in RAY_TEST_OBJECTS: - assert_equal(obj, ray.get(f.remote(obj))) + # Check that we can pass arguments by value to remote functions and + # that they are uncorrupted. + for obj in RAY_TEST_OBJECTS: + assert_equal(obj, ray.get(f.remote(obj))) - ray.worker.cleanup() + ray.worker.cleanup() - def testPassingArgumentsByValueOutOfTheBox(self): - ray.init(num_workers=1) + def testPassingArgumentsByValueOutOfTheBox(self): + ray.init(num_workers=1) - @ray.remote - def f(x): - return x + @ray.remote + def f(x): + return x - # Test passing lambdas. + # Test passing lambdas. - def temp(): - return 1 + def temp(): + return 1 - self.assertEqual(ray.get(f.remote(temp))(), 1) - self.assertEqual(ray.get(f.remote(lambda x: x + 1))(3), 4) + self.assertEqual(ray.get(f.remote(temp))(), 1) + self.assertEqual(ray.get(f.remote(lambda x: x + 1))(3), 4) - # Test sets. - self.assertEqual(ray.get(f.remote(set())), set()) - s = set([1, (1, 2, "hi")]) - self.assertEqual(ray.get(f.remote(s)), s) + # Test sets. + self.assertEqual(ray.get(f.remote(set())), set()) + s = set([1, (1, 2, "hi")]) + self.assertEqual(ray.get(f.remote(s)), s) - # Test types. - self.assertEqual(ray.get(f.remote(int)), int) - self.assertEqual(ray.get(f.remote(float)), float) - self.assertEqual(ray.get(f.remote(str)), str) + # Test types. + self.assertEqual(ray.get(f.remote(int)), int) + self.assertEqual(ray.get(f.remote(float)), float) + self.assertEqual(ray.get(f.remote(str)), str) - class Foo(object): - def __init__(self): - pass + class Foo(object): + def __init__(self): + pass - # Make sure that we can put and get a custom type. Note that the result - # won't be "equal" to Foo. - ray.get(ray.put(Foo)) + # Make sure that we can put and get a custom type. Note that the result + # won't be "equal" to Foo. + ray.get(ray.put(Foo)) - ray.worker.cleanup() + ray.worker.cleanup() class WorkerTest(unittest.TestCase): - def testPythonWorkers(self): - # Test the codepath for starting workers from the Python script, instead of - # the local scheduler. This codepath is for debugging purposes only. - num_workers = 4 - ray.worker._init(num_workers=num_workers, - start_workers_from_local_scheduler=False, - start_ray_local=True) + def testPythonWorkers(self): + # Test the codepath for starting workers from the Python script, + # instead of the local scheduler. This codepath is for debugging + # purposes only. + num_workers = 4 + ray.worker._init(num_workers=num_workers, + start_workers_from_local_scheduler=False, + start_ray_local=True) - @ray.remote - def f(x): - return x + @ray.remote + def f(x): + return x - values = ray.get([f.remote(1) for i in range(num_workers * 2)]) - self.assertEqual(values, [1] * (num_workers * 2)) - ray.worker.cleanup() + values = ray.get([f.remote(1) for i in range(num_workers * 2)]) + self.assertEqual(values, [1] * (num_workers * 2)) + ray.worker.cleanup() - def testPutGet(self): - ray.init(num_workers=0) + def testPutGet(self): + ray.init(num_workers=0) - for i in range(100): - value_before = i * 10 ** 6 - objectid = ray.put(value_before) - value_after = ray.get(objectid) - self.assertEqual(value_before, value_after) + for i in range(100): + value_before = i * 10 ** 6 + objectid = ray.put(value_before) + value_after = ray.get(objectid) + self.assertEqual(value_before, value_after) - for i in range(100): - value_before = i * 10 ** 6 * 1.0 - objectid = ray.put(value_before) - value_after = ray.get(objectid) - self.assertEqual(value_before, value_after) + for i in range(100): + value_before = i * 10 ** 6 * 1.0 + objectid = ray.put(value_before) + value_after = ray.get(objectid) + self.assertEqual(value_before, value_after) - for i in range(100): - value_before = "h" * i - objectid = ray.put(value_before) - value_after = ray.get(objectid) - self.assertEqual(value_before, value_after) + for i in range(100): + value_before = "h" * i + objectid = ray.put(value_before) + value_after = ray.get(objectid) + self.assertEqual(value_before, value_after) - for i in range(100): - value_before = [1] * i - objectid = ray.put(value_before) - value_after = ray.get(objectid) - self.assertEqual(value_before, value_after) + for i in range(100): + value_before = [1] * i + objectid = ray.put(value_before) + value_after = ray.get(objectid) + self.assertEqual(value_before, value_after) - ray.worker.cleanup() + ray.worker.cleanup() class APITest(unittest.TestCase): - def init_ray(self, kwargs=None): - if kwargs is None: - kwargs = {} - ray.init(**kwargs) - - def tearDown(self): - ray.worker.cleanup() - - def testRegisterClass(self): - self.init_ray({"num_workers": 2}) - - # Check that putting an object of a class that has not been registered - # throws an exception. - class TempClass(object): - pass - ray.get(ray.put(TempClass())) - - # Note that the below actually returns a dictionary and not a defaultdict. - # This is a bug (https://github.com/ray-project/ray/issues/512). - ray.get(ray.put(defaultdict(lambda: 0))) + def init_ray(self, kwargs=None): + if kwargs is None: + kwargs = {} + ray.init(**kwargs) + + def tearDown(self): + ray.worker.cleanup() + + def testRegisterClass(self): + self.init_ray({"num_workers": 2}) + + # Check that putting an object of a class that has not been registered + # throws an exception. + class TempClass(object): + pass + ray.get(ray.put(TempClass())) + + # Note that the below actually returns a dictionary and not a + # defaultdict. This is a bug + # (https://github.com/ray-project/ray/issues/512). + ray.get(ray.put(defaultdict(lambda: 0))) - # Test passing custom classes into remote functions from the driver. - @ray.remote - def f(x): - return x + # Test passing custom classes into remote functions from the driver. + @ray.remote + def f(x): + return x - foo = ray.get(f.remote(Foo(7))) - self.assertEqual(foo, Foo(7)) + foo = ray.get(f.remote(Foo(7))) + self.assertEqual(foo, Foo(7)) - regex = re.compile(r"\d+\.\d*") - new_regex = ray.get(f.remote(regex)) - # This seems to fail on the system Python 3 that comes with - # Ubuntu, so it is commented out for now: - # self.assertEqual(regex, new_regex) - # Instead, we do this: - self.assertEqual(regex.pattern, new_regex.pattern) - - # Test returning custom classes created on workers. - @ray.remote - def g(): - return SubQux(), Qux() - - subqux, qux = ray.get(g.remote()) - self.assertEqual(subqux.objs[2].foo.value, 0) - - # Test exporting custom class definitions from one worker to another when - # the worker is blocked in a get. - class NewTempClass(object): - def __init__(self, value): - self.value = value - - @ray.remote - def h1(x): - return NewTempClass(x) - - @ray.remote - def h2(x): - return ray.get(h1.remote(x)) - - self.assertEqual(ray.get(h2.remote(10)).value, 10) - - # Test registering multiple classes with the same name. - @ray.remote(num_return_vals=3) - def j(): - class Class0(object): - def method0(self): - pass - - c0 = Class0() - - class Class0(object): - def method1(self): - pass - - c1 = Class0() - - class Class0(object): - def method2(self): - pass - - c2 = Class0() - - return c0, c1, c2 - - results = [] - for _ in range(5): - results += j.remote() - for i in range(len(results) // 3): - c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))]) - - c0.method0() - c1.method1() - c2.method2() - - self.assertFalse(hasattr(c0, "method1")) - self.assertFalse(hasattr(c0, "method2")) - self.assertFalse(hasattr(c1, "method0")) - self.assertFalse(hasattr(c1, "method2")) - self.assertFalse(hasattr(c2, "method0")) - self.assertFalse(hasattr(c2, "method1")) - - @ray.remote - def k(): - class Class0(object): - def method0(self): - pass - - c0 = Class0() - - class Class0(object): - def method1(self): - pass - - c1 = Class0() - - class Class0(object): - def method2(self): - pass - - c2 = Class0() - - return c0, c1, c2 - - results = ray.get([k.remote() for _ in range(5)]) - for c0, c1, c2 in results: - c0.method0() - c1.method1() - c2.method2() - - self.assertFalse(hasattr(c0, "method1")) - self.assertFalse(hasattr(c0, "method2")) - self.assertFalse(hasattr(c1, "method0")) - self.assertFalse(hasattr(c1, "method2")) - self.assertFalse(hasattr(c2, "method0")) - self.assertFalse(hasattr(c2, "method1")) - - def testKeywordArgs(self): - reload(test_functions) - self.init_ray() - - x = test_functions.keyword_fct1.remote(1) - self.assertEqual(ray.get(x), "1 hello") - x = test_functions.keyword_fct1.remote(1, "hi") - self.assertEqual(ray.get(x), "1 hi") - x = test_functions.keyword_fct1.remote(1, b="world") - self.assertEqual(ray.get(x), "1 world") - - x = test_functions.keyword_fct2.remote(a="w", b="hi") - self.assertEqual(ray.get(x), "w hi") - x = test_functions.keyword_fct2.remote(b="hi", a="w") - self.assertEqual(ray.get(x), "w hi") - x = test_functions.keyword_fct2.remote(a="w") - self.assertEqual(ray.get(x), "w world") - x = test_functions.keyword_fct2.remote(b="hi") - self.assertEqual(ray.get(x), "hello hi") - x = test_functions.keyword_fct2.remote("w") - self.assertEqual(ray.get(x), "w world") - x = test_functions.keyword_fct2.remote("w", "hi") - self.assertEqual(ray.get(x), "w hi") - - x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi") - self.assertEqual(ray.get(x), "0 1 w hi") - x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w") - self.assertEqual(ray.get(x), "0 1 w hi") - x = test_functions.keyword_fct3.remote(0, 1, c="w") - self.assertEqual(ray.get(x), "0 1 w world") - x = test_functions.keyword_fct3.remote(0, 1, d="hi") - self.assertEqual(ray.get(x), "0 1 hello hi") - x = test_functions.keyword_fct3.remote(0, 1) - self.assertEqual(ray.get(x), "0 1 hello world") - - # Check that we cannot pass invalid keyword arguments to functions. - @ray.remote - def f1(): - return - - @ray.remote - def f2(x, y=0, z=0): - return - - # Make sure we get an exception if too many arguments are passed in. - with self.assertRaises(Exception): - f1.remote(3) - - with self.assertRaises(Exception): - f1.remote(x=3) - - with self.assertRaises(Exception): - f2.remote(0, w=0) - - # Make sure we get an exception if too many arguments are passed in. - with self.assertRaises(Exception): - f2.remote(1, 2, 3, 4) - - @ray.remote - def f3(x): - return x - - self.assertEqual(ray.get(f3.remote(4)), 4) - - def testVariableNumberOfArgs(self): - reload(test_functions) - self.init_ray() - - x = test_functions.varargs_fct1.remote(0, 1, 2) - self.assertEqual(ray.get(x), "0 1 2") - x = test_functions.varargs_fct2.remote(0, 1, 2) - self.assertEqual(ray.get(x), "1 2") - - self.assertTrue(test_functions.kwargs_exception_thrown) - self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown) - - @ray.remote - def f1(*args): - return args - - @ray.remote - def f2(x, y, *args): - return x, y, args - - self.assertEqual(ray.get(f1.remote()), ()) - self.assertEqual(ray.get(f1.remote(1)), (1,)) - self.assertEqual(ray.get(f1.remote(1, 2, 3)), (1, 2, 3)) - with self.assertRaises(Exception): - f2.remote() - with self.assertRaises(Exception): - f2.remote(1) - self.assertEqual(ray.get(f2.remote(1, 2)), (1, 2, ())) - self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,))) - self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4))) - - def testNoArgs(self): - reload(test_functions) - self.init_ray() - - ray.get(test_functions.no_op.remote()) - - def testDefiningRemoteFunctions(self): - self.init_ray({"num_cpus": 3}) - - # Test that we can define a remote function in the shell. - @ray.remote - def f(x): - return x + 1 - self.assertEqual(ray.get(f.remote(0)), 1) - - # Test that we can redefine the remote function. - @ray.remote - def f(x): - return x + 10 - while True: - val = ray.get(f.remote(0)) - self.assertTrue(val in [1, 10]) - if val == 10: - break - else: - print("Still using old definition of f, trying again.") - - # Test that we can close over plain old data. - data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, - {"a": np.zeros(3)}] - - @ray.remote - def g(): - return data - ray.get(g.remote()) - - # Test that we can close over modules. - @ray.remote - def h(): - return np.zeros([3, 5]) - assert_equal(ray.get(h.remote()), np.zeros([3, 5])) - - @ray.remote - def j(): - return time.time() - ray.get(j.remote()) - - # Test that we can define remote functions that call other remote - # functions. - @ray.remote - def k(x): - return x + 1 - - @ray.remote - def l(x): - return ray.get(k.remote(x)) - - @ray.remote - def m(x): - return ray.get(l.remote(x)) - self.assertEqual(ray.get(k.remote(1)), 2) - self.assertEqual(ray.get(l.remote(1)), 2) - self.assertEqual(ray.get(m.remote(1)), 2) - - def testGetMultiple(self): - self.init_ray() - object_ids = [ray.put(i) for i in range(10)] - self.assertEqual(ray.get(object_ids), list(range(10))) - - # Get a random choice of object IDs with duplicates. - indices = list(np.random.choice(range(10), 5)) - indices += indices - results = ray.get([object_ids[i] for i in indices]) - self.assertEqual(results, indices) - - def testWait(self): - self.init_ray({"num_cpus": 1}) - - @ray.remote - def f(delay): - time.sleep(delay) - return 1 - - objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)] - ready_ids, remaining_ids = ray.wait(objectids) - self.assertEqual(len(ready_ids), 1) - self.assertEqual(len(remaining_ids), 3) - ready_ids, remaining_ids = ray.wait(objectids, num_returns=4) - self.assertEqual(set(ready_ids), set(objectids)) - self.assertEqual(remaining_ids, []) - - objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), f.remote(0.5)] - start_time = time.time() - ready_ids, remaining_ids = ray.wait(objectids, timeout=1750, num_returns=4) - self.assertLess(time.time() - start_time, 2) - self.assertEqual(len(ready_ids), 3) - self.assertEqual(len(remaining_ids), 1) - ray.wait(objectids) - objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), f.remote(0.5)] - start_time = time.time() - ready_ids, remaining_ids = ray.wait(objectids, timeout=5000) - self.assertTrue(time.time() - start_time < 5) - self.assertEqual(len(ready_ids), 1) - self.assertEqual(len(remaining_ids), 3) - - # Verify that calling wait with duplicate object IDs throws an exception. - x = ray.put(1) - self.assertRaises(Exception, lambda: ray.wait([x, x])) - - def testMultipleWaitsAndGets(self): - # It is important to use three workers here, so that the three tasks - # launched in this experiment can run at the same time. - self.init_ray() - - @ray.remote - def f(delay): - time.sleep(delay) - return 1 - - @ray.remote - def g(l): - # The argument l should be a list containing one object ID. - ray.wait([l[0]]) - - @ray.remote - def h(l): - # The argument l should be a list containing one object ID. - ray.get(l[0]) - - # Make sure that multiple wait requests involving the same object ID all - # return. - x = f.remote(1) - ray.get([g.remote([x]), g.remote([x])]) - - # Make sure that multiple get requests involving the same object ID all - # return. - x = f.remote(1) - ray.get([h.remote([x]), h.remote([x])]) - - def testCachingFunctionsToRun(self): - # Test that we export functions to run on all workers before the driver is - # connected. - def f(worker_info): - sys.path.append(1) - ray.worker.global_worker.run_function_on_all_workers(f) - - def f(worker_info): - sys.path.append(2) - ray.worker.global_worker.run_function_on_all_workers(f) - - def g(worker_info): - sys.path.append(3) - ray.worker.global_worker.run_function_on_all_workers(g) - - def f(worker_info): - sys.path.append(4) - ray.worker.global_worker.run_function_on_all_workers(f) - - self.init_ray() - - @ray.remote - def get_state(): - time.sleep(1) - return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1] - - res1 = get_state.remote() - res2 = get_state.remote() - self.assertEqual(ray.get(res1), (1, 2, 3, 4)) - self.assertEqual(ray.get(res2), (1, 2, 3, 4)) - - # Clean up the path on the workers. - def f(worker_info): - sys.path.pop() - sys.path.pop() - sys.path.pop() - sys.path.pop() - ray.worker.global_worker.run_function_on_all_workers(f) - - def testRunningFunctionOnAllWorkers(self): - self.init_ray() - - def f(worker_info): - sys.path.append("fake_directory") - ray.worker.global_worker.run_function_on_all_workers(f) - - @ray.remote - def get_path1(): - return sys.path - self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1]) - - def f(worker_info): - sys.path.pop(-1) - ray.worker.global_worker.run_function_on_all_workers(f) - - # Create a second remote function to guarantee that when we call - # get_path2.remote(), the second function to run will have been run on the - # worker. - @ray.remote - def get_path2(): - return sys.path - self.assertTrue("fake_directory" not in ray.get(get_path2.remote())) - - def testLoggingAPI(self): - self.init_ray({"driver_mode": ray.SILENT_MODE}) - - def events(): - # This is a hack for getting the event log. It is not part of the API. - keys = ray.worker.global_worker.redis_client.keys("event_log:*") - res = [] - for key in keys: - res.extend(ray.worker.global_worker.redis_client.zrange(key, 0, -1)) - return res - - def wait_for_num_events(num_events, timeout=10): - start_time = time.time() - while time.time() - start_time < timeout: - if len(events()) >= num_events: - return - time.sleep(0.1) - print("Timing out of wait.") - - @ray.remote - def test_log_event(): - ray.log_event("event_type1", contents={"key": "val"}) - - @ray.remote - def test_log_span(): - with ray.log_span("event_type2", contents={"key": "val"}): - pass - - # Make sure that we can call ray.log_event in a remote function. - ray.get(test_log_event.remote()) - # Wait for the event to appear in the event log. - wait_for_num_events(1) - self.assertEqual(len(events()), 1) - - # Make sure that we can call ray.log_span in a remote function. - ray.get(test_log_span.remote()) - - # Wait for the events to appear in the event log. - wait_for_num_events(2) - self.assertEqual(len(events()), 2) - - @ray.remote - def test_log_span_exception(): - with ray.log_span("event_type2", contents={"key": "val"}): - raise Exception("This failed.") - - # Make sure that logging a span works if an exception is thrown. - test_log_span_exception.remote() - # Wait for the events to appear in the event log. - wait_for_num_events(3) - self.assertEqual(len(events()), 3) - - def testIdenticalFunctionNames(self): - # Define a bunch of remote functions and make sure that we don't - # accidentally call an older version. - self.init_ray() - - num_calls = 200 - - @ray.remote - def f(): - return 1 - results1 = [f.remote() for _ in range(num_calls)] - - @ray.remote - def f(): - return 2 - results2 = [f.remote() for _ in range(num_calls)] - - @ray.remote - def f(): - return 3 - results3 = [f.remote() for _ in range(num_calls)] - - @ray.remote - def f(): - return 4 - results4 = [f.remote() for _ in range(num_calls)] - - @ray.remote - def f(): - return 5 - results5 = [f.remote() for _ in range(num_calls)] - - self.assertEqual(ray.get(results1), num_calls * [1]) - self.assertEqual(ray.get(results2), num_calls * [2]) - self.assertEqual(ray.get(results3), num_calls * [3]) - self.assertEqual(ray.get(results4), num_calls * [4]) - self.assertEqual(ray.get(results5), num_calls * [5]) - - @ray.remote - def g(): - return 1 - - @ray.remote # noqa: F811 - def g(): - return 2 - - @ray.remote # noqa: F811 - def g(): - return 3 - - @ray.remote # noqa: F811 - def g(): - return 4 - - @ray.remote # noqa: F811 - def g(): - return 5 - - result_values = ray.get([g.remote() for _ in range(num_calls)]) - self.assertEqual(result_values, num_calls * [5]) - - def testIllegalAPICalls(self): - self.init_ray() - - # Verify that we cannot call put on an ObjectID. - x = ray.put(1) - with self.assertRaises(Exception): - ray.put(x) - # Verify that we cannot call get on a regular value. - with self.assertRaises(Exception): - ray.get(3) + regex = re.compile(r"\d+\.\d*") + new_regex = ray.get(f.remote(regex)) + # This seems to fail on the system Python 3 that comes with + # Ubuntu, so it is commented out for now: + # self.assertEqual(regex, new_regex) + # Instead, we do this: + self.assertEqual(regex.pattern, new_regex.pattern) + + # Test returning custom classes created on workers. + @ray.remote + def g(): + return SubQux(), Qux() + + subqux, qux = ray.get(g.remote()) + self.assertEqual(subqux.objs[2].foo.value, 0) + + # Test exporting custom class definitions from one worker to another + # when the worker is blocked in a get. + class NewTempClass(object): + def __init__(self, value): + self.value = value + + @ray.remote + def h1(x): + return NewTempClass(x) + + @ray.remote + def h2(x): + return ray.get(h1.remote(x)) + + self.assertEqual(ray.get(h2.remote(10)).value, 10) + + # Test registering multiple classes with the same name. + @ray.remote(num_return_vals=3) + def j(): + class Class0(object): + def method0(self): + pass + + c0 = Class0() + + class Class0(object): + def method1(self): + pass + + c1 = Class0() + + class Class0(object): + def method2(self): + pass + + c2 = Class0() + + return c0, c1, c2 + + results = [] + for _ in range(5): + results += j.remote() + for i in range(len(results) // 3): + c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))]) + + c0.method0() + c1.method1() + c2.method2() + + self.assertFalse(hasattr(c0, "method1")) + self.assertFalse(hasattr(c0, "method2")) + self.assertFalse(hasattr(c1, "method0")) + self.assertFalse(hasattr(c1, "method2")) + self.assertFalse(hasattr(c2, "method0")) + self.assertFalse(hasattr(c2, "method1")) + + @ray.remote + def k(): + class Class0(object): + def method0(self): + pass + + c0 = Class0() + + class Class0(object): + def method1(self): + pass + + c1 = Class0() + + class Class0(object): + def method2(self): + pass + + c2 = Class0() + + return c0, c1, c2 + + results = ray.get([k.remote() for _ in range(5)]) + for c0, c1, c2 in results: + c0.method0() + c1.method1() + c2.method2() + + self.assertFalse(hasattr(c0, "method1")) + self.assertFalse(hasattr(c0, "method2")) + self.assertFalse(hasattr(c1, "method0")) + self.assertFalse(hasattr(c1, "method2")) + self.assertFalse(hasattr(c2, "method0")) + self.assertFalse(hasattr(c2, "method1")) + + def testKeywordArgs(self): + reload(test_functions) + self.init_ray() + + x = test_functions.keyword_fct1.remote(1) + self.assertEqual(ray.get(x), "1 hello") + x = test_functions.keyword_fct1.remote(1, "hi") + self.assertEqual(ray.get(x), "1 hi") + x = test_functions.keyword_fct1.remote(1, b="world") + self.assertEqual(ray.get(x), "1 world") + + x = test_functions.keyword_fct2.remote(a="w", b="hi") + self.assertEqual(ray.get(x), "w hi") + x = test_functions.keyword_fct2.remote(b="hi", a="w") + self.assertEqual(ray.get(x), "w hi") + x = test_functions.keyword_fct2.remote(a="w") + self.assertEqual(ray.get(x), "w world") + x = test_functions.keyword_fct2.remote(b="hi") + self.assertEqual(ray.get(x), "hello hi") + x = test_functions.keyword_fct2.remote("w") + self.assertEqual(ray.get(x), "w world") + x = test_functions.keyword_fct2.remote("w", "hi") + self.assertEqual(ray.get(x), "w hi") + + x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi") + self.assertEqual(ray.get(x), "0 1 w hi") + x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w") + self.assertEqual(ray.get(x), "0 1 w hi") + x = test_functions.keyword_fct3.remote(0, 1, c="w") + self.assertEqual(ray.get(x), "0 1 w world") + x = test_functions.keyword_fct3.remote(0, 1, d="hi") + self.assertEqual(ray.get(x), "0 1 hello hi") + x = test_functions.keyword_fct3.remote(0, 1) + self.assertEqual(ray.get(x), "0 1 hello world") + + # Check that we cannot pass invalid keyword arguments to functions. + @ray.remote + def f1(): + return + + @ray.remote + def f2(x, y=0, z=0): + return + + # Make sure we get an exception if too many arguments are passed in. + with self.assertRaises(Exception): + f1.remote(3) + + with self.assertRaises(Exception): + f1.remote(x=3) + + with self.assertRaises(Exception): + f2.remote(0, w=0) + + # Make sure we get an exception if too many arguments are passed in. + with self.assertRaises(Exception): + f2.remote(1, 2, 3, 4) + + @ray.remote + def f3(x): + return x + + self.assertEqual(ray.get(f3.remote(4)), 4) + + def testVariableNumberOfArgs(self): + reload(test_functions) + self.init_ray() + + x = test_functions.varargs_fct1.remote(0, 1, 2) + self.assertEqual(ray.get(x), "0 1 2") + x = test_functions.varargs_fct2.remote(0, 1, 2) + self.assertEqual(ray.get(x), "1 2") + + self.assertTrue(test_functions.kwargs_exception_thrown) + self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown) + + @ray.remote + def f1(*args): + return args + + @ray.remote + def f2(x, y, *args): + return x, y, args + + self.assertEqual(ray.get(f1.remote()), ()) + self.assertEqual(ray.get(f1.remote(1)), (1,)) + self.assertEqual(ray.get(f1.remote(1, 2, 3)), (1, 2, 3)) + with self.assertRaises(Exception): + f2.remote() + with self.assertRaises(Exception): + f2.remote(1) + self.assertEqual(ray.get(f2.remote(1, 2)), (1, 2, ())) + self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,))) + self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4))) + + def testNoArgs(self): + reload(test_functions) + self.init_ray() + + ray.get(test_functions.no_op.remote()) + + def testDefiningRemoteFunctions(self): + self.init_ray({"num_cpus": 3}) + + # Test that we can define a remote function in the shell. + @ray.remote + def f(x): + return x + 1 + self.assertEqual(ray.get(f.remote(0)), 1) + + # Test that we can redefine the remote function. + @ray.remote + def f(x): + return x + 10 + while True: + val = ray.get(f.remote(0)) + self.assertTrue(val in [1, 10]) + if val == 10: + break + else: + print("Still using old definition of f, trying again.") + + # Test that we can close over plain old data. + data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, + {"a": np.zeros(3)}] + + @ray.remote + def g(): + return data + ray.get(g.remote()) + + # Test that we can close over modules. + @ray.remote + def h(): + return np.zeros([3, 5]) + assert_equal(ray.get(h.remote()), np.zeros([3, 5])) + + @ray.remote + def j(): + return time.time() + ray.get(j.remote()) + + # Test that we can define remote functions that call other remote + # functions. + @ray.remote + def k(x): + return x + 1 + + @ray.remote + def l(x): + return ray.get(k.remote(x)) + + @ray.remote + def m(x): + return ray.get(l.remote(x)) + self.assertEqual(ray.get(k.remote(1)), 2) + self.assertEqual(ray.get(l.remote(1)), 2) + self.assertEqual(ray.get(m.remote(1)), 2) + + def testGetMultiple(self): + self.init_ray() + object_ids = [ray.put(i) for i in range(10)] + self.assertEqual(ray.get(object_ids), list(range(10))) + + # Get a random choice of object IDs with duplicates. + indices = list(np.random.choice(range(10), 5)) + indices += indices + results = ray.get([object_ids[i] for i in indices]) + self.assertEqual(results, indices) + + def testWait(self): + self.init_ray({"num_cpus": 1}) + + @ray.remote + def f(delay): + time.sleep(delay) + return 1 + + objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), + f.remote(0.5)] + ready_ids, remaining_ids = ray.wait(objectids) + self.assertEqual(len(ready_ids), 1) + self.assertEqual(len(remaining_ids), 3) + ready_ids, remaining_ids = ray.wait(objectids, num_returns=4) + self.assertEqual(set(ready_ids), set(objectids)) + self.assertEqual(remaining_ids, []) + + objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), + f.remote(0.5)] + start_time = time.time() + ready_ids, remaining_ids = ray.wait(objectids, timeout=1750, + num_returns=4) + self.assertLess(time.time() - start_time, 2) + self.assertEqual(len(ready_ids), 3) + self.assertEqual(len(remaining_ids), 1) + ray.wait(objectids) + objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), + f.remote(0.5)] + start_time = time.time() + ready_ids, remaining_ids = ray.wait(objectids, timeout=5000) + self.assertTrue(time.time() - start_time < 5) + self.assertEqual(len(ready_ids), 1) + self.assertEqual(len(remaining_ids), 3) + + # Verify that calling wait with duplicate object IDs throws an + # exception. + x = ray.put(1) + self.assertRaises(Exception, lambda: ray.wait([x, x])) + + def testMultipleWaitsAndGets(self): + # It is important to use three workers here, so that the three tasks + # launched in this experiment can run at the same time. + self.init_ray() + + @ray.remote + def f(delay): + time.sleep(delay) + return 1 + + @ray.remote + def g(l): + # The argument l should be a list containing one object ID. + ray.wait([l[0]]) + + @ray.remote + def h(l): + # The argument l should be a list containing one object ID. + ray.get(l[0]) + + # Make sure that multiple wait requests involving the same object ID + # all return. + x = f.remote(1) + ray.get([g.remote([x]), g.remote([x])]) + + # Make sure that multiple get requests involving the same object ID all + # return. + x = f.remote(1) + ray.get([h.remote([x]), h.remote([x])]) + + def testCachingFunctionsToRun(self): + # Test that we export functions to run on all workers before the driver + # is connected. + def f(worker_info): + sys.path.append(1) + ray.worker.global_worker.run_function_on_all_workers(f) + + def f(worker_info): + sys.path.append(2) + ray.worker.global_worker.run_function_on_all_workers(f) + + def g(worker_info): + sys.path.append(3) + ray.worker.global_worker.run_function_on_all_workers(g) + + def f(worker_info): + sys.path.append(4) + ray.worker.global_worker.run_function_on_all_workers(f) + + self.init_ray() + + @ray.remote + def get_state(): + time.sleep(1) + return sys.path[-4], sys.path[-3], sys.path[-2], sys.path[-1] + + res1 = get_state.remote() + res2 = get_state.remote() + self.assertEqual(ray.get(res1), (1, 2, 3, 4)) + self.assertEqual(ray.get(res2), (1, 2, 3, 4)) + + # Clean up the path on the workers. + def f(worker_info): + sys.path.pop() + sys.path.pop() + sys.path.pop() + sys.path.pop() + ray.worker.global_worker.run_function_on_all_workers(f) + + def testRunningFunctionOnAllWorkers(self): + self.init_ray() + + def f(worker_info): + sys.path.append("fake_directory") + ray.worker.global_worker.run_function_on_all_workers(f) + + @ray.remote + def get_path1(): + return sys.path + self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1]) + + def f(worker_info): + sys.path.pop(-1) + ray.worker.global_worker.run_function_on_all_workers(f) + + # Create a second remote function to guarantee that when we call + # get_path2.remote(), the second function to run will have been run on + # the worker. + @ray.remote + def get_path2(): + return sys.path + self.assertTrue("fake_directory" not in ray.get(get_path2.remote())) + + def testLoggingAPI(self): + self.init_ray({"driver_mode": ray.SILENT_MODE}) + + def events(): + # This is a hack for getting the event log. It is not part of the + # API. + keys = ray.worker.global_worker.redis_client.keys("event_log:*") + res = [] + for key in keys: + res.extend(ray.worker.global_worker.redis_client.zrange(key, 0, + -1)) + return res + + def wait_for_num_events(num_events, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + if len(events()) >= num_events: + return + time.sleep(0.1) + print("Timing out of wait.") + + @ray.remote + def test_log_event(): + ray.log_event("event_type1", contents={"key": "val"}) + + @ray.remote + def test_log_span(): + with ray.log_span("event_type2", contents={"key": "val"}): + pass + + # Make sure that we can call ray.log_event in a remote function. + ray.get(test_log_event.remote()) + # Wait for the event to appear in the event log. + wait_for_num_events(1) + self.assertEqual(len(events()), 1) + + # Make sure that we can call ray.log_span in a remote function. + ray.get(test_log_span.remote()) + + # Wait for the events to appear in the event log. + wait_for_num_events(2) + self.assertEqual(len(events()), 2) + + @ray.remote + def test_log_span_exception(): + with ray.log_span("event_type2", contents={"key": "val"}): + raise Exception("This failed.") + + # Make sure that logging a span works if an exception is thrown. + test_log_span_exception.remote() + # Wait for the events to appear in the event log. + wait_for_num_events(3) + self.assertEqual(len(events()), 3) + + def testIdenticalFunctionNames(self): + # Define a bunch of remote functions and make sure that we don't + # accidentally call an older version. + self.init_ray() + + num_calls = 200 + + @ray.remote + def f(): + return 1 + results1 = [f.remote() for _ in range(num_calls)] + + @ray.remote + def f(): + return 2 + results2 = [f.remote() for _ in range(num_calls)] + + @ray.remote + def f(): + return 3 + results3 = [f.remote() for _ in range(num_calls)] + + @ray.remote + def f(): + return 4 + results4 = [f.remote() for _ in range(num_calls)] + + @ray.remote + def f(): + return 5 + results5 = [f.remote() for _ in range(num_calls)] + + self.assertEqual(ray.get(results1), num_calls * [1]) + self.assertEqual(ray.get(results2), num_calls * [2]) + self.assertEqual(ray.get(results3), num_calls * [3]) + self.assertEqual(ray.get(results4), num_calls * [4]) + self.assertEqual(ray.get(results5), num_calls * [5]) + + @ray.remote + def g(): + return 1 + + @ray.remote # noqa: F811 + def g(): + return 2 + + @ray.remote # noqa: F811 + def g(): + return 3 + + @ray.remote # noqa: F811 + def g(): + return 4 + + @ray.remote # noqa: F811 + def g(): + return 5 + + result_values = ray.get([g.remote() for _ in range(num_calls)]) + self.assertEqual(result_values, num_calls * [5]) + + def testIllegalAPICalls(self): + self.init_ray() + + # Verify that we cannot call put on an ObjectID. + x = ray.put(1) + with self.assertRaises(Exception): + ray.put(x) + # Verify that we cannot call get on a regular value. + with self.assertRaises(Exception): + ray.get(3) class APITestSharded(APITest): - def init_ray(self, kwargs=None): - if kwargs is None: - kwargs = {} - kwargs["start_ray_local"] = True - kwargs["num_redis_shards"] = 20 - kwargs["redirect_output"] = True - ray.worker._init(**kwargs) + def init_ray(self, kwargs=None): + if kwargs is None: + kwargs = {} + kwargs["start_ray_local"] = True + kwargs["num_redis_shards"] = 20 + kwargs["redirect_output"] = True + ray.worker._init(**kwargs) class PythonModeTest(unittest.TestCase): - def testPythonMode(self): - reload(test_functions) - ray.init(driver_mode=ray.PYTHON_MODE) + def testPythonMode(self): + reload(test_functions) + ray.init(driver_mode=ray.PYTHON_MODE) - @ray.remote - def f(): - return np.ones([3, 4, 5]) - xref = f.remote() - # Remote functions should return by value. - assert_equal(xref, np.ones([3, 4, 5])) - # Check that ray.get is the identity. - assert_equal(xref, ray.get(xref)) - y = np.random.normal(size=[11, 12]) - # Check that ray.put is the identity. - assert_equal(y, ray.put(y)) + @ray.remote + def f(): + return np.ones([3, 4, 5]) + xref = f.remote() + # Remote functions should return by value. + assert_equal(xref, np.ones([3, 4, 5])) + # Check that ray.get is the identity. + assert_equal(xref, ray.get(xref)) + y = np.random.normal(size=[11, 12]) + # Check that ray.put is the identity. + assert_equal(y, ray.put(y)) - # Make sure objects are immutable, this example is why we need to copy - # arguments before passing them into remote functions in python mode - aref = test_functions.python_mode_f.remote() - assert_equal(aref, np.array([0, 0])) - bref = test_functions.python_mode_g.remote(aref) - # Make sure python_mode_g does not mutate aref. - assert_equal(aref, np.array([0, 0])) - assert_equal(bref, np.array([1, 0])) + # Make sure objects are immutable, this example is why we need to copy + # arguments before passing them into remote functions in python mode + aref = test_functions.python_mode_f.remote() + assert_equal(aref, np.array([0, 0])) + bref = test_functions.python_mode_g.remote(aref) + # Make sure python_mode_g does not mutate aref. + assert_equal(aref, np.array([0, 0])) + assert_equal(bref, np.array([1, 0])) - ray.worker.cleanup() + ray.worker.cleanup() class UtilsTest(unittest.TestCase): - def testCopyingDirectory(self): - # The functionality being tested here is really multi-node functionality, - # but this test just uses a single node. + def testCopyingDirectory(self): + # The functionality being tested here is really multi-node + # functionality, but this test just uses a single node. - ray.init(num_workers=1) + ray.init(num_workers=1) - source_text = "hello world" + source_text = "hello world" - temp_dir1 = os.path.join(os.path.dirname(__file__), "temp_dir1") - source_dir = os.path.join(temp_dir1, "dir") - source_file = os.path.join(source_dir, "file.txt") - temp_dir2 = os.path.join(os.path.dirname(__file__), "temp_dir2") - target_dir = os.path.join(temp_dir2, "dir") - target_file = os.path.join(target_dir, "file.txt") + temp_dir1 = os.path.join(os.path.dirname(__file__), "temp_dir1") + source_dir = os.path.join(temp_dir1, "dir") + source_file = os.path.join(source_dir, "file.txt") + temp_dir2 = os.path.join(os.path.dirname(__file__), "temp_dir2") + target_dir = os.path.join(temp_dir2, "dir") + target_file = os.path.join(target_dir, "file.txt") - def remove_temporary_files(): - if os.path.exists(temp_dir1): - shutil.rmtree(temp_dir1) - if os.path.exists(temp_dir2): - shutil.rmtree(temp_dir2) + def remove_temporary_files(): + if os.path.exists(temp_dir1): + shutil.rmtree(temp_dir1) + if os.path.exists(temp_dir2): + shutil.rmtree(temp_dir2) - # Remove the relevant files if they are left over from a previous run of - # this test. - remove_temporary_files() + # Remove the relevant files if they are left over from a previous run + # of this test. + remove_temporary_files() - # Create the source files. - os.mkdir(temp_dir1) - os.mkdir(source_dir) - with open(source_file, "w") as f: - f.write(source_text) + # Create the source files. + os.mkdir(temp_dir1) + os.mkdir(source_dir) + with open(source_file, "w") as f: + f.write(source_text) - # Copy the source directory to the target directory. - ray.experimental.copy_directory(source_dir, target_dir) - time.sleep(0.5) + # Copy the source directory to the target directory. + ray.experimental.copy_directory(source_dir, target_dir) + time.sleep(0.5) - # Check that the target files exist and are the same as the source files. - self.assertTrue(os.path.exists(target_dir)) - self.assertTrue(os.path.exists(target_file)) - with open(target_file, "r") as f: - self.assertEqual(f.read(), source_text) + # Check that the target files exist and are the same as the source + # files. + self.assertTrue(os.path.exists(target_dir)) + self.assertTrue(os.path.exists(target_file)) + with open(target_file, "r") as f: + self.assertEqual(f.read(), source_text) - # Remove the relevant files to clean up. - remove_temporary_files() + # Remove the relevant files to clean up. + remove_temporary_files() - ray.worker.cleanup() + ray.worker.cleanup() class ResourcesTest(unittest.TestCase): - def testResourceConstraints(self): - num_workers = 20 - ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2) + def testResourceConstraints(self): + num_workers = 20 + ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2) - # Attempt to wait for all of the workers to start up. - ray.worker.global_worker.run_function_on_all_workers( - lambda worker_info: sys.path.append(worker_info["counter"])) + # Attempt to wait for all of the workers to start up. + ray.worker.global_worker.run_function_on_all_workers( + lambda worker_info: sys.path.append(worker_info["counter"])) - @ray.remote(num_cpus=0) - def get_worker_id(): - time.sleep(1) - return sys.path[-1] - while True: - if len(set(ray.get([get_worker_id.remote() - for _ in range(num_workers)]))) == num_workers: - break + @ray.remote(num_cpus=0) + def get_worker_id(): + time.sleep(1) + return sys.path[-1] + while True: + if len(set(ray.get([get_worker_id.remote() + for _ in range(num_workers)]))) == num_workers: + break - time_buffer = 0.3 + time_buffer = 0.3 - # At most 10 copies of this can run at once. - @ray.remote(num_cpus=1) - def f(n): - time.sleep(n) + # At most 10 copies of this can run at once. + @ray.remote(num_cpus=1) + def f(n): + time.sleep(n) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(10)]) - duration = time.time() - start_time - self.assertLess(duration, 0.5 + time_buffer) - self.assertGreater(duration, 0.5) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(10)]) + duration = time.time() - start_time + self.assertLess(duration, 0.5 + time_buffer) + self.assertGreater(duration, 0.5) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(11)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(11)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - @ray.remote(num_cpus=3) - def f(n): - time.sleep(n) + @ray.remote(num_cpus=3) + def f(n): + time.sleep(n) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(3)]) - duration = time.time() - start_time - self.assertLess(duration, 0.5 + time_buffer) - self.assertGreater(duration, 0.5) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(3)]) + duration = time.time() - start_time + self.assertLess(duration, 0.5 + time_buffer) + self.assertGreater(duration, 0.5) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(4)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(4)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - @ray.remote(num_gpus=1) - def f(n): - time.sleep(n) + @ray.remote(num_gpus=1) + def f(n): + time.sleep(n) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(2)]) - duration = time.time() - start_time - self.assertLess(duration, 0.5 + time_buffer) - self.assertGreater(duration, 0.5) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(2)]) + duration = time.time() - start_time + self.assertLess(duration, 0.5 + time_buffer) + self.assertGreater(duration, 0.5) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(3)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(3)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - start_time = time.time() - ray.get([f.remote(0.5) for _ in range(4)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([f.remote(0.5) for _ in range(4)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - ray.worker.cleanup() + ray.worker.cleanup() - def testMultiResourceConstraints(self): - num_workers = 20 - ray.init(num_workers=num_workers, num_cpus=10, num_gpus=10) + def testMultiResourceConstraints(self): + num_workers = 20 + ray.init(num_workers=num_workers, num_cpus=10, num_gpus=10) - # Attempt to wait for all of the workers to start up. - ray.worker.global_worker.run_function_on_all_workers( - lambda worker_info: sys.path.append(worker_info["counter"])) + # Attempt to wait for all of the workers to start up. + ray.worker.global_worker.run_function_on_all_workers( + lambda worker_info: sys.path.append(worker_info["counter"])) - @ray.remote(num_cpus=0) - def get_worker_id(): - time.sleep(1) - return sys.path[-1] - while True: - if len(set(ray.get([get_worker_id.remote() - for _ in range(num_workers)]))) == num_workers: - break + @ray.remote(num_cpus=0) + def get_worker_id(): + time.sleep(1) + return sys.path[-1] + while True: + if len(set(ray.get([get_worker_id.remote() + for _ in range(num_workers)]))) == num_workers: + break - @ray.remote(num_cpus=1, num_gpus=9) - def f(n): - time.sleep(n) + @ray.remote(num_cpus=1, num_gpus=9) + def f(n): + time.sleep(n) - @ray.remote(num_cpus=9, num_gpus=1) - def g(n): - time.sleep(n) + @ray.remote(num_cpus=9, num_gpus=1) + def g(n): + time.sleep(n) - time_buffer = 0.3 + time_buffer = 0.3 - start_time = time.time() - ray.get([f.remote(0.5), g.remote(0.5)]) - duration = time.time() - start_time - self.assertLess(duration, 0.5 + time_buffer) - self.assertGreater(duration, 0.5) + start_time = time.time() + ray.get([f.remote(0.5), g.remote(0.5)]) + duration = time.time() - start_time + self.assertLess(duration, 0.5 + time_buffer) + self.assertGreater(duration, 0.5) - start_time = time.time() - ray.get([f.remote(0.5), f.remote(0.5)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([f.remote(0.5), f.remote(0.5)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - start_time = time.time() - ray.get([g.remote(0.5), g.remote(0.5)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([g.remote(0.5), g.remote(0.5)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - start_time = time.time() - ray.get([f.remote(0.5), f.remote(0.5), g.remote(0.5), g.remote(0.5)]) - duration = time.time() - start_time - self.assertLess(duration, 1 + time_buffer) - self.assertGreater(duration, 1) + start_time = time.time() + ray.get([f.remote(0.5), f.remote(0.5), g.remote(0.5), g.remote(0.5)]) + duration = time.time() - start_time + self.assertLess(duration, 1 + time_buffer) + self.assertGreater(duration, 1) - ray.worker.cleanup() + ray.worker.cleanup() - def testGPUIDs(self): - num_gpus = 10 - ray.init(num_cpus=10, num_gpus=num_gpus) + def testGPUIDs(self): + num_gpus = 10 + ray.init(num_cpus=10, num_gpus=num_gpus) - @ray.remote(num_gpus=0) - def f0(): - time.sleep(0.1) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 0 - for gpu_id in gpu_ids: - assert gpu_id in range(num_gpus) - return gpu_ids + @ray.remote(num_gpus=0) + def f0(): + time.sleep(0.1) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 0 + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + return gpu_ids - @ray.remote(num_gpus=1) - def f1(): - time.sleep(0.1) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 1 - for gpu_id in gpu_ids: - assert gpu_id in range(num_gpus) - return gpu_ids + @ray.remote(num_gpus=1) + def f1(): + time.sleep(0.1) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 1 + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + return gpu_ids - @ray.remote(num_gpus=2) - def f2(): - time.sleep(0.1) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 2 - for gpu_id in gpu_ids: - assert gpu_id in range(num_gpus) - return gpu_ids + @ray.remote(num_gpus=2) + def f2(): + time.sleep(0.1) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 2 + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + return gpu_ids - @ray.remote(num_gpus=3) - def f3(): - time.sleep(0.1) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 3 - for gpu_id in gpu_ids: - assert gpu_id in range(num_gpus) - return gpu_ids + @ray.remote(num_gpus=3) + def f3(): + time.sleep(0.1) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 3 + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + return gpu_ids - @ray.remote(num_gpus=4) - def f4(): - time.sleep(0.1) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 4 - for gpu_id in gpu_ids: - assert gpu_id in range(num_gpus) - return gpu_ids + @ray.remote(num_gpus=4) + def f4(): + time.sleep(0.1) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 4 + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + return gpu_ids - @ray.remote(num_gpus=5) - def f5(): - time.sleep(0.1) - gpu_ids = ray.get_gpu_ids() - assert len(gpu_ids) == 5 - for gpu_id in gpu_ids: - assert gpu_id in range(num_gpus) - return gpu_ids + @ray.remote(num_gpus=5) + def f5(): + time.sleep(0.1) + gpu_ids = ray.get_gpu_ids() + assert len(gpu_ids) == 5 + for gpu_id in gpu_ids: + assert gpu_id in range(num_gpus) + return gpu_ids - list_of_ids = ray.get([f0.remote() for _ in range(10)]) - self.assertEqual(list_of_ids, 10 * [[]]) + list_of_ids = ray.get([f0.remote() for _ in range(10)]) + self.assertEqual(list_of_ids, 10 * [[]]) - list_of_ids = ray.get([f1.remote() for _ in range(10)]) - set_of_ids = set([tuple(gpu_ids) for gpu_ids in list_of_ids]) - self.assertEqual(set_of_ids, set([(i,) for i in range(10)])) + list_of_ids = ray.get([f1.remote() for _ in range(10)]) + set_of_ids = set([tuple(gpu_ids) for gpu_ids in list_of_ids]) + self.assertEqual(set_of_ids, set([(i,) for i in range(10)])) - list_of_ids = ray.get([f2.remote(), f4.remote(), f4.remote()]) - all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids] - self.assertEqual(set(all_ids), set(range(10))) + list_of_ids = ray.get([f2.remote(), f4.remote(), f4.remote()]) + all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids] + self.assertEqual(set(all_ids), set(range(10))) - remaining = [f5.remote() for _ in range(20)] - for _ in range(10): - t1 = time.time() - ready, remaining = ray.wait(remaining, num_returns=2) - t2 = time.time() - # There are only 10 GPUs, and each task uses 2 GPUs, so there should only - # be 2 tasks scheduled at a given time, so if we wait for 2 tasks to - # finish, then it should take at least 0.1 seconds for each pair of tasks - # to finish. - self.assertGreater(t2 - t1, 0.09) - list_of_ids = ray.get(ready) - all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids] - self.assertEqual(set(all_ids), set(range(10))) + remaining = [f5.remote() for _ in range(20)] + for _ in range(10): + t1 = time.time() + ready, remaining = ray.wait(remaining, num_returns=2) + t2 = time.time() + # There are only 10 GPUs, and each task uses 2 GPUs, so there + # should only be 2 tasks scheduled at a given time, so if we wait + # for 2 tasks to finish, then it should take at least 0.1 seconds + # for each pair of tasks to finish. + self.assertGreater(t2 - t1, 0.09) + list_of_ids = ray.get(ready) + all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids] + self.assertEqual(set(all_ids), set(range(10))) - ray.worker.cleanup() + ray.worker.cleanup() - def testMultipleLocalSchedulers(self): - # This test will define a bunch of tasks that can only be assigned to - # specific local schedulers, and we will check that they are assigned to - # the correct local schedulers. - address_info = ray.worker._init(start_ray_local=True, - num_local_schedulers=3, - num_cpus=[100, 5, 10], - num_gpus=[0, 5, 1]) + def testMultipleLocalSchedulers(self): + # This test will define a bunch of tasks that can only be assigned to + # specific local schedulers, and we will check that they are assigned + # to the correct local schedulers. + address_info = ray.worker._init(start_ray_local=True, + num_local_schedulers=3, + num_cpus=[100, 5, 10], + num_gpus=[0, 5, 1]) - # Define a bunch of remote functions that all return the socket name of the - # plasma store. Since there is a one-to-one correspondence between plasma - # stores and local schedulers (at least right now), this can be used to - # identify which local scheduler the task was assigned to. + # Define a bunch of remote functions that all return the socket name of + # the plasma store. Since there is a one-to-one correspondence between + # plasma stores and local schedulers (at least right now), this can be + # used to identify which local scheduler the task was assigned to. - # This must be run on the zeroth local scheduler. - @ray.remote(num_cpus=11) - def run_on_0(): - return ray.worker.global_worker.plasma_client.store_socket_name + # This must be run on the zeroth local scheduler. + @ray.remote(num_cpus=11) + def run_on_0(): + return ray.worker.global_worker.plasma_client.store_socket_name - # This must be run on the first local scheduler. - @ray.remote(num_gpus=2) - def run_on_1(): - return ray.worker.global_worker.plasma_client.store_socket_name + # This must be run on the first local scheduler. + @ray.remote(num_gpus=2) + def run_on_1(): + return ray.worker.global_worker.plasma_client.store_socket_name - # This must be run on the second local scheduler. - @ray.remote(num_cpus=6, num_gpus=1) - def run_on_2(): - return ray.worker.global_worker.plasma_client.store_socket_name + # This must be run on the second local scheduler. + @ray.remote(num_cpus=6, num_gpus=1) + def run_on_2(): + return ray.worker.global_worker.plasma_client.store_socket_name - # This can be run anywhere. - @ray.remote(num_cpus=0, num_gpus=0) - def run_on_0_1_2(): - return ray.worker.global_worker.plasma_client.store_socket_name + # This can be run anywhere. + @ray.remote(num_cpus=0, num_gpus=0) + def run_on_0_1_2(): + return ray.worker.global_worker.plasma_client.store_socket_name - # This must be run on the first or second local scheduler. - @ray.remote(num_gpus=1) - def run_on_1_2(): - return ray.worker.global_worker.plasma_client.store_socket_name + # This must be run on the first or second local scheduler. + @ray.remote(num_gpus=1) + def run_on_1_2(): + return ray.worker.global_worker.plasma_client.store_socket_name - # This must be run on the zeroth or second local scheduler. - @ray.remote(num_cpus=8) - def run_on_0_2(): - return ray.worker.global_worker.plasma_client.store_socket_name + # This must be run on the zeroth or second local scheduler. + @ray.remote(num_cpus=8) + def run_on_0_2(): + return ray.worker.global_worker.plasma_client.store_socket_name - def run_lots_of_tasks(): - names = [] - results = [] - for i in range(100): - index = np.random.randint(6) - if index == 0: - names.append("run_on_0") - results.append(run_on_0.remote()) - elif index == 1: - names.append("run_on_1") - results.append(run_on_1.remote()) - elif index == 2: - names.append("run_on_2") - results.append(run_on_2.remote()) - elif index == 3: - names.append("run_on_0_1_2") - results.append(run_on_0_1_2.remote()) - elif index == 4: - names.append("run_on_1_2") - results.append(run_on_1_2.remote()) - elif index == 5: - names.append("run_on_0_2") - results.append(run_on_0_2.remote()) - return names, results + def run_lots_of_tasks(): + names = [] + results = [] + for i in range(100): + index = np.random.randint(6) + if index == 0: + names.append("run_on_0") + results.append(run_on_0.remote()) + elif index == 1: + names.append("run_on_1") + results.append(run_on_1.remote()) + elif index == 2: + names.append("run_on_2") + results.append(run_on_2.remote()) + elif index == 3: + names.append("run_on_0_1_2") + results.append(run_on_0_1_2.remote()) + elif index == 4: + names.append("run_on_1_2") + results.append(run_on_1_2.remote()) + elif index == 5: + names.append("run_on_0_2") + results.append(run_on_0_2.remote()) + return names, results - store_names = [object_store_address.name for object_store_address - in address_info["object_store_addresses"]] + store_names = [object_store_address.name for object_store_address + in address_info["object_store_addresses"]] - def validate_names_and_results(names, results): - for name, result in zip(names, ray.get(results)): - if name == "run_on_0": - self.assertIn(result, [store_names[0]]) - elif name == "run_on_1": - self.assertIn(result, [store_names[1]]) - elif name == "run_on_2": - self.assertIn(result, [store_names[2]]) - elif name == "run_on_0_1_2": - self.assertIn(result, [store_names[0], store_names[1], - store_names[2]]) - elif name == "run_on_1_2": - self.assertIn(result, [store_names[1], store_names[2]]) - elif name == "run_on_0_2": - self.assertIn(result, [store_names[0], store_names[2]]) - else: - raise Exception("This should be unreachable.") - self.assertEqual(set(ray.get(results)), set(store_names)) + def validate_names_and_results(names, results): + for name, result in zip(names, ray.get(results)): + if name == "run_on_0": + self.assertIn(result, [store_names[0]]) + elif name == "run_on_1": + self.assertIn(result, [store_names[1]]) + elif name == "run_on_2": + self.assertIn(result, [store_names[2]]) + elif name == "run_on_0_1_2": + self.assertIn(result, [store_names[0], store_names[1], + store_names[2]]) + elif name == "run_on_1_2": + self.assertIn(result, [store_names[1], store_names[2]]) + elif name == "run_on_0_2": + self.assertIn(result, [store_names[0], store_names[2]]) + else: + raise Exception("This should be unreachable.") + self.assertEqual(set(ray.get(results)), set(store_names)) - names, results = run_lots_of_tasks() - validate_names_and_results(names, results) + names, results = run_lots_of_tasks() + validate_names_and_results(names, results) - # Make sure the same thing works when this is nested inside of a task. + # Make sure the same thing works when this is nested inside of a task. - @ray.remote - def run_nested1(): - names, results = run_lots_of_tasks() - return names, results + @ray.remote + def run_nested1(): + names, results = run_lots_of_tasks() + return names, results - @ray.remote - def run_nested2(): - names, results = ray.get(run_nested1.remote()) - return names, results + @ray.remote + def run_nested2(): + names, results = ray.get(run_nested1.remote()) + return names, results - names, results = ray.get(run_nested2.remote()) - validate_names_and_results(names, results) + names, results = ray.get(run_nested2.remote()) + validate_names_and_results(names, results) - ray.worker.cleanup() + ray.worker.cleanup() class WorkerPoolTests(unittest.TestCase): - def tearDown(self): - ray.worker.cleanup() + def tearDown(self): + ray.worker.cleanup() - def testNoWorkers(self): - ray.init(num_workers=0) + def testNoWorkers(self): + ray.init(num_workers=0) - @ray.remote - def f(): - return 1 + @ray.remote + def f(): + return 1 - # Make sure we can call a remote function. This will require starting a new - # worker. - ray.get(f.remote()) + # Make sure we can call a remote function. This will require starting a + # new worker. + ray.get(f.remote()) - ray.get([f.remote() for _ in range(100)]) + ray.get([f.remote() for _ in range(100)]) - def testBlockingTasks(self): - ray.init(num_workers=1) + def testBlockingTasks(self): + ray.init(num_workers=1) - @ray.remote - def f(i, j): - return (i, j) + @ray.remote + def f(i, j): + return (i, j) - @ray.remote - def g(i): - # Each instance of g submits and blocks on the result of another remote - # task. - object_ids = [f.remote(i, j) for j in range(10)] - return ray.get(object_ids) + @ray.remote + def g(i): + # Each instance of g submits and blocks on the result of another + # remote task. + object_ids = [f.remote(i, j) for j in range(10)] + return ray.get(object_ids) - ray.get([g.remote(i) for i in range(100)]) + ray.get([g.remote(i) for i in range(100)]) - @ray.remote - def _sleep(i): - time.sleep(1) - return (i) + @ray.remote + def _sleep(i): + time.sleep(1) + return (i) - @ray.remote - def sleep(): - # Each instance of sleep submits and blocks on the result of another - # remote task, which takes one second to execute. - ray.get([_sleep.remote(i) for i in range(10)]) + @ray.remote + def sleep(): + # Each instance of sleep submits and blocks on the result of + # another remote task, which takes one second to execute. + ray.get([_sleep.remote(i) for i in range(10)]) - ray.get(sleep.remote()) + ray.get(sleep.remote()) - ray.worker.cleanup() + ray.worker.cleanup() - def testMaxCallTasks(self): - ray.init(num_cpus=1) + def testMaxCallTasks(self): + ray.init(num_cpus=1) - @ray.remote(max_calls=1) - def f(): - return os.getpid() + @ray.remote(max_calls=1) + def f(): + return os.getpid() - pid = ray.get(f.remote()) - ray.test.test_utils.wait_for_pid_to_exit(pid) + pid = ray.get(f.remote()) + ray.test.test_utils.wait_for_pid_to_exit(pid) - @ray.remote(max_calls=2) - def f(): - return os.getpid() + @ray.remote(max_calls=2) + def f(): + return os.getpid() - pid1 = ray.get(f.remote()) - pid2 = ray.get(f.remote()) - self.assertEqual(pid1, pid2) - ray.test.test_utils.wait_for_pid_to_exit(pid1) + pid1 = ray.get(f.remote()) + pid2 = ray.get(f.remote()) + self.assertEqual(pid1, pid2) + ray.test.test_utils.wait_for_pid_to_exit(pid1) - ray.worker.cleanup() + ray.worker.cleanup() class SchedulingAlgorithm(unittest.TestCase): - def attempt_to_load_balance(self, remote_function, args, total_tasks, - num_local_schedulers, minimum_count, - num_attempts=20): - attempts = 0 - while attempts < num_attempts: - locations = ray.get([remote_function.remote(*args) - for _ in range(total_tasks)]) - names = set(locations) - counts = [locations.count(name) for name in names] - print("Counts are {}.".format(counts)) - if len(names) == num_local_schedulers and all([count >= minimum_count - for count in counts]): - break - attempts += 1 - self.assertLess(attempts, num_attempts) + def attempt_to_load_balance(self, remote_function, args, total_tasks, + num_local_schedulers, minimum_count, + num_attempts=20): + attempts = 0 + while attempts < num_attempts: + locations = ray.get([remote_function.remote(*args) + for _ in range(total_tasks)]) + names = set(locations) + counts = [locations.count(name) for name in names] + print("Counts are {}.".format(counts)) + if (len(names) == num_local_schedulers and + all([count >= minimum_count for count in counts])): + break + attempts += 1 + self.assertLess(attempts, num_attempts) - def testLoadBalancing(self): - # This test ensures that tasks are being assigned to all local schedulers - # in a roughly equal manner. - num_local_schedulers = 3 - num_cpus = 7 - ray.worker._init(start_ray_local=True, - num_local_schedulers=num_local_schedulers, - num_cpus=num_cpus) + def testLoadBalancing(self): + # This test ensures that tasks are being assigned to all local + # schedulers in a roughly equal manner. + num_local_schedulers = 3 + num_cpus = 7 + ray.worker._init(start_ray_local=True, + num_local_schedulers=num_local_schedulers, + num_cpus=num_cpus) - @ray.remote - def f(): - time.sleep(0.001) - return ray.worker.global_worker.plasma_client.store_socket_name + @ray.remote + def f(): + time.sleep(0.001) + return ray.worker.global_worker.plasma_client.store_socket_name - self.attempt_to_load_balance(f, [], 100, num_local_schedulers, 25) - self.attempt_to_load_balance(f, [], 1000, num_local_schedulers, 250) + self.attempt_to_load_balance(f, [], 100, num_local_schedulers, 25) + self.attempt_to_load_balance(f, [], 1000, num_local_schedulers, 250) - ray.worker.cleanup() + ray.worker.cleanup() - def testLoadBalancingWithDependencies(self): - # This test ensures that tasks are being assigned to all local schedulers - # in a roughly equal manner even when the tasks have dependencies. - num_workers = 3 - num_local_schedulers = 3 - ray.worker._init(start_ray_local=True, num_workers=num_workers, - num_local_schedulers=num_local_schedulers) + def testLoadBalancingWithDependencies(self): + # This test ensures that tasks are being assigned to all local + # schedulers in a roughly equal manner even when the tasks have + # dependencies. + num_workers = 3 + num_local_schedulers = 3 + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers) - @ray.remote - def f(x): - return ray.worker.global_worker.plasma_client.store_socket_name + @ray.remote + def f(x): + return ray.worker.global_worker.plasma_client.store_socket_name - # This object will be local to one of the local schedulers. Make sure this - # doesn't prevent tasks from being scheduled on other local schedulers. - x = ray.put(np.zeros(1000000)) + # This object will be local to one of the local schedulers. Make sure + # this doesn't prevent tasks from being scheduled on other local + # schedulers. + x = ray.put(np.zeros(1000000)) - self.attempt_to_load_balance(f, [x], 100, num_local_schedulers, 25) + self.attempt_to_load_balance(f, [x], 100, num_local_schedulers, 25) - ray.worker.cleanup() + ray.worker.cleanup() def wait_for_num_tasks(num_tasks, timeout=10): - start_time = time.time() - while time.time() - start_time < timeout: - if len(ray.global_state.task_table()) >= num_tasks: - return - time.sleep(0.1) - raise Exception("Timed out while waiting for global state.") + start_time = time.time() + while time.time() - start_time < timeout: + if len(ray.global_state.task_table()) >= num_tasks: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for global state.") def wait_for_num_objects(num_objects, timeout=10): - start_time = time.time() - while time.time() - start_time < timeout: - if len(ray.global_state.object_table()) >= num_objects: - return - time.sleep(0.1) - raise Exception("Timed out while waiting for global state.") + start_time = time.time() + while time.time() - start_time < timeout: + if len(ray.global_state.object_table()) >= num_objects: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for global state.") class GlobalStateAPI(unittest.TestCase): - def testGlobalStateAPI(self): - with self.assertRaises(Exception): - ray.global_state.object_table() + def testGlobalStateAPI(self): + with self.assertRaises(Exception): + ray.global_state.object_table() - with self.assertRaises(Exception): - ray.global_state.task_table() + with self.assertRaises(Exception): + ray.global_state.task_table() - with self.assertRaises(Exception): - ray.global_state.client_table() + with self.assertRaises(Exception): + ray.global_state.client_table() - with self.assertRaises(Exception): - ray.global_state.function_table() + with self.assertRaises(Exception): + ray.global_state.function_table() - with self.assertRaises(Exception): - ray.global_state.log_files() + with self.assertRaises(Exception): + ray.global_state.log_files() - ray.init() + ray.init() - self.assertEqual(ray.global_state.object_table(), dict()) + self.assertEqual(ray.global_state.object_table(), dict()) - ID_SIZE = 20 + ID_SIZE = 20 - driver_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.worker_id) - driver_task_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.current_task_id.id()) + driver_id = ray.experimental.state.binary_to_hex( + ray.worker.global_worker.worker_id) + driver_task_id = ray.experimental.state.binary_to_hex( + ray.worker.global_worker.current_task_id.id()) - # One task is put in the task table which corresponds to this driver. - wait_for_num_tasks(1) - task_table = ray.global_state.task_table() - self.assertEqual(len(task_table), 1) - self.assertEqual(driver_task_id, list(task_table.keys())[0]) - self.assertEqual(task_table[driver_task_id]["State"], - ray.experimental.state.TASK_STATUS_RUNNING) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], - driver_task_id) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], - ID_SIZE * "ff") - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["Args"], []) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["DriverID"], - driver_id) - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"], - ID_SIZE * "ff") - self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"], - []) + # One task is put in the task table which corresponds to this driver. + wait_for_num_tasks(1) + task_table = ray.global_state.task_table() + self.assertEqual(len(task_table), 1) + self.assertEqual(driver_task_id, list(task_table.keys())[0]) + self.assertEqual(task_table[driver_task_id]["State"], + ray.experimental.state.TASK_STATUS_RUNNING) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["TaskID"], + driver_task_id) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["ActorID"], + ID_SIZE * "ff") + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["Args"], []) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["DriverID"], + driver_id) + self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"], + ID_SIZE * "ff") + self.assertEqual((task_table[driver_task_id]["TaskSpec"] + ["ReturnObjectIDs"]), + []) - client_table = ray.global_state.client_table() - node_ip_address = ray.worker.global_worker.node_ip_address - self.assertEqual(len(client_table[node_ip_address]), 3) - manager_client = [c for c in client_table[node_ip_address] - if c["ClientType"] == "plasma_manager"][0] + client_table = ray.global_state.client_table() + node_ip_address = ray.worker.global_worker.node_ip_address + self.assertEqual(len(client_table[node_ip_address]), 3) + manager_client = [c for c in client_table[node_ip_address] + if c["ClientType"] == "plasma_manager"][0] - @ray.remote - def f(*xs): - return 1 + @ray.remote + def f(*xs): + return 1 - x_id = ray.put(1) - result_id = f.remote(1, "hi", x_id) + x_id = ray.put(1) + result_id = f.remote(1, "hi", x_id) - # Wait for one additional task to complete. - start_time = time.time() - while time.time() - start_time < 10: - wait_for_num_tasks(1 + 1) - task_table = ray.global_state.task_table() - self.assertEqual(len(task_table), 1 + 1) - task_id_set = set(task_table.keys()) - task_id_set.remove(driver_task_id) - task_id = list(task_id_set)[0] - if task_table[task_id]["State"] == "DONE": - break - time.sleep(0.1) - function_table = ray.global_state.function_table() - task_spec = task_table[task_id]["TaskSpec"] - self.assertEqual(task_spec["ActorID"], ID_SIZE * "ff") - self.assertEqual(task_spec["Args"], [1, "hi", x_id]) - self.assertEqual(task_spec["DriverID"], driver_id) - self.assertEqual(task_spec["ReturnObjectIDs"], [result_id]) - function_table_entry = function_table[task_spec["FunctionID"]] - self.assertEqual(function_table_entry["Name"], "__main__.f") - self.assertEqual(function_table_entry["DriverID"], driver_id) - self.assertEqual(function_table_entry["Module"], "__main__") + # Wait for one additional task to complete. + start_time = time.time() + while time.time() - start_time < 10: + wait_for_num_tasks(1 + 1) + task_table = ray.global_state.task_table() + self.assertEqual(len(task_table), 1 + 1) + task_id_set = set(task_table.keys()) + task_id_set.remove(driver_task_id) + task_id = list(task_id_set)[0] + if task_table[task_id]["State"] == "DONE": + break + time.sleep(0.1) + function_table = ray.global_state.function_table() + task_spec = task_table[task_id]["TaskSpec"] + self.assertEqual(task_spec["ActorID"], ID_SIZE * "ff") + self.assertEqual(task_spec["Args"], [1, "hi", x_id]) + self.assertEqual(task_spec["DriverID"], driver_id) + self.assertEqual(task_spec["ReturnObjectIDs"], [result_id]) + function_table_entry = function_table[task_spec["FunctionID"]] + self.assertEqual(function_table_entry["Name"], "__main__.f") + self.assertEqual(function_table_entry["DriverID"], driver_id) + self.assertEqual(function_table_entry["Module"], "__main__") - self.assertEqual(task_table[task_id], ray.global_state.task_table(task_id)) + self.assertEqual(task_table[task_id], + ray.global_state.task_table(task_id)) - # Wait for two objects, one for the x_id and one for result_id. - wait_for_num_objects(2) + # Wait for two objects, one for the x_id and one for result_id. + wait_for_num_objects(2) - def wait_for_object_table(): - timeout = 10 - start_time = time.time() - while time.time() - start_time < timeout: + def wait_for_object_table(): + timeout = 10 + start_time = time.time() + while time.time() - start_time < timeout: + object_table = ray.global_state.object_table() + tables_ready = ( + object_table[x_id]["ManagerIDs"] is not None and + object_table[result_id]["ManagerIDs"] is not None) + if tables_ready: + return + time.sleep(0.1) + raise Exception("Timed out while waiting for object table to " + "update.") + + # Wait for the object table to be updated. + wait_for_object_table() object_table = ray.global_state.object_table() - tables_ready = (object_table[x_id]["ManagerIDs"] is not None and - object_table[result_id]["ManagerIDs"] is not None) - if tables_ready: - return - time.sleep(0.1) - raise Exception("Timed out while waiting for object table to update.") + self.assertEqual(len(object_table), 2) - # Wait for the object table to be updated. - wait_for_object_table() - object_table = ray.global_state.object_table() - self.assertEqual(len(object_table), 2) + self.assertEqual(object_table[x_id]["IsPut"], True) + self.assertEqual(object_table[x_id]["TaskID"], driver_task_id) + self.assertEqual(object_table[x_id]["ManagerIDs"], + [manager_client["DBClientID"]]) - self.assertEqual(object_table[x_id]["IsPut"], True) - self.assertEqual(object_table[x_id]["TaskID"], driver_task_id) - self.assertEqual(object_table[x_id]["ManagerIDs"], - [manager_client["DBClientID"]]) + self.assertEqual(object_table[result_id]["IsPut"], False) + self.assertEqual(object_table[result_id]["TaskID"], task_id) + self.assertEqual(object_table[result_id]["ManagerIDs"], + [manager_client["DBClientID"]]) - self.assertEqual(object_table[result_id]["IsPut"], False) - self.assertEqual(object_table[result_id]["TaskID"], task_id) - self.assertEqual(object_table[result_id]["ManagerIDs"], - [manager_client["DBClientID"]]) + self.assertEqual(object_table[x_id], + ray.global_state.object_table(x_id)) + self.assertEqual(object_table[result_id], + ray.global_state.object_table(result_id)) - self.assertEqual(object_table[x_id], ray.global_state.object_table(x_id)) - self.assertEqual(object_table[result_id], - ray.global_state.object_table(result_id)) + ray.worker.cleanup() - ray.worker.cleanup() + def testLogFileAPI(self): + ray.init(redirect_output=True) - def testLogFileAPI(self): - ray.init(redirect_output=True) + message = "unique message" - message = "unique message" + @ray.remote + def f(): + print(message) + # The call to sys.stdout.flush() seems to be necessary when using + # the system Python 2.7 on Ubuntu. + sys.stdout.flush() - @ray.remote - def f(): - print(message) - # The call to sys.stdout.flush() seems to be necessary when using the - # system Python 2.7 on Ubuntu. - sys.stdout.flush() + ray.get(f.remote()) - ray.get(f.remote()) + # Make sure that the message appears in the log files. + start_time = time.time() + found_message = False + while time.time() - start_time < 10: + log_files = ray.global_state.log_files() + for ip, innerdict in log_files.items(): + for filename, contents in innerdict.items(): + contents_str = "".join(contents) + if message in contents_str: + found_message = True + if found_message: + break + time.sleep(0.1) - # Make sure that the message appears in the log files. - start_time = time.time() - found_message = False - while time.time() - start_time < 10: - log_files = ray.global_state.log_files() - for ip, innerdict in log_files.items(): - for filename, contents in innerdict.items(): - contents_str = "".join(contents) - if message in contents_str: - found_message = True - if found_message: - break - time.sleep(0.1) + self.assertEqual(found_message, True) - self.assertEqual(found_message, True) + ray.worker.cleanup() - ray.worker.cleanup() + def testTaskProfileAPI(self): + ray.init(redirect_output=True) - def testTaskProfileAPI(self): - ray.init(redirect_output=True) + @ray.remote + def f(): + return 1 - @ray.remote - def f(): - return 1 + num_calls = 5 + [f.remote() for _ in range(num_calls)] - num_calls = 5 - [f.remote() for _ in range(num_calls)] + # Make sure the event log has the correct number of events. + start_time = time.time() + while time.time() - start_time < 10: + profiles = ray.global_state.task_profiles(start=0, end=time.time()) + limited_profiles = ray.global_state.task_profiles(start=0, + end=time.time(), + num=1) + if len(profiles) == num_calls and len(limited_profiles) == 1: + break + time.sleep(0.1) + self.assertEqual(len(profiles), num_calls) + self.assertEqual(len(limited_profiles), 1) - # Make sure the event log has the correct number of events. - start_time = time.time() - while time.time() - start_time < 10: - profiles = ray.global_state.task_profiles(start=0, end=time.time()) - limited_profiles = ray.global_state.task_profiles(start=0, - end=time.time(), - num=1) - if len(profiles) == num_calls and len(limited_profiles) == 1: - break - time.sleep(0.1) - self.assertEqual(len(profiles), num_calls) - self.assertEqual(len(limited_profiles), 1) + # Make sure that each entry is properly formatted. + for task_id, data in profiles.items(): + self.assertIn("execute_start", data) + self.assertIn("execute_end", data) + self.assertIn("get_arguments_start", data) + self.assertIn("get_arguments_end", data) + self.assertIn("store_outputs_start", data) + self.assertIn("store_outputs_end", data) - # Make sure that each entry is properly formatted. - for task_id, data in profiles.items(): - self.assertIn("execute_start", data) - self.assertIn("execute_end", data) - self.assertIn("get_arguments_start", data) - self.assertIn("get_arguments_end", data) - self.assertIn("store_outputs_start", data) - self.assertIn("store_outputs_end", data) + ray.worker.cleanup() - ray.worker.cleanup() + def testWorkers(self): + num_workers = 3 + ray.init(redirect_output=True, num_cpus=num_workers, + num_workers=num_workers) - def testWorkers(self): - num_workers = 3 - ray.init(redirect_output=True, num_cpus=num_workers, - num_workers=num_workers) + @ray.remote + def f(): + return id(ray.worker.global_worker) - @ray.remote - def f(): - return id(ray.worker.global_worker) + # Wait until all of the workers have started. + worker_ids = set() + while len(worker_ids) != num_workers: + worker_ids = set(ray.get([f.remote() for _ in range(10)])) - # Wait until all of the workers have started. - worker_ids = set() - while len(worker_ids) != num_workers: - worker_ids = set(ray.get([f.remote() for _ in range(10)])) + worker_info = ray.global_state.workers() + self.assertEqual(len(worker_info), num_workers) + for worker_id, info in worker_info.items(): + self.assertEqual(info["node_ip_address"], "127.0.0.1") + self.assertIn("local_scheduler_socket", info) + self.assertIn("plasma_manager_socket", info) + self.assertIn("plasma_store_socket", info) + self.assertIn("stderr_file", info) + self.assertIn("stdout_file", info) - worker_info = ray.global_state.workers() - self.assertEqual(len(worker_info), num_workers) - for worker_id, info in worker_info.items(): - self.assertEqual(info["node_ip_address"], "127.0.0.1") - self.assertIn("local_scheduler_socket", info) - self.assertIn("plasma_manager_socket", info) - self.assertIn("plasma_store_socket", info) - self.assertIn("stderr_file", info) - self.assertIn("stdout_file", info) + ray.worker.cleanup() - ray.worker.cleanup() + def testDumpTraceFile(self): + ray.init(redirect_output=True) - def testDumpTraceFile(self): - ray.init(redirect_output=True) + @ray.remote + def f(): + return 1 - @ray.remote - def f(): - return 1 + @ray.remote + class Foo(object): + def __init__(self): + pass - @ray.remote - class Foo(object): - def __init__(self): - pass + def method(self): + pass - def method(self): - pass + ray.get([f.remote() for _ in range(10)]) + actors = [Foo.remote() for _ in range(5)] + ray.get([actor.method.remote() for actor in actors]) + ray.get([actor.method.remote() for actor in actors]) - ray.get([f.remote() for _ in range(10)]) - actors = [Foo.remote() for _ in range(5)] - ray.get([actor.method.remote() for actor in actors]) - ray.get([actor.method.remote() for actor in actors]) + path = os.path.join("/tmp/ray_test_trace") + ray.global_state.dump_catapult_trace(path) - path = os.path.join("/tmp/ray_test_trace") - ray.global_state.dump_catapult_trace(path) + # TODO(rkn): This test is not perfect because it does not verify that + # the visualization actually renders (e.g., the context of the dumped + # trace could be malformed). - # TODO(rkn): This test is not perfect because it does not verify that the - # visualization actually renders (e.g., the context of the dumped trace - # could be malformed). - - ray.worker.cleanup() + ray.worker.cleanup() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/stress_tests.py b/test/stress_tests.py index 9fd15cf7e..37f6b969e 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -10,552 +10,562 @@ import time class TaskTests(unittest.TestCase): - def testSubmittingTasks(self): - for num_local_schedulers in [1, 4]: - for num_workers_per_scheduler in [4]: - num_workers = num_local_schedulers * num_workers_per_scheduler - ray.worker._init(start_ray_local=True, num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - num_cpus=100) + def testSubmittingTasks(self): + for num_local_schedulers in [1, 4]: + for num_workers_per_scheduler in [4]: + num_workers = num_local_schedulers * num_workers_per_scheduler + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + num_cpus=100) + + @ray.remote + def f(x): + return x + + for _ in range(1): + ray.get([f.remote(1) for _ in range(1000)]) + + for _ in range(10): + ray.get([f.remote(1) for _ in range(100)]) + + for _ in range(100): + ray.get([f.remote(1) for _ in range(10)]) + + for _ in range(1000): + ray.get([f.remote(1) for _ in range(1)]) + + self.assertTrue(ray.services.all_processes_alive()) + ray.worker.cleanup() + + def testDependencies(self): + for num_local_schedulers in [1, 4]: + for num_workers_per_scheduler in [4]: + num_workers = num_local_schedulers * num_workers_per_scheduler + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + num_cpus=100) + + @ray.remote + def f(x): + return x + + x = 1 + for _ in range(1000): + x = f.remote(x) + ray.get(x) + + @ray.remote + def g(*xs): + return 1 + + xs = [g.remote(1)] + for _ in range(100): + xs.append(g.remote(*xs)) + xs.append(g.remote(1)) + ray.get(xs) + + self.assertTrue(ray.services.all_processes_alive()) + ray.worker.cleanup() + + def testSubmittingManyTasks(self): + ray.init() @ray.remote def f(x): - return x + return 1 - for _ in range(1): - ray.get([f.remote(1) for _ in range(1000)]) + def g(n): + x = 1 + for i in range(n): + x = f.remote(x) + return x - for _ in range(10): - ray.get([f.remote(1) for _ in range(100)]) + ray.get([g(1000) for _ in range(100)]) + self.assertTrue(ray.services.all_processes_alive()) + ray.worker.cleanup() - for _ in range(100): - ray.get([f.remote(1) for _ in range(10)]) + def testGettingAndPutting(self): + ray.init(num_workers=1) - for _ in range(1000): - ray.get([f.remote(1) for _ in range(1)]) + for n in range(8): + x = np.zeros(10 ** n) + + for _ in range(100): + ray.put(x) + + x_id = ray.put(x) + for _ in range(1000): + ray.get(x_id) self.assertTrue(ray.services.all_processes_alive()) ray.worker.cleanup() - def testDependencies(self): - for num_local_schedulers in [1, 4]: - for num_workers_per_scheduler in [4]: - num_workers = num_local_schedulers * num_workers_per_scheduler - ray.worker._init(start_ray_local=True, num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - num_cpus=100) + def testGettingManyObjects(self): + ray.init() @ray.remote - def f(x): - return x + def f(): + return 1 - x = 1 - for _ in range(1000): - x = f.remote(x) - ray.get(x) - - @ray.remote - def g(*xs): - return 1 - - xs = [g.remote(1)] - for _ in range(100): - xs.append(g.remote(*xs)) - xs.append(g.remote(1)) - ray.get(xs) + n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster. + l = ray.get([f.remote() for _ in range(n)]) + self.assertEqual(l, n * [1]) self.assertTrue(ray.services.all_processes_alive()) ray.worker.cleanup() - def testSubmittingManyTasks(self): - ray.init() + def testWait(self): + for num_local_schedulers in [1, 4]: + for num_workers_per_scheduler in [4]: + num_workers = num_local_schedulers * num_workers_per_scheduler + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + num_cpus=100) - @ray.remote - def f(x): - return 1 + @ray.remote + def f(x): + return x - def g(n): - x = 1 - for i in range(n): - x = f.remote(x) - return x + x_ids = [f.remote(i) for i in range(100)] + for i in range(len(x_ids)): + ray.wait([x_ids[i]]) + for i in range(len(x_ids) - 1): + ray.wait(x_ids[i:]) - ray.get([g(1000) for _ in range(100)]) - self.assertTrue(ray.services.all_processes_alive()) - ray.worker.cleanup() + @ray.remote + def g(x): + time.sleep(x) - def testGettingAndPutting(self): - ray.init(num_workers=1) + for i in range(1, 5): + x_ids = [g.remote(np.random.uniform(0, i)) + for _ in range(2 * num_workers)] + ray.wait(x_ids, num_returns=len(x_ids)) - for n in range(8): - x = np.zeros(10 ** n) - - for _ in range(100): - ray.put(x) - - x_id = ray.put(x) - for _ in range(1000): - ray.get(x_id) - - self.assertTrue(ray.services.all_processes_alive()) - ray.worker.cleanup() - - def testGettingManyObjects(self): - ray.init() - - @ray.remote - def f(): - return 1 - - n = 10 ** 4 # TODO(pcm): replace by 10 ** 5 once this is faster. - l = ray.get([f.remote() for _ in range(n)]) - self.assertEqual(l, n * [1]) - - self.assertTrue(ray.services.all_processes_alive()) - ray.worker.cleanup() - - def testWait(self): - for num_local_schedulers in [1, 4]: - for num_workers_per_scheduler in [4]: - num_workers = num_local_schedulers * num_workers_per_scheduler - ray.worker._init(start_ray_local=True, num_workers=num_workers, - num_local_schedulers=num_local_schedulers, - num_cpus=100) - - @ray.remote - def f(x): - return x - - x_ids = [f.remote(i) for i in range(100)] - for i in range(len(x_ids)): - ray.wait([x_ids[i]]) - for i in range(len(x_ids) - 1): - ray.wait(x_ids[i:]) - - @ray.remote - def g(x): - time.sleep(x) - - for i in range(1, 5): - x_ids = [g.remote(np.random.uniform(0, i)) - for _ in range(2 * num_workers)] - ray.wait(x_ids, num_returns=len(x_ids)) - - self.assertTrue(ray.services.all_processes_alive()) - ray.worker.cleanup() + self.assertTrue(ray.services.all_processes_alive()) + ray.worker.cleanup() class ReconstructionTests(unittest.TestCase): - num_local_schedulers = 1 + num_local_schedulers = 1 - def setUp(self): - # Start the Redis global state store. - node_ip_address = "127.0.0.1" - redis_address, redis_shards = ray.services.start_redis(node_ip_address) - self.redis_ip_address = ray.services.get_ip_address(redis_address) - self.redis_port = ray.services.get_port(redis_address) - time.sleep(0.1) + def setUp(self): + # Start the Redis global state store. + node_ip_address = "127.0.0.1" + redis_address, redis_shards = ray.services.start_redis(node_ip_address) + self.redis_ip_address = ray.services.get_ip_address(redis_address) + self.redis_port = ray.services.get_port(redis_address) + time.sleep(0.1) - # Start the Plasma store instances with a total of 1GB memory. - self.plasma_store_memory = 10 ** 9 - plasma_addresses = [] - objstore_memory = (self.plasma_store_memory // self.num_local_schedulers) - for i in range(self.num_local_schedulers): - store_stdout_file, store_stderr_file = ray.services.new_log_files( - "plasma_store_{}".format(i), True) - manager_stdout_file, manager_stderr_file = ray.services.new_log_files( - "plasma_manager_{}".format(i), True) - plasma_addresses.append(ray.services.start_objstore( - node_ip_address, redis_address, objstore_memory=objstore_memory, - store_stdout_file=store_stdout_file, - store_stderr_file=store_stderr_file, - manager_stdout_file=manager_stdout_file, - manager_stderr_file=manager_stderr_file)) + # Start the Plasma store instances with a total of 1GB memory. + self.plasma_store_memory = 10 ** 9 + plasma_addresses = [] + objstore_memory = (self.plasma_store_memory // + self.num_local_schedulers) + for i in range(self.num_local_schedulers): + store_stdout_file, store_stderr_file = ray.services.new_log_files( + "plasma_store_{}".format(i), True) + manager_stdout_file, manager_stderr_file = ( + ray.services.new_log_files("plasma_manager_{}" + .format(i), True)) + plasma_addresses.append(ray.services.start_objstore( + node_ip_address, redis_address, + objstore_memory=objstore_memory, + store_stdout_file=store_stdout_file, + store_stderr_file=store_stderr_file, + manager_stdout_file=manager_stdout_file, + manager_stderr_file=manager_stderr_file)) - # Start the rest of the services in the Ray cluster. - address_info = {"redis_address": redis_address, - "redis_shards": redis_shards, - "object_store_addresses": plasma_addresses} - ray.worker._init(address_info=address_info, start_ray_local=True, - num_workers=1, - num_local_schedulers=self.num_local_schedulers, - num_cpus=[1] * self.num_local_schedulers, - redirect_output=True, - driver_mode=ray.SILENT_MODE) + # Start the rest of the services in the Ray cluster. + address_info = {"redis_address": redis_address, + "redis_shards": redis_shards, + "object_store_addresses": plasma_addresses} + ray.worker._init(address_info=address_info, start_ray_local=True, + num_workers=1, + num_local_schedulers=self.num_local_schedulers, + num_cpus=[1] * self.num_local_schedulers, + redirect_output=True, + driver_mode=ray.SILENT_MODE) - def tearDown(self): - self.assertTrue(ray.services.all_processes_alive()) + def tearDown(self): + self.assertTrue(ray.services.all_processes_alive()) - # Determine the IDs of all local schedulers that had a task scheduled or - # submitted. - state = ray.experimental.state.GlobalState() - state._initialize_global_state(self.redis_ip_address, self.redis_port) - tasks = state.task_table() - local_scheduler_ids = set(task["LocalSchedulerID"] for task in - tasks.values()) + # Determine the IDs of all local schedulers that had a task scheduled + # or submitted. + state = ray.experimental.state.GlobalState() + state._initialize_global_state(self.redis_ip_address, self.redis_port) + tasks = state.task_table() + local_scheduler_ids = set(task["LocalSchedulerID"] for task in + tasks.values()) - # Make sure that all nodes in the cluster were used by checking that the - # set of local scheduler IDs that had a task scheduled or submitted is - # equal to the total number of local schedulers started. We add one to the - # total number of local schedulers to account for NIL_LOCAL_SCHEDULER_ID. - # This is the local scheduler ID associated with the driver task, since it - # is not scheduled by a particular local scheduler. - self.assertEqual(len(local_scheduler_ids), - self.num_local_schedulers + 1) + # Make sure that all nodes in the cluster were used by checking that + # the set of local scheduler IDs that had a task scheduled or submitted + # is equal to the total number of local schedulers started. We add one + # to the total number of local schedulers to account for + # NIL_LOCAL_SCHEDULER_ID. This is the local scheduler ID associated + # with the driver task, since it is not scheduled by a particular local + # scheduler. + self.assertEqual(len(local_scheduler_ids), + self.num_local_schedulers + 1) - # Clean up the Ray cluster. - ray.worker.cleanup() + # Clean up the Ray cluster. + ray.worker.cleanup() - def testSimple(self): - # Define the size of one task's return argument so that the combined sum of - # all objects' sizes is at least twice the plasma stores' combined allotted - # memory. - num_objects = 1000 - size = int(self.plasma_store_memory * 1.5 / (num_objects * 8)) + def testSimple(self): + # Define the size of one task's return argument so that the combined + # sum of all objects' sizes is at least twice the plasma stores' + # combined allotted memory. + num_objects = 1000 + size = int(self.plasma_store_memory * 1.5 / (num_objects * 8)) - # Define a remote task with no dependencies, which returns a numpy array of - # the given size. - @ray.remote - def foo(i, size): - array = np.zeros(size) - array[0] = i - return array + # Define a remote task with no dependencies, which returns a numpy + # array of the given size. + @ray.remote + def foo(i, size): + array = np.zeros(size) + array[0] = i + return array - # Launch num_objects instances of the remote task. - args = [] - for i in range(num_objects): - args.append(foo.remote(i, size)) + # Launch num_objects instances of the remote task. + args = [] + for i in range(num_objects): + args.append(foo.remote(i, size)) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get each value again to force reconstruction. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get values sequentially, in chunks. - num_chunks = 4 * self.num_local_schedulers - chunk = num_objects // num_chunks - for i in range(num_chunks): - values = ray.get(args[i * chunk:(i + 1) * chunk]) - del values + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get values sequentially, in chunks. + num_chunks = 4 * self.num_local_schedulers + chunk = num_objects // num_chunks + for i in range(num_chunks): + values = ray.get(args[i * chunk:(i + 1) * chunk]) + del values - def testRecursive(self): - # Define the size of one task's return argument so that the combined sum of - # all objects' sizes is at least twice the plasma stores' combined allotted - # memory. - num_objects = 1000 - size = int(self.plasma_store_memory * 1.5 / (num_objects * 8)) + def testRecursive(self): + # Define the size of one task's return argument so that the combined + # sum of all objects' sizes is at least twice the plasma stores' + # combined allotted memory. + num_objects = 1000 + size = int(self.plasma_store_memory * 1.5 / (num_objects * 8)) - # Define a root task with no dependencies, which returns a numpy array of - # the given size. - @ray.remote - def no_dependency_task(size): - array = np.zeros(size) - return array + # Define a root task with no dependencies, which returns a numpy array + # of the given size. + @ray.remote + def no_dependency_task(size): + array = np.zeros(size) + return array - # Define a task with a single dependency, which returns its one argument. - @ray.remote - def single_dependency(i, arg): - arg = np.copy(arg) - arg[0] = i - return arg + # Define a task with a single dependency, which returns its one + # argument. + @ray.remote + def single_dependency(i, arg): + arg = np.copy(arg) + arg[0] = i + return arg - # Launch num_objects instances of the remote task, each dependent on the - # one before it. - arg = no_dependency_task.remote(size) - args = [] - for i in range(num_objects): - arg = single_dependency.remote(i, arg) - args.append(arg) + # Launch num_objects instances of the remote task, each dependent on + # the one before it. + arg = no_dependency_task.remote(size) + args = [] + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get each value again to force reconstruction. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get 10 values randomly. - for _ in range(10): - i = np.random.randint(num_objects) - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get values sequentially, in chunks. - num_chunks = 4 * self.num_local_schedulers - chunk = num_objects // num_chunks - for i in range(num_chunks): - values = ray.get(args[i * chunk:(i + 1) * chunk]) - del values + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get 10 values randomly. + for _ in range(10): + i = np.random.randint(num_objects) + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get values sequentially, in chunks. + num_chunks = 4 * self.num_local_schedulers + chunk = num_objects // num_chunks + for i in range(num_chunks): + values = ray.get(args[i * chunk:(i + 1) * chunk]) + del values - def testMultipleRecursive(self): - # Define the size of one task's return argument so that the combined sum of - # all objects' sizes is at least twice the plasma stores' combined allotted - # memory. - num_objects = 1000 - size = self.plasma_store_memory * 2 // (num_objects * 8) + def testMultipleRecursive(self): + # Define the size of one task's return argument so that the combined + # sum of all objects' sizes is at least twice the plasma stores' + # combined allotted memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) - # Define a root task with no dependencies, which returns a numpy array of - # the given size. - @ray.remote - def no_dependency_task(size): - array = np.zeros(size) - return array + # Define a root task with no dependencies, which returns a numpy array + # of the given size. + @ray.remote + def no_dependency_task(size): + array = np.zeros(size) + return array - # Define a task with multiple dependencies, which returns its first - # argument. - @ray.remote - def multiple_dependency(i, arg1, arg2, arg3): - arg1 = np.copy(arg1) - arg1[0] = i - return arg1 + # Define a task with multiple dependencies, which returns its first + # argument. + @ray.remote + def multiple_dependency(i, arg1, arg2, arg3): + arg1 = np.copy(arg1) + arg1[0] = i + return arg1 - # Launch num_args instances of the root task. Then launch num_objects - # instances of the multi-dependency remote task, each dependent on the - # num_args tasks before it. - num_args = 3 - args = [] - for i in range(num_args): - arg = no_dependency_task.remote(size) - args.append(arg) - for i in range(num_objects): - args.append(multiple_dependency.remote(i, *args[i:i + num_args])) + # Launch num_args instances of the root task. Then launch num_objects + # instances of the multi-dependency remote task, each dependent on the + # num_args tasks before it. + num_args = 3 + args = [] + for i in range(num_args): + arg = no_dependency_task.remote(size) + args.append(arg) + for i in range(num_objects): + args.append(multiple_dependency.remote(i, *args[i:i + num_args])) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - args = args[num_args:] - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get each value again to force reconstruction. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get 10 values randomly. - for _ in range(10): - i = np.random.randint(num_objects) - value = ray.get(args[i]) - self.assertEqual(value[0], i) + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + args = args[num_args:] + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get 10 values randomly. + for _ in range(10): + i = np.random.randint(num_objects) + value = ray.get(args[i]) + self.assertEqual(value[0], i) - def wait_for_errors(self, error_check): - # Wait for errors from all the nondeterministic tasks. - errors = [] - time_left = 100 - while time_left > 0: - errors = ray.error_info() - if error_check(errors): - break - time_left -= 1 - time.sleep(1) + def wait_for_errors(self, error_check): + # Wait for errors from all the nondeterministic tasks. + errors = [] + time_left = 100 + while time_left > 0: + errors = ray.error_info() + if error_check(errors): + break + time_left -= 1 + time.sleep(1) - # Make sure that enough errors came through. - self.assertTrue(error_check(errors)) - return errors + # Make sure that enough errors came through. + self.assertTrue(error_check(errors)) + return errors - def testNondeterministicTask(self): - # Define the size of one task's return argument so that the combined sum of - # all objects' sizes is at least twice the plasma stores' combined allotted - # memory. - num_objects = 1000 - size = self.plasma_store_memory * 2 // (num_objects * 8) + def testNondeterministicTask(self): + # Define the size of one task's return argument so that the combined + # sum of all objects' sizes is at least twice the plasma stores' + # combined allotted memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) - # Define a nondeterministic remote task with no dependencies, which returns - # a random numpy array of the given size. This task should produce an error - # on the driver if it is ever reexecuted. - @ray.remote - def foo(i, size): - array = np.random.rand(size) - array[0] = i - return array + # Define a nondeterministic remote task with no dependencies, which + # returns a random numpy array of the given size. This task should + # produce an error on the driver if it is ever reexecuted. + @ray.remote + def foo(i, size): + array = np.random.rand(size) + array[0] = i + return array - # Define a deterministic remote task with no dependencies, which returns a - # numpy array of zeros of the given size. - @ray.remote - def bar(i, size): - array = np.zeros(size) - array[0] = i - return array + # Define a deterministic remote task with no dependencies, which + # returns a numpy array of zeros of the given size. + @ray.remote + def bar(i, size): + array = np.zeros(size) + array[0] = i + return array - # Launch num_objects instances, half deterministic and half - # nondeterministic. - args = [] - for i in range(num_objects): - if i % 2 == 0: - args.append(foo.remote(i, size)) - else: - args.append(bar.remote(i, size)) + # Launch num_objects instances, half deterministic and half + # nondeterministic. + args = [] + for i in range(num_objects): + if i % 2 == 0: + args.append(foo.remote(i, size)) + else: + args.append(bar.remote(i, size)) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get each value again to force reconstruction. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) - def error_check(errors): - if self.num_local_schedulers == 1: - # In a single-node setting, each object is evicted and reconstructed - # exactly once, so exactly half the objects will produce an error - # during reconstruction. - min_errors = num_objects // 2 - else: - # In a multinode setting, each object is evicted zero or one times, so - # some of the nondeterministic tasks may not be reexecuted. - min_errors = 1 - return len(errors) >= min_errors - errors = self.wait_for_errors(error_check) - # Make sure all the errors have the correct type. - self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" - for error in errors)) - # Make sure all the errors have the correct function name. - self.assertTrue(all(error[b"data"] == b"__main__.foo" for error in errors)) + def error_check(errors): + if self.num_local_schedulers == 1: + # In a single-node setting, each object is evicted and + # reconstructed exactly once, so exactly half the objects will + # produce an error during reconstruction. + min_errors = num_objects // 2 + else: + # In a multinode setting, each object is evicted zero or one + # times, so some of the nondeterministic tasks may not be + # reexecuted. + min_errors = 1 + return len(errors) >= min_errors + errors = self.wait_for_errors(error_check) + # Make sure all the errors have the correct type. + self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" + for error in errors)) + # Make sure all the errors have the correct function name. + self.assertTrue(all(error[b"data"] == b"__main__.foo" + for error in errors)) - def testPutErrors(self): - # Define the size of one task's return argument so that the combined sum of - # all objects' sizes is at least twice the plasma stores' combined allotted - # memory. - num_objects = 1000 - size = self.plasma_store_memory * 2 // (num_objects * 8) + def testPutErrors(self): + # Define the size of one task's return argument so that the combined + # sum of all objects' sizes is at least twice the plasma stores' + # combined allotted memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) - # Define a task with a single dependency, a numpy array, that returns - # another array. - @ray.remote - def single_dependency(i, arg): - arg = np.copy(arg) - arg[0] = i - return arg + # Define a task with a single dependency, a numpy array, that returns + # another array. + @ray.remote + def single_dependency(i, arg): + arg = np.copy(arg) + arg[0] = i + return arg - # Define a root task that calls `ray.put` to put an argument in the object - # store. - @ray.remote - def put_arg_task(size): - # Launch num_objects instances of the remote task, each dependent on the - # one before it. The first instance of the task takes a numpy array as an - # argument, which is put into the object store. - args = [] - arg = single_dependency.remote(0, np.zeros(size)) - for i in range(num_objects): - arg = single_dependency.remote(i, arg) - args.append(arg) + # Define a root task that calls `ray.put` to put an argument in the + # object store. + @ray.remote + def put_arg_task(size): + # Launch num_objects instances of the remote task, each dependent + # on the one before it. The first instance of the task takes a + # numpy array as an argument, which is put into the object store. + args = [] + arg = single_dependency.remote(0, np.zeros(size)) + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get each value again to force reconstruction. Currently, since we're - # not able to reconstruct `ray.put` objects that were evicted and whose - # originating tasks are still running, this for-loop should hang on its - # first iteration and push an error to the driver. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. Currently, since + # we're not able to reconstruct `ray.put` objects that were evicted + # and whose originating tasks are still running, this for-loop + # should hang on its first iteration and push an error to the + # driver. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) - # Define a root task that calls `ray.put` directly. - @ray.remote - def put_task(size): - # Launch num_objects instances of the remote task, each dependent on the - # one before it. The first instance of the task takes an object ID - # returned by ray.put. - args = [] - arg = ray.put(np.zeros(size)) - for i in range(num_objects): - arg = single_dependency.remote(i, arg) - args.append(arg) + # Define a root task that calls `ray.put` directly. + @ray.remote + def put_task(size): + # Launch num_objects instances of the remote task, each dependent + # on the one before it. The first instance of the task takes an + # object ID returned by ray.put. + args = [] + arg = ray.put(np.zeros(size)) + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) - # Get each value again to force reconstruction. Currently, since we're - # not able to reconstruct `ray.put` objects that were evicted and whose - # originating tasks are still running, this for-loop should hang on its - # first iteration and push an error to the driver. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. Currently, since + # we're not able to reconstruct `ray.put` objects that were evicted + # and whose originating tasks are still running, this for-loop + # should hang on its first iteration and push an error to the + # driver. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) - put_arg_task.remote(size) + put_arg_task.remote(size) - def error_check(errors): - return len(errors) > 1 - errors = self.wait_for_errors(error_check) - # Make sure all the errors have the correct type. - self.assertTrue(all(error[b"type"] == b"put_reconstruction" - for error in errors)) - self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task" - for error in errors)) + def error_check(errors): + return len(errors) > 1 + errors = self.wait_for_errors(error_check) + # Make sure all the errors have the correct type. + self.assertTrue(all(error[b"type"] == b"put_reconstruction" + for error in errors)) + self.assertTrue(all(error[b"data"] == b"__main__.put_arg_task" + for error in errors)) - put_task.remote(size) + put_task.remote(size) - def error_check(errors): - return any(error[b"data"] == b"__main__.put_task" for error in errors) - errors = self.wait_for_errors(error_check) - # Make sure all the errors have the correct type. - self.assertTrue(all(error[b"type"] == b"put_reconstruction" - for error in errors)) - self.assertTrue(any(error[b"data"] == b"__main__.put_task" - for error in errors)) + def error_check(errors): + return any(error[b"data"] == b"__main__.put_task" + for error in errors) + errors = self.wait_for_errors(error_check) + # Make sure all the errors have the correct type. + self.assertTrue(all(error[b"type"] == b"put_reconstruction" + for error in errors)) + self.assertTrue(any(error[b"data"] == b"__main__.put_task" + for error in errors)) - def testDriverPutErrors(self): - # Define the size of one task's return argument so that the combined sum of - # all objects' sizes is at least twice the plasma stores' combined allotted - # memory. - num_objects = 1000 - size = self.plasma_store_memory * 2 // (num_objects * 8) + def testDriverPutErrors(self): + # Define the size of one task's return argument so that the combined + # sum of all objects' sizes is at least twice the plasma stores' + # combined allotted memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) - # Define a task with a single dependency, a numpy array, that returns - # another array. - @ray.remote - def single_dependency(i, arg): - arg = np.copy(arg) - arg[0] = i - return arg + # Define a task with a single dependency, a numpy array, that returns + # another array. + @ray.remote + def single_dependency(i, arg): + arg = np.copy(arg) + arg[0] = i + return arg - # Launch num_objects instances of the remote task, each dependent on the - # one before it. The first instance of the task takes a numpy array as an - # argument, which is put into the object store. - args = [] - arg = single_dependency.remote(0, np.zeros(size)) - for i in range(num_objects): - arg = single_dependency.remote(i, arg) - args.append(arg) - # Get each value to force each task to finish. After some number of gets, - # old values should be evicted. - for i in range(num_objects): - value = ray.get(args[i]) - self.assertEqual(value[0], i) + # Launch num_objects instances of the remote task, each dependent on + # the one before it. The first instance of the task takes a numpy array + # as an argument, which is put into the object store. + args = [] + arg = single_dependency.remote(0, np.zeros(size)) + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) + # Get each value to force each task to finish. After some number of + # gets, old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) - # Get each value starting from the beginning to force reconstruction. - # Currently, since we're not able to reconstruct `ray.put` objects that - # were evicted and whose originating tasks are still running, this - # for-loop should hang on its first iteration and push an error to the - # driver. - ray.worker.global_worker.local_scheduler_client.reconstruct_object( - args[0].id()) + # Get each value starting from the beginning to force reconstruction. + # Currently, since we're not able to reconstruct `ray.put` objects that + # were evicted and whose originating tasks are still running, this + # for-loop should hang on its first iteration and push an error to the + # driver. + ray.worker.global_worker.local_scheduler_client.reconstruct_object( + args[0].id()) - def error_check(errors): - return len(errors) > 1 - errors = self.wait_for_errors(error_check) - self.assertTrue(all(error[b"type"] == b"put_reconstruction" - for error in errors)) - self.assertTrue(all(error[b"data"] == b"Driver" for error in errors)) + def error_check(errors): + return len(errors) > 1 + errors = self.wait_for_errors(error_check) + self.assertTrue(all(error[b"type"] == b"put_reconstruction" + for error in errors)) + self.assertTrue(all(error[b"data"] == b"Driver" for error in errors)) class ReconstructionTestsMultinode(ReconstructionTests): - # Run the same tests as the single-node suite, but with 4 local schedulers, - # one worker each. - num_local_schedulers = 4 + # Run the same tests as the single-node suite, but with 4 local schedulers, + # one worker each. + num_local_schedulers = 4 # NOTE(swang): This test tries to launch 1000 workers and breaks. # class WorkerPoolTests(unittest.TestCase): @@ -581,4 +591,4 @@ class ReconstructionTestsMultinode(ReconstructionTests): if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2) diff --git a/test/tensorflow_test.py b/test/tensorflow_test.py index 55d7dfe27..f83e7283d 100644 --- a/test/tensorflow_test.py +++ b/test/tensorflow_test.py @@ -10,218 +10,225 @@ import ray def make_linear_network(w_name=None, b_name=None): - # Define the inputs. - x_data = tf.placeholder(tf.float32, shape=[100]) - y_data = tf.placeholder(tf.float32, shape=[100]) - # Define the weights and computation. - w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name=w_name) - b = tf.Variable(tf.zeros([1]), name=b_name) - y = w * x_data + b - # Return the loss and weight initializer. - return (tf.reduce_mean(tf.square(y - y_data)), - tf.global_variables_initializer(), x_data, y_data) + # Define the inputs. + x_data = tf.placeholder(tf.float32, shape=[100]) + y_data = tf.placeholder(tf.float32, shape=[100]) + # Define the weights and computation. + w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name=w_name) + b = tf.Variable(tf.zeros([1]), name=b_name) + y = w * x_data + b + # Return the loss and weight initializer. + return (tf.reduce_mean(tf.square(y - y_data)), + tf.global_variables_initializer(), x_data, y_data) class NetActor(object): - def __init__(self): - # Uses a separate graph for each network. - with tf.Graph().as_default(): - # Create the network. - loss, init, _, _ = make_linear_network() - sess = tf.Session() - # Additional code for setting and getting the weights. - variables = ray.experimental.TensorFlowVariables(loss, sess) - # Return all of the data needed to use the network. - self.values = [variables, init, sess] - sess.run(init) + def __init__(self): + # Uses a separate graph for each network. + with tf.Graph().as_default(): + # Create the network. + loss, init, _, _ = make_linear_network() + sess = tf.Session() + # Additional code for setting and getting the weights. + variables = ray.experimental.TensorFlowVariables(loss, sess) + # Return all of the data needed to use the network. + self.values = [variables, init, sess] + sess.run(init) - def set_and_get_weights(self, weights): - self.values[0].set_weights(weights) - return self.values[0].get_weights() + def set_and_get_weights(self, weights): + self.values[0].set_weights(weights) + return self.values[0].get_weights() - def get_weights(self): - return self.values[0].get_weights() + def get_weights(self): + return self.values[0].get_weights() class TrainActor(object): - def __init__(self): - # Almost the same as above, but now returns the placeholders and gradient. - with tf.Graph().as_default(): - loss, init, x_data, y_data = make_linear_network() - sess = tf.Session() - variables = ray.experimental.TensorFlowVariables(loss, sess) - optimizer = tf.train.GradientDescentOptimizer(0.9) - grads = optimizer.compute_gradients(loss) - train = optimizer.apply_gradients(grads) - self.values = [loss, variables, init, sess, grads, train, [x_data, y_data]] - sess.run(init) + def __init__(self): + # Almost the same as above, but now returns the placeholders and + # gradient. + with tf.Graph().as_default(): + loss, init, x_data, y_data = make_linear_network() + sess = tf.Session() + variables = ray.experimental.TensorFlowVariables(loss, sess) + optimizer = tf.train.GradientDescentOptimizer(0.9) + grads = optimizer.compute_gradients(loss) + train = optimizer.apply_gradients(grads) + self.values = [loss, variables, init, sess, grads, train, + [x_data, y_data]] + sess.run(init) - def training_step(self, weights): - _, variables, _, sess, grads, _, placeholders = self.values - variables.set_weights(weights) - return sess.run([grad[0] for grad in grads], - feed_dict=dict(zip(placeholders, [[1] * 100, [2] * 100]))) + def training_step(self, weights): + _, variables, _, sess, grads, _, placeholders = self.values + variables.set_weights(weights) + return sess.run([grad[0] for grad in grads], + feed_dict=dict(zip(placeholders, + [[1] * 100, [2] * 100]))) - def get_weights(self): - return self.values[1].get_weights() + def get_weights(self): + return self.values[1].get_weights() class TensorFlowTest(unittest.TestCase): - def testTensorFlowVariables(self): - ray.init(num_workers=2) + def testTensorFlowVariables(self): + ray.init(num_workers=2) - sess = tf.Session() - loss, init, _, _ = make_linear_network() - sess.run(init) + sess = tf.Session() + loss, init, _, _ = make_linear_network() + sess.run(init) - variables = ray.experimental.TensorFlowVariables(loss, sess) - weights = variables.get_weights() + variables = ray.experimental.TensorFlowVariables(loss, sess) + weights = variables.get_weights() - for (name, val) in weights.items(): - weights[name] += 1.0 + for (name, val) in weights.items(): + weights[name] += 1.0 - variables.set_weights(weights) - self.assertEqual(weights, variables.get_weights()) + variables.set_weights(weights) + self.assertEqual(weights, variables.get_weights()) - loss2, init2, _, _ = make_linear_network("w", "b") - sess.run(init2) + loss2, init2, _, _ = make_linear_network("w", "b") + sess.run(init2) - variables2 = ray.experimental.TensorFlowVariables(loss2, sess) - weights2 = variables2.get_weights() + variables2 = ray.experimental.TensorFlowVariables(loss2, sess) + weights2 = variables2.get_weights() - for (name, val) in weights2.items(): - weights2[name] += 2.0 + for (name, val) in weights2.items(): + weights2[name] += 2.0 - variables2.set_weights(weights2) - self.assertEqual(weights2, variables2.get_weights()) + variables2.set_weights(weights2) + self.assertEqual(weights2, variables2.get_weights()) - flat_weights = variables2.get_flat() + 2.0 - variables2.set_flat(flat_weights) - assert_almost_equal(flat_weights, variables2.get_flat()) + flat_weights = variables2.get_flat() + 2.0 + variables2.set_flat(flat_weights) + assert_almost_equal(flat_weights, variables2.get_flat()) - variables3 = ray.experimental.TensorFlowVariables(loss2) - self.assertEqual(variables3.sess, None) - sess = tf.Session() - variables3.set_session(sess) - self.assertEqual(variables3.sess, sess) + variables3 = ray.experimental.TensorFlowVariables(loss2) + self.assertEqual(variables3.sess, None) + sess = tf.Session() + variables3.set_session(sess) + self.assertEqual(variables3.sess, sess) - ray.worker.cleanup() + ray.worker.cleanup() - # Test that the variable names for the two different nets are not - # modified by TensorFlow to be unique (i.e. they should already - # be unique because of the variable prefix). - def testVariableNameCollision(self): - ray.init(num_workers=2) + # Test that the variable names for the two different nets are not + # modified by TensorFlow to be unique (i.e. they should already + # be unique because of the variable prefix). + def testVariableNameCollision(self): + ray.init(num_workers=2) - net1 = NetActor() - net2 = NetActor() + net1 = NetActor() + net2 = NetActor() - # This is checking that the variable names of the two nets are the same, - # i.e. that the names in the weight dictionaries are the same - net1.values[0].set_weights(net2.values[0].get_weights()) + # This is checking that the variable names of the two nets are the + # same, i.e. that the names in the weight dictionaries are the same + net1.values[0].set_weights(net2.values[0].get_weights()) - ray.worker.cleanup() + ray.worker.cleanup() - # Test that different networks on the same worker are independent and - # we can get/set their weights without any interaction. - def testNetworksIndependent(self): - # Note we use only one worker to ensure that all of the remote functions - # run on the same worker. - ray.init(num_workers=1) - net1 = NetActor() - net2 = NetActor() + # Test that different networks on the same worker are independent and + # we can get/set their weights without any interaction. + def testNetworksIndependent(self): + # Note we use only one worker to ensure that all of the remote + # functions run on the same worker. + ray.init(num_workers=1) + net1 = NetActor() + net2 = NetActor() - # Make sure the two networks have different weights. TODO(rkn): Note that - # equality comparisons of numpy arrays normally does not work. This only - # works because at the moment they have size 1. - weights1 = net1.get_weights() - weights2 = net2.get_weights() - self.assertNotEqual(weights1, weights2) + # Make sure the two networks have different weights. TODO(rkn): Note + # that equality comparisons of numpy arrays normally does not work. + # This only works because at the moment they have size 1. + weights1 = net1.get_weights() + weights2 = net2.get_weights() + self.assertNotEqual(weights1, weights2) - # Set the weights and get the weights, and make sure they are unchanged. - new_weights1 = net1.set_and_get_weights(weights1) - new_weights2 = net2.set_and_get_weights(weights2) - self.assertEqual(weights1, new_weights1) - self.assertEqual(weights2, new_weights2) + # Set the weights and get the weights, and make sure they are + # unchanged. + new_weights1 = net1.set_and_get_weights(weights1) + new_weights2 = net2.set_and_get_weights(weights2) + self.assertEqual(weights1, new_weights1) + self.assertEqual(weights2, new_weights2) - # Swap the weights. - new_weights1 = net2.set_and_get_weights(weights1) - new_weights2 = net1.set_and_get_weights(weights2) - self.assertEqual(weights1, new_weights1) - self.assertEqual(weights2, new_weights2) + # Swap the weights. + new_weights1 = net2.set_and_get_weights(weights1) + new_weights2 = net1.set_and_get_weights(weights2) + self.assertEqual(weights1, new_weights1) + self.assertEqual(weights2, new_weights2) - ray.worker.cleanup() + ray.worker.cleanup() - # This test creates an additional network on the driver so that the - # tensorflow variables on the driver and the worker differ. - def testNetworkDriverWorkerIndependent(self): - ray.init(num_workers=1) + # This test creates an additional network on the driver so that the + # tensorflow variables on the driver and the worker differ. + def testNetworkDriverWorkerIndependent(self): + ray.init(num_workers=1) - # Create a network on the driver locally. - sess1 = tf.Session() - loss1, init1, _, _ = make_linear_network() - ray.experimental.TensorFlowVariables(loss1, sess1) - sess1.run(init1) + # Create a network on the driver locally. + sess1 = tf.Session() + loss1, init1, _, _ = make_linear_network() + ray.experimental.TensorFlowVariables(loss1, sess1) + sess1.run(init1) - net2 = ray.remote(NetActor).remote() - weights2 = ray.get(net2.get_weights.remote()) + net2 = ray.remote(NetActor).remote() + weights2 = ray.get(net2.get_weights.remote()) - new_weights2 = ray.get(net2.set_and_get_weights.remote( - net2.get_weights.remote())) - self.assertEqual(weights2, new_weights2) + new_weights2 = ray.get(net2.set_and_get_weights.remote( + net2.get_weights.remote())) + self.assertEqual(weights2, new_weights2) - ray.worker.cleanup() + ray.worker.cleanup() - def testVariablesControlDependencies(self): - ray.init(num_workers=1) + def testVariablesControlDependencies(self): + ray.init(num_workers=1) - # Creates a network and appends a momentum optimizer. - sess = tf.Session() - loss, init, _, _ = make_linear_network() - minimizer = tf.train.MomentumOptimizer(0.9, 0.9).minimize(loss) - net_vars = ray.experimental.TensorFlowVariables(minimizer, sess) - sess.run(init) + # Creates a network and appends a momentum optimizer. + sess = tf.Session() + loss, init, _, _ = make_linear_network() + minimizer = tf.train.MomentumOptimizer(0.9, 0.9).minimize(loss) + net_vars = ray.experimental.TensorFlowVariables(minimizer, sess) + sess.run(init) - # Tests if all variables are properly retrieved, 2 variables and 2 momentum - # variables. - self.assertEqual(len(net_vars.variables.items()), 4) + # Tests if all variables are properly retrieved, 2 variables and 2 + # momentum variables. + self.assertEqual(len(net_vars.variables.items()), 4) - ray.worker.cleanup() + ray.worker.cleanup() - def testRemoteTrainingStep(self): - ray.init(num_workers=1) + def testRemoteTrainingStep(self): + ray.init(num_workers=1) - net = ray.remote(TrainActor).remote() - ray.get(net.training_step.remote(net.get_weights.remote())) + net = ray.remote(TrainActor).remote() + ray.get(net.training_step.remote(net.get_weights.remote())) - ray.worker.cleanup() + ray.worker.cleanup() - def testRemoteTrainingLoss(self): - ray.init(num_workers=2) + def testRemoteTrainingLoss(self): + ray.init(num_workers=2) - net = ray.remote(TrainActor).remote() - loss, variables, _, sess, grads, train, placeholders = TrainActor().values + net = ray.remote(TrainActor).remote() + (loss, variables, _, sess, grads, + train, placeholders) = TrainActor().values - before_acc = sess.run(loss, feed_dict=dict(zip(placeholders, - [[2] * 100, [4] * 100]))) + before_acc = sess.run(loss, + feed_dict=dict(zip(placeholders, + [[2] * 100, [4] * 100]))) - for _ in range(3): - gradients_list = ray.get( - [net.training_step.remote(variables.get_weights()) - for _ in range(2)]) - mean_grads = [sum([gradients[i] for gradients in gradients_list]) / - len(gradients_list) for i in range(len(gradients_list[0]))] - feed_dict = {grad[0]: mean_grad for (grad, mean_grad) - in zip(grads, mean_grads)} - sess.run(train, feed_dict=feed_dict) - after_acc = sess.run(loss, feed_dict=dict(zip(placeholders, - [[2] * 100, [4] * 100]))) - self.assertTrue(before_acc < after_acc) - ray.worker.cleanup() + for _ in range(3): + gradients_list = ray.get( + [net.training_step.remote(variables.get_weights()) + for _ in range(2)]) + mean_grads = [sum([gradients[i] for gradients in gradients_list]) / + len(gradients_list) for i + in range(len(gradients_list[0]))] + feed_dict = {grad[0]: mean_grad for (grad, mean_grad) + in zip(grads, mean_grads)} + sess.run(train, feed_dict=feed_dict) + after_acc = sess.run(loss, feed_dict=dict(zip(placeholders, + [[2] * 100, [4] * 100]))) + self.assertTrue(before_acc < after_acc) + ray.worker.cleanup() if __name__ == "__main__": - unittest.main(verbosity=2) + unittest.main(verbosity=2)