diff --git a/scripts/cluster.py b/scripts/cluster.py index ccd379b5b..59a27e339 100644 --- a/scripts/cluster.py +++ b/scripts/cluster.py @@ -165,7 +165,7 @@ class RayCluster(object): start_workers_command = """ cd "{}"; source ../setup-env.sh; - python -c "import ray; ray.services.start_node(\\\"{}:10001\\\", \\\"{}\\\", {}, user_source_directory=\\\"{}\\\")" > start_workers.out 2> start_workers.err < /dev/null & + python -c "import ray; ray.services.start_node(\\\"{}:10001\\\", \\\"{}\\\", {}, user_source_directory={})" > start_workers.out 2> start_workers.err < /dev/null & """.format(scripts_directory, self.node_private_ip_addresses[0], self.node_private_ip_addresses[i], num_workers_per_node, remote_user_source_directory_str) start_workers_commands.append(start_workers_command) self._run_command_over_ssh_on_all_nodes_in_parallel(start_workers_commands) @@ -213,8 +213,8 @@ class RayCluster(object): change_branch_command = "git checkout -f {}".format(branch) if branch is not None else "" update_cluster_command = """ cd "{}" && - {} git fetch && + {} git reset --hard "@{{upstream}}" -- && (make -C "./build" clean || rm -rf "./build") && ./build.sh @@ -244,7 +244,7 @@ class RayCluster(object): raise Exception("Directory {} does not exist.".format(user_source_directory)) # If user_source_directory is "/a/b/c", then local_directory_name is "c". local_directory_name = os.path.split(os.path.realpath(user_source_directory))[1] - remote_directory = os.path.join("user_source_files", local_directory_name) + remote_directory = os.path.join(self.installation_directory, "user_source_files", local_directory_name) # Remove and recreate the directory on the node. recreate_directory_command = """ rm -r "{}"; @@ -298,10 +298,9 @@ if __name__ == "__main__": args = parser.parse_args() username = args.username key_file = args.key_file - # Install Ray in the user's home directory on the cluster. - installation_directory = "$HOME" node_ip_addresses = [] node_private_ip_addresses = [] + # Check if the IP addresses in the nodes file are valid. for line in open(args.nodes).readlines(): parts = line.split(",") ip_address = str(parts[0].strip()) @@ -313,5 +312,14 @@ if __name__ == "__main__": raise Exception("Each line in the nodes file must have either one or two ip addresses.") node_ip_addresses.append(ip_address) node_private_ip_addresses.append(private_ip_address) + # This command finds the home directory on the cluster. That directory will be + # used for installing Ray. Note that single quotes around 'echo $HOME' are + # important. If you use double quotes, then the $HOME environment variable + # will be expanded locally instead of remotely. + echo_home_command = "ssh -i {} {}@{} 'echo $HOME'".format(key_file, username, node_ip_addresses[0]) + installation_directory = subprocess.check_output(echo_home_command, shell=True).strip() + print "Using '{}' as the home directory on the cluster.".format(installation_directory) + # Create the Raycluster object. cluster = RayCluster(node_ip_addresses, node_private_ip_addresses, username, key_file, installation_directory) + # Drop into an IPython shell. IPython.embed()