Remove legacy Ray code. (#3121)

* Remove legacy Ray code.

* Fix cmake and simplify monitor.

* Fix linting

* Updates

* Fix

* Implement some methods.

* Remove more plasma manager references.

* Fix

* Linting

* Fix

* Fix

* Make sure class IDs are strings.

* Some path fixes

* Fix

* Path fixes and update arrow

* Fixes.

* linting

* Fixes

* Java fixes

* Some java fixes

* TaskLanguage -> Language

* Minor

* Fix python test and remove unused method signature.

* Fix java tests

* Fix jenkins tests

* Remove commented out code.
This commit is contained in:
Robert Nishihara
2018-10-26 13:36:58 -07:00
committed by Philipp Moritz
parent 055daf17a0
commit 658c14282c
289 changed files with 2460 additions and 40708 deletions
+2 -3
View File
@@ -1,6 +1,5 @@
BasedOnStyle: Chromium
ColumnLimit: 80
BasedOnStyle: Google
ColumnLimit: 90
DerivePointerAlignment: false
IndentCaseLabels: false
PointerAlignment: Right
SpaceAfterCStyleCast: true
-15
View File
@@ -4,22 +4,10 @@
/python/build
/python/dist
/python/flatbuffers-1.7.1/
/src/common/thirdparty/redis
/src/thirdparty/arrow
/flatbuffers-1.7.1/
/src/thirdparty/boost/
/src/thirdparty/boost_1_65_1/
/src/thirdparty/boost_1_60_0/
/src/thirdparty/catapult/
/src/thirdparty/flatbuffers/
/src/thirdparty/parquet-cpp
/thirdparty/pkg/
# Files generated by flatc should be ignored
/src/common/format/*.py
/src/common/format/*_generated.h
/src/plasma/format/
/src/local_scheduler/format/*_generated.h
/src/ray/gcs/format/*_generated.h
/src/ray/object_manager/format/*_generated.h
/src/ray/raylet/format/*_generated.h
@@ -54,9 +42,6 @@ python/.eggs
*.dylib
*.dll
# Cython-generated files
*.c
# Incremental linking files
*.ilk
+2 -84
View File
@@ -53,7 +53,7 @@ matrix:
- sphinx-build -W -b html -d _build/doctrees source _build/html
- cd ..
# Run Python linting, ignore dict vs {} (C408), others are defaults
- flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
- flake8 --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
- .travis/format.sh --all
- os: linux
@@ -69,16 +69,9 @@ matrix:
script:
- cd build
# - bash ../src/common/test/run_valgrind.sh
# - bash ../src/plasma/test/run_valgrind.sh
# - bash ../src/local_scheduler/test/run_valgrind.sh
- bash ../src/ray/test/run_object_manager_valgrind.sh
- cd ..
# - python ./python/ray/plasma/test/test.py valgrind
# - python ./python/ray/local_scheduler/test/test.py valgrind
# - python ./python/ray/global_scheduler/test/test.py valgrind
# Build Linux wheels.
- os: linux
dist: trusty
@@ -108,75 +101,6 @@ matrix:
- PYTHON=3.5
- RAY_USE_NEW_GCS=on
# Test legacy Ray.
- os: linux
dist: trusty
env: PYTHON=3.5 RAY_USE_XRAY=0
install:
- ./.travis/install-dependencies.sh
- export PATH="$HOME/miniconda/bin:$PATH"
- ./.travis/install-ray.sh
- ./.travis/install-cython-examples.sh
- cd build
- bash ../src/common/test/run_tests.sh
- bash ../src/plasma/test/run_tests.sh
- bash ../src/local_scheduler/test/run_tests.sh
- cd ..
script:
- export PATH="$HOME/miniconda/bin:$PATH"
# The following is needed so cloudpickle can find some of the
# class definitions: The main module of tests that are run
# with pytest have the same name as the test file -- and this
# module is only found if the test directory is in the PYTHONPATH.
- export PYTHONPATH="$PYTHONPATH:./test/"
- python -m pytest -v python/ray/common/test/test.py
- python -m pytest -v python/ray/common/redis_module/runtest.py
- python -m pytest -v python/ray/plasma/test/test.py
- python -m pytest -v python/ray/local_scheduler/test/test.py
- python -m pytest -v python/ray/global_scheduler/test/test.py
- python -m pytest -v python/ray/test/test_global_state.py
- python -m pytest -v python/ray/test/test_queue.py
- python -m pytest -v python/ray/test/test_ray_init.py
- python -m pytest -v test/xray_test.py
- python -m pytest -v test/runtest.py
- python -m pytest -v test/array_test.py
- python -m pytest -v test/actor_test.py
- python -m pytest -v test/autoscaler_test.py
- python -m pytest -v test/tensorflow_test.py
- python -m pytest -v test/failure_test.py
- python -m pytest -v test/microbenchmarks.py
- python -m pytest -v test/stress_tests.py
- pytest test/component_failures_test.py
- python test/multi_node_test.py
- python -m pytest -v test/multi_node_test_2.py
- python -m pytest -v test/recursion_test.py
- pytest test/monitor_test.py
- python -m pytest -v test/cython_test.py
- python -m pytest -v test/credis_test.py
# ray tune tests
- python python/ray/tune/test/dependency_test.py
- python -m pytest -v python/ray/tune/test/trial_runner_test.py
- python -m pytest -v python/ray/tune/test/trial_scheduler_test.py
- python -m pytest -v python/ray/tune/test/experiment_test.py
- python -m pytest -v python/ray/tune/test/tune_server_test.py
- python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py
- python -m pytest -v python/ray/tune/test/automl_searcher_test.py
# ray rllib tests
- python -m pytest -v python/ray/rllib/test/test_catalog.py
- python -m pytest -v python/ray/rllib/test/test_filters.py
- python -m pytest -v python/ray/rllib/test/test_optimizers.py
- python -m pytest -v python/ray/rllib/test/test_evaluators.py
# ray temp file tests
- python -m pytest -v test/tempfile_test.py
install:
- ./.travis/install-dependencies.sh
@@ -207,12 +131,6 @@ script:
# module is only found if the test directory is in the PYTHONPATH.
- export PYTHONPATH="$PYTHONPATH:./test/"
# - python -m pytest -v python/ray/common/test/test.py
# - python -m pytest -v python/ray/common/redis_module/runtest.py
# - python -m pytest -v python/ray/plasma/test/test.py
# - python -m pytest -v python/ray/local_scheduler/test/test.py
# - python -m pytest -v python/ray/global_scheduler/test/test.py
- python -m pytest -v python/ray/test/test_global_state.py
- python -m pytest -v python/ray/test/test_queue.py
- python -m pytest -v python/ray/test/test_ray_init.py
@@ -227,7 +145,7 @@ script:
- python -m pytest -v test/microbenchmarks.py
- python -m pytest -v test/stress_tests.py
- python -m pytest -v test/component_failures_test.py
- python test/multi_node_test.py
- python -m pytest -v test/multi_node_test.py
- python -m pytest -v test/multi_node_test_2.py
- python -m pytest -v test/recursion_test.py
- python -m pytest -v test/monitor_test.py
-1
View File
@@ -30,7 +30,6 @@ YAPF_EXCLUDES=(
'--exclude' 'python/build/*'
'--exclude' 'python/ray/pyarrow_files/*'
'--exclude' 'python/ray/core/src/ray/gcs/*'
'--exclude' 'python/ray/common/thirdparty/*'
)
# Format specified files
+7 -12
View File
@@ -82,18 +82,15 @@ include_directories(SYSTEM ${PLASMA_INCLUDE_DIR})
include_directories("${CMAKE_CURRENT_LIST_DIR}/src/")
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/ray/)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/common/)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/plasma/)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/local_scheduler/)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/global_scheduler/)
# final target copy_ray
add_custom_target(copy_ray ALL)
# copy plasma_store_server
add_custom_command(TARGET copy_ray POST_BUILD
COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/src/plasma
COMMAND ${CMAKE_COMMAND} -E
copy ${ARROW_HOME}/bin/plasma_store_server ${CMAKE_CURRENT_BINARY_DIR}/src/plasma)
copy ${ARROW_HOME}/bin/plasma_store_server ${CMAKE_CURRENT_BINARY_DIR}/src/plasma/)
if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES")
# add pyarrow as the dependency
@@ -102,12 +99,9 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES")
# NOTE: The lists below must be kept in sync with ray/python/setup.py.
set(ray_file_list
"src/common/thirdparty/redis/src/redis-server"
"src/common/redis_module/libray_redis_module.so"
"src/plasma/plasma_manager"
"src/local_scheduler/local_scheduler"
"src/local_scheduler/liblocal_scheduler_library_python.so"
"src/global_scheduler/global_scheduler"
"src/ray/thirdparty/redis/src/redis-server"
"src/ray/gcs/redis_module/libray_redis_module.so"
"src/ray/raylet/liblocal_scheduler_library_python.so"
"src/ray/raylet/raylet_monitor"
"src/ray/raylet/raylet")
@@ -154,5 +148,6 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES")
# copy libplasma_java files
add_custom_command(TARGET copy_ray POST_BUILD
COMMAND bash -c "cp ${ARROW_LIBRARY_DIR}/libplasma_java.* ${CMAKE_CURRENT_BINARY_DIR}/src/plasma")
COMMAND bash -c "mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/src/plasma"
COMMAND bash -c "cp ${ARROW_LIBRARY_DIR}/libplasma_java.* ${CMAKE_CURRENT_BINARY_DIR}/src/plasma/")
endif()
+2 -2
View File
@@ -15,10 +15,10 @@
# - PLASMA_SHARED_LIB
set(arrow_URL https://github.com/apache/arrow.git)
# The PR for this commit is https://github.com/apache/arrow/pull/2792. We
# The PR for this commit is https://github.com/apache/arrow/pull/2826. We
# include the link here to make it easier to find the right commit because
# Arrow often rewrites git history and invalidates certain commits.
set(arrow_TAG 2d0d3d0dc51999fbaafb15d8b8362a1ef3de2ef7)
set(arrow_TAG b4f7ed6d6ed5cdb6dd136bac3181a438f35c8ea0)
set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install)
set(ARROW_HOME ${ARROW_INSTALL_PREFIX})
-3
View File
@@ -41,6 +41,3 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES")
message (WARNING "NOT FIND JNI")
endif()
endif()
include_directories(${CMAKE_SOURCE_DIR}/src/common)
include_directories(${CMAKE_SOURCE_DIR}/src/common/thirdparty)
+1 -2
View File
@@ -65,8 +65,7 @@ When ``a1.increment.remote()`` is called, the following events happens.
1. A task is created.
2. The task is assigned directly to the local scheduler responsible for the
actor by the driver's local scheduler. Thus, this scheduling procedure
bypasses the global scheduler.
actor by the driver's local scheduler.
3. An object ID is returned.
We can then call ``ray.get`` on the object ID to retrieve the actual value.
+44 -57
View File
@@ -18,44 +18,38 @@ import shlex
# These lines added to enable Sphinx to work without installing Ray.
import mock
MOCK_MODULES = ["gym",
"gym.spaces",
"scipy",
"scipy.signal",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.layers",
"tensorflow.contrib.slim",
"tensorflow.contrib.rnn",
"tensorflow.core",
"tensorflow.core.util",
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"ray.local_scheduler",
"ray.plasma",
"ray.core",
"ray.core.generated",
"ray.core.generated.DriverTableMessage",
"ray.core.generated.LocalSchedulerInfoMessage",
"ray.core.generated.ResultTableReply",
"ray.core.generated.SubscribeToDBClientTableReply",
"ray.core.generated.SubscribeToNotificationsReply",
"ray.core.generated.TaskInfo",
"ray.core.generated.TaskReply",
"ray.core.generated.TaskExecutionDependencies",
"ray.core.generated.ClientTableData",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.DriverTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ProfileTableData",
"ray.core.generated.ObjectTableData",
"ray.core.generated.ray.protocol.Task",
"ray.core.generated.TablePrefix",
"ray.core.generated.TablePubsub",]
MOCK_MODULES = [
"gym",
"gym.spaces",
"scipy",
"scipy.signal",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.layers",
"tensorflow.contrib.slim",
"tensorflow.contrib.rnn",
"tensorflow.core",
"tensorflow.core.util",
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"ray.raylet",
"ray.plasma",
"ray.core",
"ray.core.generated",
"ray.core.generated.ClientTableData",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.DriverTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ProfileTableData",
"ray.core.generated.ObjectTableData",
"ray.core.generated.ray.protocol.Task",
"ray.core.generated.TablePrefix",
"ray.core.generated.TablePubsub",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
sys.modules[mod_name] = mock.Mock()
# ray.rllib.models.action_dist.py and
# ray.rllib.models.lstm.py will use tf.VERSION
sys.modules["tensorflow"].VERSION = "9.9.9"
@@ -89,7 +83,7 @@ from recommonmark.parser import CommonMarkParser
source_suffix = ['.rst', '.md']
source_parsers = {
'.md': CommonMarkParser,
'.md': CommonMarkParser,
}
# The encoding of source files.
@@ -259,25 +253,24 @@ htmlhelp_basename = 'Raydoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
# Latex figure (float) alignment
#'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'Ray.tex', u'Ray Documentation',
u'The Ray Team', 'manual'),
(master_doc, 'Ray.tex', u'Ray Documentation', u'The Ray Team', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
@@ -300,29 +293,23 @@ latex_documents = [
# If false, no module index is generated.
#latex_domain_indices = True
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'ray', u'Ray Documentation',
[author], 1)
]
man_pages = [(master_doc, 'ray', u'Ray Documentation', [author], 1)]
# If true, show URL addresses after external links.
#man_show_urls = False
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'Ray', u'Ray Documentation',
author, 'Ray', 'One line description of project.',
'Miscellaneous'),
(master_doc, 'Ray', u'Ray Documentation', author, 'Ray',
'One line description of project.', 'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.
+1 -1
View File
@@ -47,7 +47,7 @@ Process Failures
~~~~~~~~~~~~~~~~
1. Ray does not recover from the failure of any of the following processes:
a Redis server, the global scheduler, the monitor process.
a Redis server and the monitor process.
2. If a driver fails, that driver will not be restarted and the job will not
complete.
+3 -3
View File
@@ -15,8 +15,8 @@ Running Ray standalone
Ray can be used standalone by calling ``ray.init()`` within a script. When the
call to ``ray.init()`` happens, all of the relevant processes are started.
These include a local scheduler, a global scheduler, an object store and
manager, a Redis server, and a number of worker processes.
These include a local scheduler, an object store and manager, a Redis server,
and a number of worker processes.
When the script exits, these processes will be killed.
@@ -112,7 +112,7 @@ When a driver or worker invokes a remote function, a number of things happen.
- The task object is then sent to the local scheduler on the same node as the
driver or worker.
- The local scheduler makes a decision to either schedule the task locally or to
pass the task on to a global scheduler.
pass the task on to another local scheduler.
- If all of the task's object dependencies are present in the local object
store and there are enough CPU and GPU resources available to execute the
-2
View File
@@ -45,8 +45,6 @@ A typical layout of temporary files could look like this:
│   ├── log_monitor.out
│   ├── monitor.err
│   ├── monitor.out
│   ├── plasma_manager_0.err # array of plasma managers' outputs
│   ├── plasma_manager_0.out
│   ├── plasma_store_0.err # array of plasma stores' outputs
│   ├── plasma_store_0.out
│   ├── raylet_0.err # array of raylets' outputs. Control it with `--no-redirect-worker-output` (in Ray's command line) or `redirect_worker_output` (in ray.init())
+1 -3
View File
@@ -9,7 +9,7 @@ To use Ray, you need to understand the following:
Overview
--------
Ray is a Python-based distributed execution engine. The same code can be run on
Ray is a distributed execution engine. The same code can be run on
a single machine to achieve efficient multiprocessing, and it can be used on a
cluster for large computations.
@@ -21,8 +21,6 @@ When using Ray, several processes are involved.
allows workers to efficiently share objects on the same node with minimal
copying and deserialization.
- One **local scheduler** per node assigns tasks to workers on the same node.
- A **global scheduler** receives tasks from local schedulers and assigns them
to other local schedulers.
- A **driver** is the Python process that the user controls. For example, if the
user is running a script or using a Python shell, then the driver is the Python
process that runs the script or the shell. A driver is similar to a worker in
-1
View File
@@ -51,7 +51,6 @@ Now we've started all of the Ray processes on each node Ray. This includes
- An object store on each machine.
- A local scheduler on each machine.
- Multiple Redis servers (on the head node).
- One global scheduler (on the head node).
To run some commands, start up Python on one of the nodes in the cluster, and do
the following.
@@ -154,7 +154,6 @@ Now you have started all of the Ray processes on each node. These include:
- An object store on each machine.
- A local scheduler on each machine.
- Multiple Redis servers (on the head node).
- One global scheduler (on the head node).
To confirm that the Ray cluster setup is working, start up Python on one of the
nodes in the cluster and enter the following commands to connect to the Ray
+1 -1
View File
@@ -10,5 +10,5 @@
<suppress checks=".*" files="RayCall.java"/>
<!-- suppress check for flatbuffer-generated files. -->
<!-- TODO(raulchen): move these files to a directory, so this rule can be simplier. -->
<suppress checks=".*" files="(Arg|ResourcePair|TaskLanguage|TaskInfo|ClientTableData).java" />
<suppress checks=".*" files="(Arg|ResourcePair|Language|TaskInfo|ClientTableData).java" />
</suppressions>
+3 -3
View File
@@ -42,15 +42,15 @@ fi
# echo "ray_dir = $ray_dir"
declare -a nativeBinaries=(
"./src/common/thirdparty/redis/src/redis-server"
"./src/ray/thirdparty/redis/src/redis-server"
"./src/plasma/plasma_store_server"
"./src/ray/raylet/raylet"
"./src/ray/raylet/raylet_monitor"
)
declare -a nativeLibraries=(
"./src/common/redis_module/libray_redis_module.so"
"./src/local_scheduler/liblocal_scheduler_library_java.*"
"./src/ray/gcs/redis_module/libray_redis_module.so"
"./src/ray/raylet/liblocal_scheduler_library_java.*"
"./src/plasma/libplasma_java.*"
"./src/ray/raylet/*lib.a"
)
@@ -165,12 +165,12 @@ public class RayConfig {
// library path
this.libraryPath = new ImmutableList.Builder<String>().add(
rayHome + "/build/src/plasma",
rayHome + "/build/src/local_scheduler"
rayHome + "/build/src/ray/raylet"
).addAll(customLibraryPath).build();
redisServerExecutablePath = rayHome +
"/build/src/common/thirdparty/redis/src/redis-server";
redisModulePath = rayHome + "/build/src/common/redis_module/libray_redis_module.so";
"/build/src/ray/thirdparty/redis/src/redis-server";
redisModulePath = rayHome + "/build/src/ray/gcs/redis_module/libray_redis_module.so";
plasmaStoreExecutablePath = rayHome + "/build/src/plasma/plasma_store_server";
rayletExecutablePath = rayHome + "/build/src/ray/raylet/raylet";
@@ -2,13 +2,13 @@
package org.ray.runtime.generated;
public final class TaskLanguage {
private TaskLanguage() { }
public final class Language {
private Language() { }
public static final int PYTHON = 0;
public static final int JAVA = 1;
public static final int CPP = 1;
public static final int JAVA = 2;
public static final String[] names = { "PYTHON", "JAVA", };
public static final String[] names = { "PYTHON", "CPP", "JAVA", };
public static String name(int e) { return names[e]; }
}
@@ -13,9 +13,9 @@ import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionDescriptor;
import org.ray.runtime.generated.Arg;
import org.ray.runtime.generated.Language;
import org.ray.runtime.generated.ResourcePair;
import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.generated.TaskLanguage;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.UniqueIdUtil;
@@ -229,7 +229,7 @@ public class RayletClientImpl implements RayletClient {
actorIdOffset, actorHandleIdOffset, actorCounter,
false, functionIdOffset,
argsOffset, returnsOffset, requiredResourcesOffset,
requiredPlacementResourcesOffset, TaskLanguage.JAVA,
requiredPlacementResourcesOffset, Language.JAVA,
functionDescriptorOffset);
fbb.finish(root);
ByteBuffer buffer = fbb.dataBuffer();
@@ -256,8 +256,8 @@ public class RayletClientImpl implements RayletClient {
/// 1) pushd $Dir/java/runtime/target/classes
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.RayletClientImpl
/// 3) clang-format -i org_ray_runtime_raylet_RayletClientImpl.h
/// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/local_scheduler/lib/java/
/// 5) vim $Dir/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
/// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/ray/raylet/lib/java/
/// 5) vim $Dir/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
/// 6) popd
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
@@ -23,7 +23,7 @@ public class RayConfigTest {
Assert.assertEquals(System.getProperty("user.dir"), rayConfig.rayHome);
Assert.assertEquals(System.getProperty("user.dir") +
"/build/src/common/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath);
"/build/src/ray/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath);
Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath);
+1 -1
View File
@@ -40,7 +40,7 @@
<RAY_CONFIG>${basedir}/../ray.config.ini</RAY_CONFIG>
</environmentVariables>
<argLine>-ea
-Djava.library.path=${basedir}/../../build/src/plasma:${basedir}/../../build/src/local_scheduler
-Djava.library.path=${basedir}/../../build/src/plasma:${basedir}/../../build/src/ray/raylet
-noverify
-DlogOutput=console
</argLine>
+1 -1
View File
@@ -46,7 +46,7 @@ except ImportError as e:
e.args += (helpful_message, )
raise
from ray.local_scheduler import ObjectID, _config # noqa: E402
from ray.raylet import ObjectID, _config # noqa: E402
from ray.profiling import profile # noqa: E402
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
remote, get_gpu_ids, get_resource_ids, get_webui_url,
+1 -1
View File
@@ -9,7 +9,7 @@ import traceback
import ray.cloudpickle as pickle
from ray.function_manager import FunctionActorManager
import ray.local_scheduler
import ray.raylet
import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
-451
View File
@@ -1,451 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import redis
import sys
import time
import unittest
import ray.gcs_utils
import ray.services
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
assert (numbytes == 4 or numbytes == 8)
for i in range(numbytes):
curbyte = num & 0xff
if sys.version_info >= (3, 0):
retstr += bytes([curbyte])
else:
retstr += chr(curbyte)
num = num >> 8
return retstr
def get_next_message(pubsub_client, timeout_seconds=10):
"""Block until the next message is available on the pubsub channel."""
start_time = time.time()
while True:
message = pubsub_client.get_message()
if message is not None:
return message
time.sleep(0.1)
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for next message.")
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
unused_primary_redis_addr, redis_shards = ray.services.start_redis(
"localhost", use_credis="RAY_USE_NEW_GCS" in os.environ)
self.redis = redis.StrictRedis(
host="localhost", port=redis_shards[0].split(":")[-1], db=0)
def tearDown(self):
ray.services.cleanup()
def testInvalidObjectTableAdd(self):
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
"one", "hash2", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
"hash2", "manager_id1",
"extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
# object ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1"})
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
# Check that the second manager was added, even though the hash was
# mismatched.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that it is fine if we add the same object ID multiple times
# with the most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
"hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
def testObjectTableAddAndLookup(self):
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not
# been added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Add a manager that already exists again and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that we properly handle NULL characters. In the past, NULL
# characters were handled improperly causing a "hash mismatch" error if
# two object IDs that agreed up to the NULL character were inserted
# with different hashes.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
"hash2", "manager_id1")
# Check that NULL characters in the hash are handled properly.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash2", "manager_id1")
def testObjectTableAddAndRemove(self):
# Try removing a manager from an object ID that has not been added yet.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not
# been added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that doesn't exist, and make sure we still have the
# same set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that does exist. Make sure it gets removed the first
# time and does nothing the second time.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
# Remove the last manager, and make sure we have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
# Remove a manager from an empty set, and make sure we now have an
# empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object
# notifications.
def check_object_notification(notification_message, object_id,
object_size, manager_ids):
notification_object = (ray.gcs_utils.SubscribeToNotificationsReply.
GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(
notification_object.ManagerIds(i), manager_ids[i])
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(
ray.gcs_utils.OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1",
data_size, "hash1", "manager_id2")
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(
get_next_message(p)["data"], b"object_id1", data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to
# RAY.OBJECT_TABLE_ADD should trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(
get_next_message(p)["data"], b"object_id3", data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(
get_next_message(p)["data"], b"object_id2", data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(
get_next_message(p)["data"], b"object_id3", data_size,
[b"manager_id1", b"manager_id2", b"manager_id3"])
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = (
ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
message, 0))
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
# Try looking up something in the result table before anything is
# added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Adding the object to the object table should have no effect.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Add the result to the result table. The lookup now returns the task
# ID.
task_id = b"task_id1"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1",
task_id, 0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Try another result table lookup. This should succeed.
task_id = b"task_id2"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2",
task_id, 1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id2")
check_result_table_entry(response, task_id, True)
def testInvalidTaskTableAdd(self):
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3,
"node_id")
with self.assertRaises(redis.ResponseError):
# Non-integer scheduling states should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
"invalid_state", "node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Should not be able to update a non-existent task.
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
"node_id", b"")
def testTaskTableAddAndLookup(self):
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
# make sure somebody will get a notification (checked in the redis
# module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
def check_task_reply(message, task_args, updated=False):
(task_status, local_scheduler_id, execution_dependencies_string,
spillback_count, task_spec) = task_args
task_reply_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
self.assertEqual(task_reply_object.SpillbackCount(),
spillback_count)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"", 0, b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
*task_args[:4])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
# If the current value, test value, and set value are all the same, the
# update happens, and the response is still the same task.
task_args = [task_args[0]] + task_args
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
check_task_reply(get_response, task_args[1:])
# If the current value is the same as the test value, and the set value
# is different, the update happens, and the response is the entire
# task.
task_args[1] = TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
check_task_reply(get_response, task_args[1:])
# If the current value is no longer the same as the test value, the
# response is the same task as before the test-and-set.
new_task_args = task_args[:]
new_task_args[1] = TASK_STATUS_WAITING
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id", *new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
self.assertEqual(get_response2, get_response)
# If the test value is a bitmask that matches the current value, the
# update happens.
task_args = new_task_args
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# If the test value is a bitmask that does not match the current value,
# the update does not happen, and the response is the same task as
# before the test-and-set.
new_task_args = task_args[:]
new_task_args[0] = TASK_STATUS_SCHEDULED
old_response = response
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id", *new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [
b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"
]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
message, 0)
self.assertEqual(notification_object.TaskId(), task_args[0])
self.assertEqual(notification_object.State(), task_args[1])
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
self.assertEqual(notification_object.ExecutionDependencies(),
task_args[3])
self.assertEqual(notification_object.TaskSpec(), task_args[-1])
def testTaskTableSubscribe(self):
scheduling_state = 1
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=ray.gcs_utils.TASK_PREFIX,
local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=ray.gcs_utils.TASK_PREFIX,
local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
if __name__ == "__main__":
unittest.main(verbosity=2)
-181
View File
@@ -1,181 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pickle
import sys
import unittest
import ray.local_scheduler as local_scheduler
import ray.ray_constants as ray_constants
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
BASE_SIMPLE_OBJECTS = [
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
990 * u"h",
np.ones(3),
np.array([True, False]), None, True, False
]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [
long(0), # noqa: E501,F821
long(1), # noqa: E501,F821
long(100000), # noqa: E501,F821
long(1 << 100) # noqa: E501,F821
]
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS]
DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS]
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS +
TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS)
# Create some complex objects that cannot be serialized by value in tasks.
lst = []
lst.append(lst)
class Foo(object):
def __init__(self):
pass
BASE_COMPLEX_OBJECTS = [
15000 * "h", 15000 * u"h", lst,
Foo(), 100 * [100 * [10 * [1]]],
np.array([Foo()])
]
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS]
DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS]
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS)
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
for val in SIMPLE_OBJECTS:
self.assertTrue(local_scheduler.check_simple_value(val))
for val in COMPLEX_OBJECTS:
self.assertFalse(local_scheduler.check_simple_value(val))
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
random_object_id()
def test_cannot_pickle_object_ids(self):
object_ids = [random_object_id() for _ in range(256)]
def f():
return object_ids
def g(val=object_ids):
return 1
def h():
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
self.assertRaises(Exception, lambda: pickle.dumps(f))
self.assertRaises(Exception, lambda: pickle.dumps(g))
self.assertRaises(Exception, lambda: pickle.dumps(h))
def test_equality_comparisons(self):
x1 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"a")
x2 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"a")
y1 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"b")
y2 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"b")
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
self.assertNotEqual(x1, y1)
random_strings = [
np.random.bytes(ray_constants.ID_SIZE) for _ in range(256)
]
object_ids1 = [
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
]
object_ids2 = [
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
def test_hashability(self):
x = random_object_id()
y = random_object_id()
{x: y}
{x, y}
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
self.assertEqual(function_id.id(), task.function_id().id())
retrieved_args = task.arguments()
self.assertEqual(num_return_vals, len(task.returns()))
self.assertEqual(len(args), len(retrieved_args))
for i in range(len(retrieved_args)):
if isinstance(retrieved_args[i], local_scheduler.ObjectID):
self.assertEqual(retrieved_args[i].id(), args[i].id())
else:
self.assertEqual(retrieved_args[i], args[i])
def test_create_and_serialize_task(self):
# TODO(rkn): The function ID should be a FunctionID object, not an
# ObjectID.
driver_id = random_driver_id()
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"],
10 * ["a"], 100 * ["a"], 1000 * ["a"], [
1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]
], object_ids[:1], object_ids[:2], object_ids[:3],
object_ids[:4], object_ids[:5], object_ids[:10],
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
object_ids[0], "a"
], [1, object_ids[0], "a"], [
object_ids[0], 1, object_ids[1], "a"
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args,
num_return_vals, parent_id, 0)
self.check_task(task, function_id, num_return_vals, args)
data = local_scheduler.task_to_string(task)
task2 = local_scheduler.task_from_string(data)
self.check_task(task2, function_id, num_return_vals, args)
if __name__ == "__main__":
unittest.main(verbosity=2)
View File
+2 -4
View File
@@ -108,8 +108,6 @@ class SGDWorker(object):
if plasma_op:
store_socket = (
ray.worker.global_worker.plasma_client.store_socket_name)
manager_socket = (
ray.worker.global_worker.plasma_client.manager_socket_name)
if not plasma.tf_plasma_op:
plasma.build_plasma_tensorflow_op()
@@ -130,7 +128,7 @@ class SGDWorker(object):
[grad],
self.plasma_in_grads_oids[j],
plasma_store_socket_name=store_socket,
plasma_manager_socket_name=manager_socket)
plasma_manager_socket_name="")
self.plasma_in_grads.append(plasma_grad)
# For applying grads <- plasma
@@ -149,7 +147,7 @@ class SGDWorker(object):
self.plasma_out_grads_oids[j],
dtype=tf.float32,
plasma_store_socket_name=store_socket,
plasma_manager_socket_name=manager_socket)
plasma_manager_socket_name="")
grad_ph = tf.reshape(grad_ph,
self.packed_grads_and_vars[0][j][0].shape)
logger.debug("Packed tensor {}".format(grad_ph))
+4 -9
View File
@@ -14,15 +14,10 @@ logger = logging.getLogger(__name__)
def fetch(oids):
if ray.global_state.use_raylet:
local_sched_client = ray.worker.global_worker.local_scheduler_client
for o in oids:
ray_obj_id = ray.ObjectID(o)
local_sched_client.reconstruct_objects([ray_obj_id], True)
else:
for o in oids:
plasma_id = ray.pyarrow.plasma.ObjectID(o)
ray.worker.global_worker.plasma_client.fetch([plasma_id])
local_sched_client = ray.worker.global_worker.local_scheduler_client
for o in oids:
ray_obj_id = ray.ObjectID(o)
local_sched_client.reconstruct_objects([ray_obj_id], True)
def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""):
+128 -398
View File
@@ -6,8 +6,6 @@ import copy
from collections import defaultdict
import heapq
import json
import numbers
import os
import redis
import sys
import time
@@ -18,25 +16,6 @@ import ray.ray_constants as ray_constants
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
TASK_STATUS_LOST = 32
TASK_STATUS_RECONSTRUCTING = 64
TASK_STATUS_MAPPING = {
TASK_STATUS_WAITING: "WAITING",
TASK_STATUS_SCHEDULED: "SCHEDULED",
TASK_STATUS_QUEUED: "QUEUED",
TASK_STATUS_RUNNING: "RUNNING",
TASK_STATUS_DONE: "DONE",
TASK_STATUS_LOST: "LOST",
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
}
class GlobalState(object):
"""A class used to interface with the Ray control state.
@@ -47,7 +26,6 @@ class GlobalState(object):
Attributes:
redis_client: The Redis client used to query the primary redis server.
redis_clients: Redis clients for each of the Redis shards.
use_raylet: True if we are using the raylet code path.
"""
def __init__(self):
@@ -57,8 +35,6 @@ class GlobalState(object):
self.redis_client = None
# Clients for the redis shards, storing the object table & task table.
self.redis_clients = None
# True if we are using the raylet code path and false otherwise.
self.use_raylet = None
def _check_connected(self):
"""Check that the object has been initialized before it is used.
@@ -130,18 +106,6 @@ class GlobalState(object):
"ip_address_ports = {}".format(
num_redis_shards, ip_address_ports))
use_raylet = self.redis_client.get("UseRaylet")
if use_raylet is not None:
self.use_raylet = bool(int(use_raylet))
elif os.environ.get("RAY_USE_XRAY") == "0":
# This environment variable is used in our testing setup.
print("Detected environment variable 'RAY_USE_XRAY' with value "
"{}. This turns OFF xray.".format(
os.environ.get("RAY_USE_XRAY")))
self.use_raylet = False
else:
self.use_raylet = True
# Get the rest of the information.
self.redis_clients = []
for ip_address_port in ip_address_ports:
@@ -195,51 +159,23 @@ class GlobalState(object):
object_id = ray.ObjectID(hex_to_binary(object_id))
# Return information about a single object ID.
if not self.use_raylet:
# Use the non-raylet code path.
object_locations = self._execute_command(
object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id())
if object_locations is not None:
manager_ids = [
binary_to_hex(manager_id)
for manager_id in object_locations
]
else:
manager_ids = None
message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.OBJECT, "",
object_id.id())
result = []
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
result_table_response = self._execute_command(
object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
result_table_message = (
ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0))
result = {
"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())
for i in range(gcs_entry.EntriesLength()):
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
gcs_entry.Entries(i), 0)
object_info = {
"DataSize": entry.ObjectSize(),
"Manager": entry.Manager(),
"IsEviction": entry.IsEviction(),
"NumEvictions": entry.NumEvictions()
}
else:
# Use the raylet code path.
message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.OBJECT,
"", object_id.id())
result = []
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
for i in range(gcs_entry.EntriesLength()):
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
gcs_entry.Entries(i), 0)
object_info = {
"DataSize": entry.ObjectSize(),
"Manager": entry.Manager(),
"IsEviction": entry.IsEviction(),
"NumEvictions": entry.NumEvictions()
}
result.append(object_info)
result.append(object_info)
return result
@@ -259,25 +195,12 @@ class GlobalState(object):
return self._object_table(object_id)
else:
# Return the entire object table.
if not self.use_raylet:
object_info_keys = self._keys(
ray.gcs_utils.OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(
ray.gcs_utils.OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set([
key[len(ray.gcs_utils.OBJECT_INFO_PREFIX):]
for key in object_info_keys
] + [
key[len(ray.gcs_utils.OBJECT_LOCATION_PREFIX):]
for key in object_location_keys
])
else:
object_keys = self._keys(
ray.gcs_utils.TablePrefix_OBJECT_string + "*")
object_ids_binary = {
key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
for key in object_keys
}
object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string +
"*")
object_ids_binary = {
key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
for key in object_keys
}
results = {}
for object_id_binary in object_ids_binary:
@@ -294,21 +217,21 @@ class GlobalState(object):
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field
into a human-readable string.
"""
if not self.use_raylet:
# Use the non-raylet code path.
task_table_response = self._execute_command(
task_id, "RAY.TASK_TABLE_GET", task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task "
"table.".format(binary_to_hex(task_id.id())))
task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec = ray.local_scheduler.task_from_string(task_spec)
message = self._execute_command(task_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
"", task_id.id())
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
info = []
for i in range(gcs_entries.EntriesLength()):
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(i), 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task_spec = ray.raylet.task_from_string(task_spec)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
@@ -326,80 +249,19 @@ class GlobalState(object):
"RequiredResources": task_spec.required_resources()
}
execution_dependencies_message = (
ray.gcs_utils.TaskExecutionDependencies.
GetRootAsTaskExecutionDependencies(
task_table_message.ExecutionDependencies(), 0))
execution_dependencies = [
ray.ObjectID(
execution_dependencies_message.ExecutionDependencies(i))
for i in range(execution_dependencies_message.
ExecutionDependenciesLength())
]
# TODO(rkn): The return fields ExecutionDependenciesString and
# ExecutionDependencies are redundant, so we should remove
# ExecutionDependencies. However, it is currently used in
# monitor.py.
return {
"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"ExecutionDependenciesString": task_table_message.
ExecutionDependencies(),
"ExecutionDependencies": execution_dependencies,
"SpillbackCount": task_table_message.SpillbackCount(),
info.append({
"ExecutionSpec": {
"Dependencies": [
execution_spec.Dependencies(i)
for i in range(execution_spec.DependenciesLength())
],
"LastTimestamp": execution_spec.LastTimestamp(),
"NumForwards": execution_spec.NumForwards()
},
"TaskSpec": task_spec_info
}
})
else:
# Use the raylet code path.
message = self._execute_command(
task_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.RAYLET_TASK, "", task_id.id())
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
info = []
for i in range(gcs_entries.EntriesLength()):
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(i), 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task_spec = ray.local_scheduler.task_from_string(task_spec)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
"ParentTaskID": binary_to_hex(
task_spec.parent_task_id().id()),
"ParentCounter": task_spec.parent_counter(),
"ActorID": binary_to_hex(task_spec.actor_id().id()),
"ActorCreationID": binary_to_hex(
task_spec.actor_creation_id().id()),
"ActorCreationDummyObjectID": binary_to_hex(
task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter": task_spec.actor_counter(),
"FunctionID": binary_to_hex(task_spec.function_id().id()),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources()
}
info.append({
"ExecutionSpec": {
"Dependencies": [
execution_spec.Dependencies(i)
for i in range(execution_spec.DependenciesLength())
],
"LastTimestamp": execution_spec.LastTimestamp(),
"NumForwards": execution_spec.NumForwards()
},
"TaskSpec": task_spec_info
})
return info
return info
def task_table(self, task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
@@ -416,19 +278,12 @@ class GlobalState(object):
task_id = ray.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
if not self.use_raylet:
task_table_keys = self._keys(ray.gcs_utils.TASK_PREFIX + "*")
task_ids_binary = [
key[len(ray.gcs_utils.TASK_PREFIX):]
for key in task_table_keys
]
else:
task_table_keys = self._keys(
ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
task_ids_binary = [
key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
for key in task_table_keys
]
task_table_keys = self._keys(
ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
task_ids_binary = [
key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
for key in task_table_keys
]
results = {}
for task_id_binary in task_ids_binary:
@@ -464,95 +319,54 @@ class GlobalState(object):
Information about the Ray clients in the cluster.
"""
self._check_connected()
if not self.use_raylet:
db_client_keys = self.redis_client.keys(
ray.gcs_utils.DB_CLIENT_PREFIX + "*")
node_info = {}
for key in db_client_keys:
client_info = self.redis_client.hgetall(key)
node_ip_address = decode(client_info[b"node_ip_address"])
if node_ip_address not in node_info:
node_info[node_ip_address] = []
client_info_parsed = {}
assert b"client_type" in client_info
assert b"deleted" in client_info
assert b"ray_client_id" in client_info
for field, value in client_info.items():
if field == b"node_ip_address":
pass
elif field == b"client_type":
client_info_parsed["ClientType"] = decode(value)
elif field == b"deleted":
client_info_parsed["Deleted"] = bool(
int(decode(value)))
elif field == b"ray_client_id":
client_info_parsed["DBClientID"] = binary_to_hex(value)
elif field == b"manager_address":
client_info_parsed["AuxAddress"] = decode(value)
elif field == b"local_scheduler_socket_name":
client_info_parsed["LocalSchedulerSocketName"] = (
decode(value))
elif client_info[b"client_type"] == b"local_scheduler":
# The remaining fields are resource types.
client_info_parsed[decode(field)] = float(
decode(value))
else:
client_info_parsed[decode(field)] = decode(value)
node_info[node_ip_address].append(client_info_parsed)
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
NIL_CLIENT_ID)
return node_info
# Handle the case where no clients are returned. This should only
# occur potentially immediately after the cluster is started.
if message is None:
return []
else:
# This is the raylet code path.
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
NIL_CLIENT_ID)
node_info = {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
# Handle the case where no clients are returned. This should only
# occur potentially immediately after the cluster is started.
if message is None:
return []
# Since GCS entries are append-only, we override so that
# only the latest entries are kept.
for i in range(gcs_entry.EntriesLength()):
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
node_info = {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
resources = {
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
}
client_id = ray.utils.binary_to_hex(client.ClientId())
# Since GCS entries are append-only, we override so that
# only the latest entries are kept.
for i in range(gcs_entry.EntriesLength()):
client = (
ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
# If this client is being removed, then it must
# have previously been inserted, and
# it cannot have previously been removed.
if not client.IsInsertion():
assert client_id in node_info, "Client removed not found!"
assert node_info[client_id]["IsInsertion"], (
"Unexpected duplicate removal of client.")
resources = {
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
}
client_id = ray.utils.binary_to_hex(client.ClientId())
# If this client is being removed, then it must
# have previously been inserted, and
# it cannot have previously been removed.
if not client.IsInsertion():
assert client_id in node_info, "Client removed not found!"
assert node_info[client_id]["IsInsertion"], (
"Unexpected duplicate removal of client.")
node_info[client_id] = {
"ClientID": client_id,
"IsInsertion": client.IsInsertion(),
"NodeManagerAddress": decode(client.NodeManagerAddress()),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName()),
"RayletSocketName": decode(client.RayletSocketName()),
"Resources": resources
}
return list(node_info.values())
node_info[client_id] = {
"ClientID": client_id,
"IsInsertion": client.IsInsertion(),
"NodeManagerAddress": decode(client.NodeManagerAddress()),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName()),
"RayletSocketName": decode(client.RayletSocketName()),
"Resources": resources
}
return list(node_info.values())
def log_files(self):
"""Fetch and return a dictionary of log file names to outputs.
@@ -755,10 +569,6 @@ class GlobalState(object):
return profile_events
def profile_table(self):
if not self.use_raylet:
raise Exception("This method is only supported in the raylet "
"code path.")
profile_table_keys = self._keys(
ray.gcs_utils.TablePrefix_PROFILE_string + "*")
component_identifiers_binary = [
@@ -1207,23 +1017,6 @@ class GlobalState(object):
info[key] = cur
latest_timestamp = cur
def local_schedulers(self):
"""Get a list of live local schedulers.
Returns:
A list of the live local schedulers.
"""
if self.use_raylet:
raise Exception("The local_schedulers() method is deprecated.")
clients = self.client_table()
local_schedulers = []
for ip_address, client_list in clients.items():
for client in client_list:
if (client["ClientType"] == "local_scheduler"
and not client["Deleted"]):
local_schedulers.append(client)
return local_schedulers
def workers(self):
"""Get a dictionary mapping worker ID to worker information."""
worker_keys = self.redis_client.keys("Worker*")
@@ -1237,8 +1030,6 @@ class GlobalState(object):
"local_scheduler_socket": (decode(
worker_info[b"local_scheduler_socket"])),
"node_ip_address": decode(worker_info[b"node_ip_address"]),
"plasma_manager_socket": decode(
worker_info[b"plasma_manager_socket"]),
"plasma_store_socket": decode(
worker_info[b"plasma_store_socket"])
}
@@ -1298,24 +1089,12 @@ class GlobalState(object):
resource in the cluster.
"""
resources = defaultdict(int)
if not self.use_raylet:
local_schedulers = self.local_schedulers()
for local_scheduler in local_schedulers:
for key, value in local_scheduler.items():
if key not in [
"ClientType", "Deleted", "DBClientID",
"AuxAddress", "LocalSchedulerSocketName"
]:
resources[key] += value
else:
clients = self.client_table()
for client in clients:
# Only count resources from live clients.
if client["IsInsertion"]:
for key, value in client["Resources"].items():
resources[key] += value
clients = self.client_table()
for client in clients:
# Only count resources from live clients.
if client["IsInsertion"]:
for key, value in client["Resources"].items():
resources[key] += value
return dict(resources)
@@ -1340,93 +1119,48 @@ class GlobalState(object):
"""
available_resources_by_id = {}
if not self.use_raylet:
subscribe_client = self.redis_client.pubsub()
subscribe_client.subscribe(
ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
subscribe_clients = [
redis_client.pubsub(ignore_subscribe_messages=True)
for redis_client in self.redis_clients
]
for subscribe_client in subscribe_clients:
subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
local_scheduler_ids = {
local_scheduler["DBClientID"]
for local_scheduler in self.local_schedulers()
}
client_ids = self._live_client_ids()
while set(available_resources_by_id.keys()) != local_scheduler_ids:
while set(available_resources_by_id.keys()) != client_ids:
for subscribe_client in subscribe_clients:
# Parse client message
raw_message = subscribe_client.get_message()
if raw_message is None:
if (raw_message is None or raw_message["channel"] !=
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
# Ignore subscribtion success message from Redis
# This is a long in python 2 and an int in python 3
if isinstance(data, numbers.Number):
continue
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
GetRootAsLocalSchedulerInfoMessage(data, 0))
num_resources = message.DynamicResourcesLength()
gcs_entries = (
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
# Calculate available resources for this client
num_resources = message.ResourcesAvailableLabelLength()
dynamic_resources = {}
for i in range(num_resources):
dyn = message.DynamicResources(i)
resource_id = decode(dyn.Key())
dynamic_resources[resource_id] = dyn.Value()
resource_id = decode(message.ResourcesAvailableLabel(i))
dynamic_resources[resource_id] = (
message.ResourcesAvailableCapacity(i))
# Update available resources for this local scheduler
client_id = binary_to_hex(message.DbClientId())
# Update available resources for this client
client_id = ray.utils.binary_to_hex(message.ClientId())
available_resources_by_id[client_id] = dynamic_resources
# Update local schedulers in cluster
local_scheduler_ids = {
local_scheduler["DBClientID"]
for local_scheduler in self.local_schedulers()
}
# Remove disconnected local schedulers
for local_scheduler_id in available_resources_by_id.keys():
if local_scheduler_id not in local_scheduler_ids:
del available_resources_by_id[local_scheduler_id]
else:
subscribe_clients = [
redis_client.pubsub(ignore_subscribe_messages=True)
for redis_client in self.redis_clients
]
for subscribe_client in subscribe_clients:
subscribe_client.subscribe(
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
# Update clients in cluster
client_ids = self._live_client_ids()
while set(available_resources_by_id.keys()) != client_ids:
for subscribe_client in subscribe_clients:
# Parse client message
raw_message = subscribe_client.get_message()
if (raw_message is None or raw_message["channel"] !=
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
gcs_entries = (
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
# Calculate available resources for this client
num_resources = message.ResourcesAvailableLabelLength()
dynamic_resources = {}
for i in range(num_resources):
resource_id = decode(
message.ResourcesAvailableLabel(i))
dynamic_resources[resource_id] = (
message.ResourcesAvailableCapacity(i))
# Update available resources for this client
client_id = ray.utils.binary_to_hex(message.ClientId())
available_resources_by_id[client_id] = dynamic_resources
# Update clients in cluster
client_ids = self._live_client_ids()
# Remove disconnected clients
for client_id in available_resources_by_id.keys():
if client_id not in client_ids:
del available_resources_by_id[client_id]
# Remove disconnected clients
for client_id in available_resources_by_id.keys():
if client_id not in client_ids:
del available_resources_by_id[client_id]
# Calculate total available resources
total_available_resources = defaultdict(int)
@@ -1479,10 +1213,6 @@ class GlobalState(object):
A dictionary mapping job ID to a list of the error messages for
that job.
"""
if not self.use_raylet:
raise Exception("The error_messages method is only supported in "
"the raylet code path.")
if job_id is not None:
return self._error_messages(job_id)
-29
View File
@@ -4,19 +4,6 @@ from __future__ import print_function
import flatbuffers
from ray.core.generated.ResultTableReply import ResultTableReply
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskExecutionDependencies import \
TaskExecutionDependencies
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.core.generated.TaskInfo import TaskInfo
import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
@@ -32,29 +19,13 @@ from ray.core.generated.TablePrefix import TablePrefix
from ray.core.generated.TablePubsub import TablePubsub
__all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"DriverTableData", "ProfileTableData", "ObjectTableData", "Task",
"TablePrefix", "TablePubsub", "construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in
# ray_redis_module.cc.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
FUNCTION_PREFIX = "RemoteFunction:"
# These prefixes must be kept up-to-date with the definitions in
# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
-7
View File
@@ -1,7 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .global_scheduler_services import start_global_scheduler
__all__ = ["start_global_scheduler"]
@@ -1,61 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import subprocess
import time
def start_global_scheduler(redis_address,
node_ip_address,
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will
run on.
use_valgrind (bool): True if the global scheduler should be started
inside of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started
inside a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
Return:
The process ID of the global scheduler process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [
global_scheduler_executable, "-r", redis_address, "-h", node_ip_address
]
if use_valgrind:
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return pid
-332
View File
@@ -1,332 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import random
import signal
import sys
import time
import unittest
# The ray import must come before the pyarrow import because ray modifies the
# python path so that the right version of pyarrow is found.
import ray.global_scheduler as global_scheduler
import ray.local_scheduler as local_scheduler
import ray.plasma as plasma
from ray.plasma.utils import create_object
from ray import services
from ray.experimental import state
import ray.ray_constants as ray_constants
import pyarrow as pa
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
NUM_CLUSTER_NODES = 2
NIL_WORKER_ID = ray_constants.ID_SIZE * b"\xff"
NIL_OBJECT_ID = ray_constants.ID_SIZE * b"\xff"
NIL_ACTOR_ID = ray_constants.ID_SIZE * b"\xff"
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def new_port():
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
self.node_ip_address = "127.0.0.1"
redis_address, redis_shards = services.start_redis(
self.node_ip_address, use_raylet=False)
redis_port = services.get_port(redis_address)
time.sleep(0.1)
# Create a client for the global state store.
self.state = state.GlobalState()
self.state._initialize_global_state(self.node_ip_address, redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
self.plasma_clients = []
self.local_scheduler_clients = []
for i in range(NUM_CLUSTER_NODES):
# Start the Plasma store. Plasma store name is randomly generated.
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated
# by the plasma module.
manager_info = plasma.start_plasma_manager(plasma_store_name,
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(self.node_ip_address,
plasma_manager_port)
plasma_client = pa.plasma.connect(plasma_store_name,
plasma_manager_name, 64)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=plasma_address,
redis_address=redis_address,
static_resources={"CPU": 10})
# Connect to the scheduler.
local_scheduler_client = local_scheduler.LocalSchedulerClient(
local_scheduler_name, NIL_WORKER_ID, False, random_task_id(),
False)
self.local_scheduler_clients.append(local_scheduler_client)
self.local_scheduler_pids.append(p4)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
for p2 in self.plasma_store_pids:
self.assertEqual(p2.poll(), None)
for p3 in self.plasma_manager_pids:
self.assertEqual(p3.poll(), None)
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
redis_processes = services.all_processes[
services.PROCESS_TYPE_REDIS_SERVER]
for redis_process in redis_processes:
self.assertEqual(redis_process.poll(), None)
# Kill the global scheduler.
if USE_VALGRIND:
self.p1.send_signal(signal.SIGTERM)
self.p1.wait()
if self.p1.returncode != 0:
os._exit(-1)
else:
self.p1.kill()
# Kill local schedulers, plasma managers, and plasma stores.
for p2 in self.local_scheduler_pids:
p2.kill()
for p3 in self.plasma_manager_pids:
p3.kill()
for p4 in self.plasma_store_pids:
p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to
# happen after we kill the global scheduler.
while redis_processes:
redis_process = redis_processes.pop()
redis_process.kill()
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
Iterates over all the client table keys, gets the db_client_id for the
client with client_type matching plasma_manager. Strips the client
table prefix. TODO(atumanov): write a separate function to get all
plasma manager client IDs.
Returns:
The db_client_id if one is found and otherwise None.
"""
db_client_id = None
client_list = self.state.client_table()[self.node_ip_address]
for client in client_list:
if client["ClientType"] == "plasma_manager":
db_client_id = client["DBClientID"]
break
return db_client_id
def test_task_default_resources(self):
task1 = local_scheduler.Task(
random_driver_id(), random_function_id(), [random_object_id()], 0,
random_task_id(), 0)
self.assertEqual(task1.required_resources(), {"CPU": 1})
task2 = local_scheduler.Task(
random_driver_id(), random_function_id(), [random_object_id()], 0,
random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_OBJECT_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], {
"CPU": 1,
"GPU": 2
})
self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2})
def test_redis_only_single_task(self):
# Tests global scheduler functionality by interacting with Redis and
# checking task state transitions in Redis only. TODO(atumanov):
# implement.
# Check precondition for this test:
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert (db_client_id is not None)
@unittest.skipIf(
os.environ.get("RAY_USE_NEW_GCS", False),
"New GCS API doesn't have a Python API yet.")
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to local scheduler.
time.sleep(0.1)
# Submit a task to Redis.
task = local_scheduler.Task(
random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep.binary())],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to
# the local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [
state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED
])
if task_status == state.TASK_STATUS_QUEUED:
break
else:
print(task_status)
print("The task has not been scheduled yet, trying again.")
num_retries -= 1
time.sleep(1)
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
num_tasks = 1000
for _ in range(num_tasks):
# Create a new object for each task.
data_size = np.random.randint(1 << 12)
metadata_size = np.random.randint(1 << 9)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to
# yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(
random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep.binary())],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that
# they all get assigned to the local scheduler.
num_retries = 20
num_tasks_done = 0
while num_retries > 0:
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_statuses = [
task_entry["State"]
for task_entry in task_entries.values()
]
self.assertTrue(
all(status in [
state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED
] for status in task_statuses))
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(
state.TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(
state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, "
"tasks scheduled = {}, "
"tasks queued = {}, retries left = {}".format(
len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all(status == state.TASK_STATUS_QUEUED
for status in task_statuses):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
# Tasks can either be queued or in the global scheduler due to
# spillback.
self.assertEqual(num_tasks_done + num_tasks_waiting, num_tasks)
@unittest.skipIf(
os.environ.get("RAY_USE_NEW_GCS", False),
"New GCS API doesn't have a Python API yet.")
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
@unittest.skipIf(
os.environ.get("RAY_USE_NEW_GCS", False),
"New GCS API doesn't have a Python API yet.")
def test_integration_many_tasks(self):
# More realistic case: should handle out of order object and task
# notifications.
self.integration_many_tasks_helper(timesync=False)
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument
# parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
+2 -5
View File
@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray.local_scheduler
import ray.raylet
import ray.worker
from ray import profiling
@@ -42,7 +42,4 @@ def free(object_ids, local_only=False, worker=None):
if len(object_ids) == 0:
return
if worker.use_raylet:
worker.local_scheduler_client.free(object_ids, local_only)
else:
raise Exception("Free is not supported in legacy backend.")
worker.local_scheduler_client.free(object_ids, local_only)
@@ -1,132 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import subprocess
import sys
import time
from ray.tempfile_services import (get_local_scheduler_socket_name,
get_temp_root)
def start_local_scheduler(plasma_store_name,
plasma_manager_name=None,
worker_path=None,
plasma_address=None,
node_ip_address="127.0.0.1",
redis_address=None,
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None,
static_resources=None,
num_workers=0):
"""Start a local scheduler process.
Args:
plasma_store_name (str): The name of the plasma store socket to connect
to.
plasma_manager_name (str): The name of the plasma manager to connect
to. This does not need to be provided, but if it is, then the Redis
address must be provided as well.
worker_path (str): The path of the worker script to use when the local
scheduler starts up new workers.
plasma_address (str): The address of the plasma manager to connect to.
This is only used by the global scheduler to figure out which
plasma managers are connected to which local schedulers.
node_ip_address (str): The address of the node that this local
scheduler is running on.
redis_address (str): The address of the Redis instance to connect to.
If this is not provided, then the local scheduler will not connect
to Redis.
use_valgrind (bool): True if the local scheduler should be started
inside of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the local scheduler should be started
inside a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
static_resources: A dictionary specifying the local scheduler's
resource capacities. This maps resource names (strings) to
integers or floats.
num_workers (int): The number of workers that the local scheduler
should start.
Return:
A tuple of the name of the local scheduler socket and the process ID of
the local scheduler process.
"""
if (plasma_manager_name is None) != (redis_address is None):
raise Exception("If one of the plasma_manager_name and the "
"redis_address is provided, then both must be "
"provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = get_local_scheduler_socket_name()
command = [
local_scheduler_executable, "-s", local_scheduler_name, "-p",
plasma_store_name, "-h", node_ip_address, "-n",
str(num_workers)
]
if plasma_manager_name is not None:
command += ["-m", plasma_manager_name]
if worker_path is not None:
assert plasma_store_name is not None
assert plasma_manager_name is not None
assert redis_address is not None
start_worker_command = ("{} {} "
"--node-ip-address={} "
"--object-store-name={} "
"--object-store-manager-name={} "
"--local-scheduler-name={} "
"--redis-address={} "
"--temp-dir={}".format(
sys.executable, worker_path,
node_ip_address, plasma_store_name,
plasma_manager_name, local_scheduler_name,
redis_address, get_temp_root()))
command += ["-w", start_worker_command]
if redis_address is not None:
command += ["-r", redis_address]
if plasma_address is not None:
command += ["-a", plasma_address]
if static_resources is not None:
resource_argument = ""
for resource_name, resource_quantity in static_resources.items():
assert (isinstance(resource_quantity, int)
or isinstance(resource_quantity, float))
resource_argument = ",".join([
resource_name + "," + str(resource_quantity)
for resource_name, resource_quantity in static_resources.items()
])
else:
resource_argument = "CPU,{}".format(multiprocessing.cpu_count())
command += ["-c", resource_argument]
if use_valgrind:
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return local_scheduler_name, pid
-206
View File
@@ -1,206 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import signal
import sys
import threading
import time
import unittest
import ray.local_scheduler as local_scheduler
import ray.plasma as plasma
import ray.ray_constants as ray_constants
import pyarrow as pa
USE_VALGRIND = False
NIL_WORKER_ID = ray_constants.ID_SIZE * b"\xff"
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
class TestLocalSchedulerClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p1 = plasma.start_plasma_store()
self.plasma_client = pa.plasma.connect(plasma_store_name, "", 0)
# Start a local scheduler.
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
scheduler_name, NIL_WORKER_ID, False, random_task_id(), False)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
self.assertEqual(self.p2.poll(), None)
# Kill Plasma.
self.p1.kill()
# Kill the local scheduler.
if USE_VALGRIND:
self.p2.send_signal(signal.SIGTERM)
self.p2.wait()
if self.p2.returncode != 0:
os._exit(-1)
else:
self.p2.kill()
def test_submit_and_get_task(self):
function_id = random_function_id()
object_ids = [random_object_id() for i in range(256)]
# Create and seal the objects in the object store so that we can
# schedule all of the subsequent tasks.
for object_id in object_ids:
self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0)
self.plasma_client.seal(pa.plasma.ObjectID(object_id.id()))
# Define some arguments to use for the tasks.
args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1],
1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [
1, 1.3, 1 << 100, "hi", u"hi", [1, 2]
], object_ids[:1], object_ids[:2], object_ids[:3],
object_ids[:4], object_ids[:5], object_ids[:10],
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
object_ids[0], "a"
], [1, object_ids[0], "a"], [
object_ids[0], 1, object_ids[1], "a"
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id,
args, num_return_vals,
random_task_id(), 0)
# Submit a task.
self.local_scheduler_client.submit(task)
# Get the task.
new_task = self.local_scheduler_client.get_task()
self.assertEqual(task.function_id().id(),
new_task.function_id().id())
retrieved_args = new_task.arguments()
returns = new_task.returns()
self.assertEqual(len(args), len(retrieved_args))
self.assertEqual(num_return_vals, len(returns))
for i in range(len(retrieved_args)):
if isinstance(args[i], local_scheduler.ObjectID):
self.assertEqual(args[i].id(), retrieved_args[i].id())
else:
self.assertEqual(args[i], retrieved_args[i])
# Submit all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id,
args, num_return_vals,
random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Get all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
new_task = self.local_scheduler_client.get_task()
def test_scheduling_when_objects_ready(self):
# Create a task and submit it.
object_id = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id], 0, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Sleep to give the thread time to call get_task.
time.sleep(0.1)
# Create and seal the object ID in the object store. This should
# trigger a scheduling event.
self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0)
self.plasma_client.seal(pa.plasma.ObjectID(object_id.id()))
# Wait until the thread finishes so that we know the task was
# scheduled.
t.join()
def test_scheduling_when_objects_evicted(self):
# Create a task with two dependencies and submit it.
object_id1 = random_object_id()
object_id2 = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id1, object_id2], 0,
random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Make one of the dependencies available.
buf = self.plasma_client.create(pa.plasma.ObjectID(object_id1.id()), 1)
self.plasma_client.seal(pa.plasma.ObjectID(object_id1.id()))
# Release the object.
del buf
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Force eviction of the first dependency.
self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY)
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Check that the first object dependency was evicted.
object1 = self.plasma_client.get_buffers(
[pa.plasma.ObjectID(object_id1.id())], timeout_ms=0)
self.assertEqual(object1, [None])
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Create the second dependency.
self.plasma_client.create(pa.plasma.ObjectID(object_id2.id()), 1)
self.plasma_client.seal(pa.plasma.ObjectID(object_id2.id()))
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Create the first dependency again. Both dependencies are now
# available.
self.plasma_client.create(pa.plasma.ObjectID(object_id1.id()), 1)
self.plasma_client.seal(pa.plasma.ObjectID(object_id1.id()))
# Wait until the thread finishes so that we know the task was
# scheduled.
t.join()
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument
# parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
+11 -446
View File
@@ -3,11 +3,9 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
import logging
import os
import time
from collections import Counter, defaultdict
import traceback
import redis
@@ -20,27 +18,6 @@ import ray.utils
import ray.ray_constants as ray_constants
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
from ray.worker import NIL_ACTOR_ID
# These variables must be kept in sync with the C codebase.
# common/common.h
NIL_ID = b"\xff" * ray_constants.ID_SIZE
# common/task.h
TASK_STATUS_LOST = 32
# common/redis_module/ray_redis_module.cc
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
TASK_TABLE_PREFIX = b"TT:"
DB_CLIENT_PREFIX = b"CL:"
DB_CLIENT_TABLE_NAME = b"db_clients"
# local_scheduler/local_scheduler.h
LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler"
# plasma/plasma_manager.cc
PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager"
# Set up logging.
logger = logging.getLogger(__name__)
@@ -55,19 +32,8 @@ class Monitor(object):
Attributes:
redis: A connection to the Redis server.
use_raylet: A bool indicating whether to use the raylet code path or
not.
subscribe_client: A pubsub client for the Redis server. This is used to
receive notifications about failed components.
dead_local_schedulers: A set of the local scheduler IDs of all of the
local schedulers that were up at one point and have died since
then.
live_plasma_managers: A counter mapping live plasma manager IDs to the
number of heartbeats that have passed since we last heard from that
plasma manager. A plasma manager is live if we received a heartbeat
from it at any point, and if it has not timed out.
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
managers that were up at one point and have died since then.
"""
def __init__(self,
@@ -79,26 +45,16 @@ class Monitor(object):
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(
redis_address, redis_port, redis_password=redis_password)
self.use_raylet = self.state.use_raylet
self.redis = redis.StrictRedis(
host=redis_address, port=redis_port, db=0, password=redis_password)
# Setup subscriptions to the primary Redis server and the Redis shards.
self.primary_subscribe_client = self.redis.pubsub(
ignore_subscribe_messages=True)
if self.use_raylet:
self.shard_subscribe_clients = []
for redis_client in self.state.redis_clients:
subscribe_client = redis_client.pubsub(
ignore_subscribe_messages=True)
self.shard_subscribe_clients.append(subscribe_client)
else:
# We don't need to subscribe to the shards in legacy Ray.
self.shard_subscribe_clients = []
# Initialize data structures to keep track of the active database
# clients.
self.dead_local_schedulers = set()
self.live_plasma_managers = Counter()
self.dead_plasma_managers = set()
self.shard_subscribe_clients = []
for redis_client in self.state.redis_clients:
subscribe_client = redis_client.pubsub(
ignore_subscribe_messages=True)
self.shard_subscribe_clients.append(subscribe_client)
# Keep a mapping from local scheduler client ID to IP address to use
# for updating the load metrics.
self.local_scheduler_id_to_ip_map = {}
@@ -152,170 +108,6 @@ class Monitor(object):
for subscribe_client in self.shard_subscribe_clients:
subscribe_client.subscribe(channel)
def cleanup_task_table(self):
"""Clean up global state for failed local schedulers.
This marks any tasks that were scheduled on dead local schedulers as
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
tasks = self.state.task_table()
num_tasks_updated = 0
for task_id, task in tasks.items():
# See if the corresponding local scheduler is alive.
if task["LocalSchedulerID"] not in self.dead_local_schedulers:
continue
# Remove dummy objects returned by actor tasks from any plasma
# manager. Although the objects may still exist in that object
# store, this deletion makes them effectively unreachable by any
# local scheduler connected to a different store.
# TODO(swang): Actually remove the objects from the object store,
# so that the reconstructed actor can reuse the same object store.
if hex_to_binary(task["TaskSpec"]["ActorID"]) != NIL_ACTOR_ID:
dummy_object_id = task["TaskSpec"]["ReturnObjectIDs"][-1]
obj = self.state.object_table(dummy_object_id)
manager_ids = obj["ManagerIDs"]
if manager_ids is not None:
# The dummy object should exist on at most one plasma
# manager, the manager associated with the local scheduler
# that died.
assert len(manager_ids) <= 1
# Remove the dummy object from the plasma manager
# associated with the dead local scheduler, if any.
for manager in manager_ids:
ok = self.state._execute_command(
dummy_object_id, "RAY.OBJECT_TABLE_REMOVE",
dummy_object_id.id(), hex_to_binary(manager))
if ok != b"OK":
logger.warn("Failed to remove object location for "
"dead plasma manager.")
# If the task is scheduled on a dead local scheduler, mark the
# task as lost.
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
task["ExecutionDependenciesString"], task["SpillbackCount"])
if ok != b"OK":
logger.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
if num_tasks_updated > 0:
logger.warn("Marked {} tasks as lost.".format(num_tasks_updated))
def cleanup_object_table(self):
"""Clean up global state for failed plasma managers.
This removes dead plasma managers from any location entries in the
object table. A plasma manager is deemed dead if it is in
self.dead_plasma_managers.
"""
# TODO(swang): Also kill the associated plasma store, since it's no
# longer reachable without a plasma manager.
objects = self.state.object_table()
num_objects_removed = 0
for object_id, obj in objects.items():
manager_ids = obj["ManagerIDs"]
if manager_ids is None:
continue
for manager in manager_ids:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that
# location entry.
ok = self.state._execute_command(
object_id, "RAY.OBJECT_TABLE_REMOVE", object_id.id(),
hex_to_binary(manager))
if ok != b"OK":
logger.warn("Failed to remove object location for "
"dead plasma manager.")
num_objects_removed += 1
if num_objects_removed > 0:
logger.warn("Marked {} objects as lost."
.format(num_objects_removed))
def scan_db_client_table(self):
"""Scan the database client table for dead clients.
After subscribing to the client table, it's necessary to call this
before reading any messages from the subscription channel. This ensures
that we do not miss any notifications for deleted clients that occurred
before we subscribed.
"""
# Exit if we are using the raylet code path because client_table is
# implemented differently. TODO(rkn): Fix this.
if self.use_raylet:
return
clients = self.state.client_table()
for node_ip_address, node_clients in clients.items():
for client in node_clients:
db_client_id = client["DBClientID"]
client_type = client["ClientType"]
if client["Deleted"]:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
def db_client_notification_handler(self, unused_channel, data):
"""Handle a notification from the db_client table from Redis.
This handler processes notifications from the db_client table.
Notifications should be parsed using the SubscribeToDBClientTableReply
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of
the associated state in the state tables should be handled by the
caller.
"""
notification_object = (ray.gcs_utils.SubscribeToDBClientTableReply.
GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
is_insertion = notification_object.IsInsertion()
# If the update was an insertion, we ignore it.
if is_insertion:
return
# If the update was a deletion, add them to our accounting for dead
# local schedulers and plasma managers.
logger.warn("Removed {}, client ID {}".format(client_type,
db_client_id))
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
if db_client_id not in self.dead_local_schedulers:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
if db_client_id not in self.dead_plasma_managers:
self.dead_plasma_managers.add(db_client_id)
# Stop tracking this plasma manager's heartbeats, since it's
# already dead.
del self.live_plasma_managers[db_client_id]
def local_scheduler_info_handler(self, unused_channel, data):
"""Handle a local scheduler heartbeat from Redis."""
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
GetRootAsLocalSchedulerInfoMessage(data, 0))
num_resources = message.DynamicResourcesLength()
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
dyn = message.DynamicResources(i)
static = message.StaticResources(i)
dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
static_resources[static.Key().decode("utf-8")] = static.Value()
# Update the load metrics for this local scheduler.
client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
ip = self.local_scheduler_id_to_ip_map.get(client_id)
if ip:
self.load_metrics.update(ip, static_resources, dynamic_resources)
else:
logger.warning(
"Warning: could not find ip for client {} in {}.".format(
client_id, self.local_scheduler_id_to_ip_map))
def xray_heartbeat_handler(self, unused_channel, data):
"""Handle an xray heartbeat message from Redis."""
@@ -342,160 +134,6 @@ class Monitor(object):
print("Warning: could not find ip for client {} in {}.".format(
client_id, self.local_scheduler_id_to_ip_map))
def plasma_manager_heartbeat_handler(self, unused_channel, data):
"""Handle a plasma manager heartbeat from Redis.
This resets the number of heartbeats that we've missed from this plasma
manager.
"""
# The first ray_constants.ID_SIZE characters are the client ID.
db_client_id = data[:ray_constants.ID_SIZE]
# Reset the number of heartbeats that we've missed from this plasma
# manager.
self.live_plasma_managers[db_client_id] = 0
def _entries_for_driver_in_shard(self, driver_id, redis_shard_index):
"""Collect IDs of control-state entries for a driver from a shard.
Args:
driver_id: The ID of the driver.
redis_shard_index: The index of the Redis shard to query.
Returns:
Lists of IDs: (returned_object_ids, task_ids, put_objects). The
first two are relevant to the driver and are safe to delete.
The last contains all "put" objects in this redis shard; each
element is an (object_id, corresponding task_id) pair.
"""
# TODO(zongheng): consider adding save & restore functionalities.
redis = self.state.redis_clients[redis_shard_index]
task_table_infos = {} # task id -> TaskInfo messages
# Scan the task table & filter to get the list of tasks belong to this
# driver. Use a cursor in order not to block the redis shards.
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
entry = redis.hgetall(key)
task_info = ray.gcs_utils.TaskInfo.GetRootAsTaskInfo(
entry[b"TaskSpec"], 0)
if driver_id != task_info.DriverId():
# Ignore tasks that aren't from this driver.
continue
task_table_infos[task_info.TaskId()] = task_info
# Get the list of objects returned by these tasks. Note these might
# not belong to this redis shard.
returned_object_ids = []
for task_info in task_table_infos.values():
returned_object_ids.extend([
task_info.Returns(i) for i in range(task_info.ReturnsLength())
])
# Also record all the ray.put()'d objects.
put_objects = []
for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
entry = redis.hgetall(key)
if entry[b"is_put"] == "0":
continue
object_id = key.split(OBJECT_INFO_PREFIX)[1]
task_id = entry[b"task"]
put_objects.append((object_id, task_id))
return returned_object_ids, task_table_infos.keys(), put_objects
def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index):
redis = self.state.redis_clients[shard_index]
# Clean up (in the future, save) entries for non-empty objects.
object_ids_locs = set()
object_ids_infos = set()
for object_id in object_ids:
# OL.
obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1)
if obj_loc:
object_ids_locs.add(object_id)
# OI.
obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id)
if obj_info:
object_ids_infos.add(object_id)
# Form the redis keys to delete.
keys = [TASK_TABLE_PREFIX + k for k in task_ids]
keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs])
keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos])
if not keys:
return
# Remove with best effort.
num_deleted = redis.delete(*keys)
logger.info(
"Removed {} dead redis entries of the driver from redis shard {}.".
format(num_deleted, shard_index))
if num_deleted != len(keys):
logger.warning(
"Failed to remove {} relevant redis entries"
" from redis shard {}.".format(len(keys) - num_deleted))
def _clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from all redis shards.
Specifically, removes control-state entries of:
* all objects (OI and OL entries) created by `ray.put()` from the
driver
* all tasks belonging to the driver.
"""
# TODO(zongheng): handle function_table, client_table, log_files --
# these are in the metadata redis server, not in the shards.
driver_object_ids = []
driver_task_ids = []
all_put_objects = []
# Collect relevant ids.
# TODO(zongheng): consider parallelizing this loop.
for shard_index in range(len(self.state.redis_clients)):
returned_object_ids, task_ids, put_objects = \
self._entries_for_driver_in_shard(driver_id, shard_index)
driver_object_ids.extend(returned_object_ids)
driver_task_ids.extend(task_ids)
all_put_objects.extend(put_objects)
# For the put objects, keep those from relevant tasks.
driver_task_ids_set = set(driver_task_ids)
for object_id, task_id in all_put_objects:
if task_id in driver_task_ids_set:
driver_object_ids.append(object_id)
# Partition IDs and distribute to shards.
object_ids_per_shard = defaultdict(list)
task_ids_per_shard = defaultdict(list)
def ToShardIndex(index):
return binary_to_object_id(index).redis_shard_hash() % len(
self.state.redis_clients)
for object_id in driver_object_ids:
object_ids_per_shard[ToShardIndex(object_id)].append(object_id)
for task_id in driver_task_ids:
task_ids_per_shard[ToShardIndex(task_id)].append(task_id)
# TODO(zongheng): consider parallelizing this loop.
for shard_index in range(len(self.state.redis_clients)):
self._clean_up_entries_from_shard(
object_ids_per_shard[shard_index],
task_ids_per_shard[shard_index], shard_index)
def driver_removed_handler(self, unused_channel, data):
"""Handle a notification that a driver has been removed.
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = ray.gcs_utils.DriverTableMessage.GetRootAsDriverTableMessage(
data, 0)
driver_id = message.DriverId()
logger.info("Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._clean_up_entries_for_driver(driver_id)
def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
@@ -529,7 +167,7 @@ class Monitor(object):
driver_object_id_bins = set()
for object_id, object_table_object in object_table_objects.items():
assert len(object_table_object) > 0
task_id_bin = ray.local_scheduler.compute_task_id(object_id).id()
task_id_bin = ray.raylet.compute_task_id(object_id).id()
if task_id_bin in driver_task_id_bins:
driver_object_id_bins.add(object_id.id())
@@ -602,20 +240,7 @@ class Monitor(object):
# Determine the appropriate message handler.
message_handler = None
if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL:
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL:
# The message was a heartbeat from a local scheduler
message_handler = self.local_scheduler_info_handler
elif channel == DB_CLIENT_TABLE_NAME:
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL:
# The message was a notification that a driver was removed.
logger.info("message-handler: driver_removed_handler")
message_handler = self.driver_removed_handler
elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL:
if channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
@@ -629,10 +254,7 @@ class Monitor(object):
message_handler(channel, data)
def update_local_scheduler_map(self):
if self.use_raylet:
local_schedulers = self.state.client_table()
else:
local_schedulers = self.state.local_schedulers()
local_schedulers = self.state.client_table()
self.local_scheduler_id_to_ip_map = {}
for local_scheduler_info in local_schedulers:
client_id = local_scheduler_info.get("DBClientID") or \
@@ -680,33 +302,11 @@ class Monitor(object):
clients and cleaning up state accordingly.
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL)
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.
# This ensures that we start in a consistent state, since we may have
# missed notifications that were sent before we connected to the
# subscription channel.
self.scan_db_client_table()
# If there were any dead clients at startup, clean up the associated
# state in the state tables.
if len(self.dead_local_schedulers) > 0:
self.cleanup_task_table()
if len(self.dead_plasma_managers) > 0:
self.cleanup_object_table()
num_plasma_managers = len(self.live_plasma_managers) + len(
self.dead_plasma_managers)
logger.debug("{} dead local schedulers, {} plasma managers total, {} "
"dead plasma managers".format(
len(self.dead_local_schedulers), num_plasma_managers,
len(self.dead_plasma_managers)))
# TODO(rkn): If there were any dead clients at startup, we should clean
# up the associated state in the state tables.
# Handle messages from the subscription channels.
while True:
@@ -720,43 +320,9 @@ class Monitor(object):
self._maybe_flush_gcs()
# Record how many dead local schedulers and plasma managers we had
# at the beginning of this round.
num_dead_local_schedulers = len(self.dead_local_schedulers)
num_dead_plasma_managers = len(self.dead_plasma_managers)
# Process a round of messages.
self.process_messages()
# If any new local schedulers or plasma managers were marked as
# dead in this round, clean up the associated state.
if len(self.dead_local_schedulers) > num_dead_local_schedulers:
self.cleanup_task_table()
if len(self.dead_plasma_managers) > num_dead_plasma_managers:
self.cleanup_object_table()
# Handle plasma managers that timed out during this round.
plasma_manager_ids = list(self.live_plasma_managers.keys())
for plasma_manager_id in plasma_manager_ids:
if ((self.live_plasma_managers[plasma_manager_id]) >=
ray._config.num_heartbeats_timeout()):
logger.warn("Timed out {}"
.format(PLASMA_MANAGER_CLIENT_TYPE))
# Remove the plasma manager from the managers whose
# heartbeats we're tracking.
del self.live_plasma_managers[plasma_manager_id]
# Remove the plasma manager from the db_client table. The
# corresponding state in the object table will be cleaned
# up once we receive the notification for this db_client
# deletion.
self.redis.execute_command("RAY.DISCONNECT",
plasma_manager_id)
# Increment the number of heartbeats that we've missed from each
# plasma manager.
for plasma_manager_id in self.live_plasma_managers:
self.live_plasma_managers[plasma_manager_id] += 1
# Wait for a heartbeat interval before processing the next round of
# messages.
time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3)
@@ -827,6 +393,5 @@ if __name__ == "__main__":
message = "The monitor failed with the following error:\n{}".format(
traceback_str)
ray.utils.push_error_to_driver_through_redis(
redis_client, monitor.use_raylet, ray_constants.MONITOR_DIED_ERROR,
message)
redis_client, ray_constants.MONITOR_DIED_ERROR, message)
raise e
+2 -5
View File
@@ -2,9 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.plasma.plasma import (start_plasma_store, start_plasma_manager,
DEFAULT_PLASMA_STORE_MEMORY)
from ray.plasma.plasma import start_plasma_store, DEFAULT_PLASMA_STORE_MEMORY
__all__ = [
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
]
__all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"]
+2 -101
View File
@@ -3,17 +3,13 @@ from __future__ import division
from __future__ import print_function
import os
import random
import subprocess
import sys
import time
from ray.tempfile_services import (get_object_store_socket_name,
get_plasma_manager_socket_name)
from ray.tempfile_services import get_object_store_socket_name
__all__ = [
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
]
__all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"]
PLASMA_WAIT_TIMEOUT = 2**30
@@ -97,98 +93,3 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return plasma_store_name, pid
def new_port():
return random.randint(10000, 65535)
def start_plasma_manager(store_name,
redis_address,
node_ip_address="127.0.0.1",
plasma_manager_port=None,
num_retries=20,
use_valgrind=False,
run_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a plasma manager and return the ports it listens on.
Args:
store_name (str): The name of the plasma store socket.
redis_address (str): The address of the Redis server.
node_ip_address (str): The IP address of the node.
plasma_manager_port (int): The port to use for the plasma manager. If
this is not provided, a port will be generated at random.
use_valgrind (bool): True if the Plasma manager should be started
inside of valgrind and False otherwise.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
Returns:
A tuple of the Plasma manager socket name, the process ID of the
Plasma manager process, and the port that the manager is
listening on.
Raises:
Exception: An exception is raised if the manager could not be started.
"""
plasma_manager_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_manager")
plasma_manager_name = get_plasma_manager_socket_name()
if plasma_manager_port is not None:
if num_retries != 1:
raise Exception("num_retries must be 1 if port is specified.")
else:
plasma_manager_port = new_port()
process = None
counter = 0
while counter < num_retries:
if counter > 0:
print("Plasma manager failed to start, retrying now.")
command = [
plasma_manager_executable,
"-s",
store_name,
"-m",
plasma_manager_name,
"-h",
node_ip_address,
"-p",
str(plasma_manager_port),
"-r",
redis_address,
]
if use_valgrind:
process = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen(
(["valgrind", "--tool=callgrind"] + command),
stdout=stdout_file,
stderr=stderr_file)
else:
process = subprocess.Popen(
command, stdout=stdout_file, stderr=stderr_file)
# This sleep is critical. If the plasma_manager fails to start because
# the port is already in use, then we need it to fail within 0.1
# seconds.
if use_valgrind:
time.sleep(1)
else:
time.sleep(0.1)
# See if the process has terminated
if process.poll() is None:
return plasma_manager_name, process, plasma_manager_port
# Generate a new port and try again.
plasma_manager_port = new_port()
counter += 1
raise Exception("Couldn't start plasma manager.")
-560
View File
@@ -1,560 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from numpy.testing import assert_equal
import os
import random
import signal
import subprocess
import sys
import threading
import time
import unittest
# The ray import must come before the pyarrow import because ray modifies the
# python path so that the right version of pyarrow is found.
import ray
from ray.plasma.utils import (random_object_id, create_object_with_id,
create_object)
import ray.ray_constants as ray_constants
from ray import services
import pyarrow as pa
import pyarrow.plasma as plasma
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
def random_name():
return str(random.randint(0, 99999999))
def assert_get_object_equal(unit_test,
client1,
client2,
object_id,
memory_buffer=None,
metadata=None):
client1_buff = client1.get_buffers([object_id])[0]
client2_buff = client2.get_buffers([object_id])[0]
client1_metadata = client1.get_metadata([object_id])[0]
client2_metadata = client2.get_metadata([object_id])[0]
unit_test.assertEqual(len(client1_buff), len(client2_buff))
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
# Check that the buffers from the two clients are the same.
assert_equal(
np.frombuffer(client1_buff, dtype="uint8"),
np.frombuffer(client2_buff, dtype="uint8"))
# Check that the metadata buffers from the two clients are the same.
assert_equal(
np.frombuffer(client1_metadata, dtype="uint8"),
np.frombuffer(client2_metadata, dtype="uint8"))
# If a reference buffer was provided, check that it is the same as well.
if memory_buffer is not None:
assert_equal(
np.frombuffer(memory_buffer, dtype="uint8"),
np.frombuffer(client1_buff, dtype="uint8"))
# If reference metadata was provided, check that it is the same as well.
if metadata is not None:
assert_equal(
np.frombuffer(metadata, dtype="uint8"),
np.frombuffer(client1_metadata, dtype="uint8"))
DEFAULT_PLASMA_STORE_MEMORY = 10**9
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a plasma store process.
Args:
use_valgrind (bool): True if the plasma store should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the plasma store should be started inside
a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
Return:
A tuple of the name of the plasma store socket and the process ID of
the plasma store process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(pa.__path__[0],
"plasma_store_server")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [
plasma_store_executable, "-s", plasma_store_name, "-m",
str(plasma_store_memory)
]
if use_valgrind:
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return plasma_store_name, pid
# Plasma client tests were moved into arrow
class TestPlasmaManager(unittest.TestCase):
def setUp(self):
# Start two PlasmaStores.
store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND)
store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND)
# Start a Redis server.
redis_address, _ = services.start_redis("127.0.0.1", use_raylet=False)
# Start two PlasmaManagers.
manager_name1, self.p4, self.port1 = ray.plasma.start_plasma_manager(
store_name1, redis_address, use_valgrind=USE_VALGRIND)
manager_name2, self.p5, self.port2 = ray.plasma.start_plasma_manager(
store_name2, redis_address, use_valgrind=USE_VALGRIND)
# Connect two PlasmaClients.
self.client1 = plasma.connect(store_name1, manager_name1, 64)
self.client2 = plasma.connect(store_name2, manager_name2, 64)
# Store the processes that will be explicitly killed during tearDown so
# that a test case can remove ones that will be killed during the test.
# NOTE: If this specific order is changed, valgrind will fail.
self.processes_to_kill = [self.p4, self.p5, self.p2, self.p3]
def tearDown(self):
# Check that the processes are still alive.
for process in self.processes_to_kill:
self.assertEqual(process.poll(), None)
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
# Give processes opportunity to finish work.
time.sleep(1)
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
if process.returncode != 0:
print("aborting due to valgrind error")
os._exit(-1)
else:
for process in self.processes_to_kill:
process.kill()
# Clean up the Redis server.
services.cleanup()
def test_fetch(self):
for _ in range(10):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
self.client1.fetch([object_id1])
self.assertEqual(self.client1.contains(object_id1), True)
self.assertEqual(self.client2.contains(object_id1), False)
# Fetch the object from the other plasma manager.
# TODO(rkn): Right now we must wait for the object table to be
# updated.
while not self.client2.contains(object_id1):
self.client2.fetch([object_id1])
# Compare the two buffers.
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
# Test that we can call fetch on object IDs that don't exist yet.
object_id2 = random_object_id()
self.client1.fetch([object_id2])
self.assertEqual(self.client1.contains(object_id2), False)
memory_buffer2, metadata2 = create_object_with_id(
self.client2, object_id2, 2000, 2000)
# # Check that the object has been fetched.
# self.assertEqual(self.client1.contains(object_id2), True)
# Compare the two buffers.
# assert_get_object_equal(self, self.client1, self.client2, object_id2,
# memory_buffer=memory_buffer2,
# metadata=metadata2)
# Test calling the same fetch request a bunch of times.
object_id3 = random_object_id()
self.assertEqual(self.client1.contains(object_id3), False)
self.assertEqual(self.client2.contains(object_id3), False)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
memory_buffer3, metadata3 = create_object_with_id(
self.client1, object_id3, 2000, 2000)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id3):
self.client2.fetch([object_id3])
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id3,
memory_buffer=memory_buffer3,
metadata=metadata3)
def test_fetch_multiple(self):
for _ in range(20):
# Create two objects and a third fake one that doesn't exist.
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
missing_object_id = random_object_id()
object_id2, memory_buffer2, metadata2 = create_object(
self.client1, 2000, 2000)
object_ids = [object_id1, missing_object_id, object_id2]
# Fetch the objects from the other plasma store. The second object
# ID should timeout since it does not exist.
# TODO(rkn): Right now we must wait for the object table to be
# updated.
while ((not self.client2.contains(object_id1))
or (not self.client2.contains(object_id2))):
self.client2.fetch(object_ids)
# Compare the buffers of the objects that do exist.
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
# Fetch in the other direction. The fake object still does not
# exist.
self.client1.fetch(object_ids)
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
# Check that we can call fetch with duplicated object IDs.
object_id3 = random_object_id()
self.client1.fetch([object_id3, object_id3])
object_id4, memory_buffer4, metadata4 = create_object(
self.client1, 2000, 2000)
time.sleep(0.1)
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id4):
self.client2.fetch(
[object_id3, object_id3, object_id4, object_id4])
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id4,
memory_buffer=memory_buffer4,
metadata=metadata4)
def test_wait(self):
# Test timeout.
obj_id0 = random_object_id()
self.client1.wait([obj_id0], timeout=100, num_returns=1)
# If we get here, the test worked.
# Test wait if local objects available.
obj_id1 = random_object_id()
self.client1.create(obj_id1, 1000)
self.client1.seal(obj_id1)
ready, waiting = self.client1.wait(
[obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), {obj_id1})
self.assertEqual(waiting, [])
# Test wait if only one object available and only one object waited
# for.
obj_id2 = random_object_id()
self.client1.create(obj_id2, 1000)
# Don't seal.
ready, waiting = self.client1.wait(
[obj_id2, obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), {obj_id1})
self.assertEqual(set(waiting), {obj_id2})
# Test wait if object is sealed later.
obj_id3 = random_object_id()
def finish():
self.client2.create(obj_id3, 1000)
self.client2.seal(obj_id3)
t = threading.Timer(0.1, finish)
t.start()
ready, waiting = self.client1.wait(
[obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
self.assertEqual(set(ready), {obj_id1, obj_id3})
self.assertEqual(set(waiting), {obj_id2})
# Test if the appropriate number of objects is shown if some objects
# are not ready.
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3)
self.assertEqual(set(ready), {obj_id1, obj_id3})
self.assertEqual(set(waiting), {obj_id2})
# Don't forget to seal obj_id2.
self.client1.seal(obj_id2)
# Test calling wait a bunch of times.
object_ids = []
# TODO(rkn): Increasing n to 100 (or larger) will cause failures. The
# problem appears to be that the number of timers added to the manager
# event loop slow down the manager so much that some of the
# asynchronous Redis commands timeout triggering fatal failure
# callbacks.
n = 40
for i in range(n * (n + 1) // 2):
if i % 2 == 0:
object_id, _, _ = create_object(self.client1, 200, 200)
else:
object_id, _, _ = create_object(self.client2, 200, 200)
object_ids.append(object_id)
# Try waiting for all of the object IDs on the first client.
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client1.wait(
waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client1.wait(
object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Try waiting for all of the object IDs on the second client.
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client2.wait(
waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client2.wait(
object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Make sure that wait returns when the requested number of object IDs
# are available and does not wait for all object IDs to be available.
object_ids = [random_object_id() for _ in range(9)] + \
[plasma.ObjectID(ray_constants.ID_SIZE * b'\x00')]
object_ids_perm = object_ids[:]
random.shuffle(object_ids_perm)
for i in range(10):
if i % 2 == 0:
create_object_with_id(self.client1, object_ids_perm[i], 2000,
2000)
else:
create_object_with_id(self.client2, object_ids_perm[i], 2000,
2000)
ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1))
self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)]))
self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):]))
def test_transfer(self):
num_attempts = 100
for _ in range(100):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
# Transfer the buffer to the the other Plasma store. There is a
# race condition on the create and transfer of the object, so keep
# trying until the object appears on the second Plasma store.
for i in range(num_attempts):
self.client1.transfer("127.0.0.1", self.port2, object_id1)
buff = self.client2.get_buffers(
[object_id1], timeout_ms=100)[0]
if buff is not None:
break
self.assertNotEqual(buff, None)
del buff
# Compare the two buffers.
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
# # Transfer the buffer again.
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
# # Compare the two buffers.
# assert_get_object_equal(self, self.client1, self.client2,
# object_id1,
# memory_buffer=memory_buffer1,
# metadata=metadata1)
# Create an object.
object_id2, memory_buffer2, metadata2 = create_object(
self.client2, 20000, 20000)
# Transfer the buffer to the the other Plasma store. There is a
# race condition on the create and transfer of the object, so keep
# trying until the object appears on the second Plasma store.
for i in range(num_attempts):
self.client2.transfer("127.0.0.1", self.port1, object_id2)
buff = self.client1.get_buffers(
[object_id2], timeout_ms=100)[0]
if buff is not None:
break
self.assertNotEqual(buff, None)
del buff
# Compare the two buffers.
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
def test_illegal_functionality(self):
# Create an object id string.
# object_id = random_object_id()
# Create a new buffer.
# memory_buffer = self.client1.create(object_id, 20000)
# This test is commented out because it currently fails.
# # Transferring the buffer before sealing it should fail.
# self.assertRaises(Exception,
# lambda : self.manager1.transfer(1, object_id))
pass
def test_stresstest(self):
a = time.time()
object_ids = []
for i in range(10000): # TODO(pcm): increase this to 100000.
object_id = random_object_id()
object_ids.append(object_id)
self.client1.create(object_id, 1)
self.client1.seal(object_id)
for object_id in object_ids:
self.client1.transfer("127.0.0.1", self.port2, object_id)
b = time.time() - a
print("it took", b, "seconds to put and transfer the objects")
class TestPlasmaManagerRecovery(unittest.TestCase):
def setUp(self):
# Start a Plasma store.
self.store_name, self.p2 = start_plasma_store(
use_valgrind=USE_VALGRIND)
# Start a Redis server.
self.redis_address, _ = services.start_redis(
"127.0.0.1", use_raylet=False)
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager(
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
# Connect a PlasmaClient.
self.client = plasma.connect(self.store_name, manager_name, 64)
# Store the processes that will be explicitly killed during tearDown so
# that a test case can remove ones that will be killed during the test.
# NOTE: The plasma managers must be killed before the plasma store
# since plasma store death will bring down the managers.
self.processes_to_kill = [self.p3, self.p2]
def tearDown(self):
# Check that the processes are still alive.
for process in self.processes_to_kill:
self.assertEqual(process.poll(), None)
# Kill the Plasma store and Plasma manager processes.
if USE_VALGRIND:
# Give processes opportunity to finish work.
time.sleep(1)
for process in self.processes_to_kill:
process.send_signal(signal.SIGTERM)
process.wait()
if process.returncode != 0:
print("aborting due to valgrind error")
os._exit(-1)
else:
for process in self.processes_to_kill:
process.kill()
# Clean up the Redis server.
services.cleanup()
def test_delayed_start(self):
num_objects = 10
# Create some objects using one client.
object_ids = [random_object_id() for _ in range(num_objects)]
for i in range(10):
create_object_with_id(self.client, object_ids[i], 2000, 2000)
# Wait until the objects have been sealed in the store.
ready, waiting = self.client.wait(object_ids, num_returns=num_objects)
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Start a second plasma manager attached to the same store.
manager_name, self.p5, self.port2 = ray.plasma.start_plasma_manager(
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
self.processes_to_kill = [self.p5] + self.processes_to_kill
# Check that the second manager knows about existing objects.
client2 = plasma.connect(self.store_name, manager_name, 64)
ready, waiting = [], object_ids
while True:
ready, waiting = client2.wait(
object_ids, num_returns=num_objects, timeout=0)
if len(ready) == len(object_ids):
break
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument
# parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
-53
View File
@@ -1,53 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import pyarrow.plasma as plasma
import ray.ray_constants as ray_constants
def random_object_id():
return plasma.ObjectID(np.random.bytes(ray_constants.ID_SIZE))
def generate_metadata(length):
metadata_buffer = bytearray(length)
if length > 0:
metadata_buffer[0] = random.randint(0, 255)
metadata_buffer[-1] = random.randint(0, 255)
for _ in range(100):
metadata_buffer[random.randint(0, length - 1)] = (random.randint(
0, 255))
return metadata_buffer
def write_to_data_buffer(buff, length):
array = np.frombuffer(buff, dtype="uint8")
if length > 0:
array[0] = random.randint(0, 255)
array[-1] = random.randint(0, 255)
for _ in range(100):
array[random.randint(0, length - 1)] = random.randint(0, 255)
def create_object_with_id(client,
object_id,
data_size,
metadata_size,
seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
write_to_data_buffer(memory_buffer, data_size)
if seal:
client.seal(object_id)
return memory_buffer, metadata
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(
client, object_id, data_size, metadata_size, seal=seal)
return object_id, memory_buffer, metadata
+7 -84
View File
@@ -59,17 +59,7 @@ def profile(event_type, extra_data=None, worker=None):
"""
if worker is None:
worker = ray.worker.global_worker
if not worker.use_raylet:
# Log the event if this is a worker and not a driver, since the
# driver's event log never gets flushed.
if worker.mode == ray.WORKER_MODE:
return RayLogSpanNonRaylet(
worker.profiler, event_type, contents=extra_data)
else:
return NULL_LOG_SPAN
else:
return RayLogSpanRaylet(
worker.profiler, event_type, extra_data=extra_data)
return RayLogSpanRaylet(worker.profiler, event_type, extra_data=extra_data)
class Profiler(object):
@@ -124,87 +114,20 @@ class Profiler(object):
events = self.events
self.events = []
if not self.worker.use_raylet:
event_log_key = b"event_log:" + self.worker.worker_id
event_log_value = json.dumps(events)
self.worker.local_scheduler_client.log_event(
event_log_key, event_log_value, time.time())
if self.worker.mode == ray.WORKER_MODE:
component_type = "worker"
else:
if self.worker.mode == ray.WORKER_MODE:
component_type = "worker"
else:
component_type = "driver"
component_type = "driver"
self.worker.local_scheduler_client.push_profile_events(
component_type, ray.ObjectID(self.worker.worker_id),
self.worker.node_ip_address, events)
self.worker.local_scheduler_client.push_profile_events(
component_type, ray.ObjectID(self.worker.worker_id),
self.worker.node_ip_address, events)
def add_event(self, event):
with self.lock:
self.events.append(event)
class RayLogSpanNonRaylet(object):
"""An object used to enable logging a span of events with a with statement.
Attributes:
event_type (str): The type of the event being logged.
contents: Additional information to log.
"""
def __init__(self, profiler, event_type, contents=None):
"""Initialize a RayLogSpanNonRaylet object."""
self.profiler = profiler
self.event_type = event_type
self.contents = contents
def _log(self, event_type, kind, contents=None):
"""Log an event to the global state store.
This adds the event to a buffer of events locally. The buffer can be
flushed and written to the global state store by calling
flush_profile_data().
Args:
event_type (str): The type of the event.
contents: More general data to store with the event.
kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This
is LOG_POINT if the event being logged happens at a single
point in time. It is LOG_SPAN_START if we are starting to log a
span of time, and it is LOG_SPAN_END if we are finishing
logging a span of time.
"""
# TODO(rkn): This code currently takes around half a microsecond. Since
# we call it tens of times per task, this adds up. We will need to redo
# the logging code, perhaps in C.
contents = {} if contents is None else contents
assert isinstance(contents, dict)
# Make sure all of the keys and values in the dictionary are strings.
contents = {str(k): str(v) for k, v in contents.items()}
self.profiler.add_event((time.time(), event_type, kind, contents))
def __enter__(self):
"""Log the beginning of a span event."""
self._log(
event_type=self.event_type,
contents=self.contents,
kind=LOG_SPAN_START)
def __exit__(self, type, value, tb):
"""Log the end of a span event. Log any exception that occurred."""
if type is None:
self._log(event_type=self.event_type, kind=LOG_SPAN_END)
else:
self._log(
event_type=self.event_type,
contents={
"type": str(type),
"value": value,
"traceback": traceback.format_exc()
},
kind=LOG_SPAN_END)
class RayLogSpanRaylet(object):
"""An object used to enable logging a span of events with a with statement.
+1 -2
View File
@@ -5,7 +5,7 @@ from __future__ import print_function
import os
from ray.local_scheduler import ObjectID
from ray.raylet import ObjectID
def env_integer(key, default):
@@ -41,7 +41,6 @@ REGISTER_ACTOR_PUSH_ERROR = "register_actor"
WORKER_CRASH_PUSH_ERROR = "worker_crash"
WORKER_DIED_PUSH_ERROR = "worker_died"
PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction"
HASH_MISMATCH_PUSH_ERROR = "object_hash_mismatch"
INFEASIBLE_TASK_ERROR = "infeasible_task"
REMOVED_NODE_ERROR = "node_removed"
MONITOR_DIED_ERROR = "monitor_died"
@@ -2,10 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.core.src.local_scheduler.liblocal_scheduler_library_python import (
from ray.core.src.ray.raylet.liblocal_scheduler_library_python import (
Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id,
task_from_string, task_to_string, _config, common_error)
from .local_scheduler_services import start_local_scheduler
__all__ = [
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
+2 -5
View File
@@ -39,11 +39,8 @@ class TaskPool(object):
for worker, obj_id in self.completed():
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
if not ray.global_state.use_raylet:
ray.worker.global_worker.plasma_client.fetch([plasma_id])
else:
(ray.worker.global_worker.local_scheduler_client.
reconstruct_objects([obj_id], True))
(ray.worker.global_worker.local_scheduler_client.
reconstruct_objects([obj_id], True))
self._fetching.append((worker, obj_id))
remaining = []
+5 -33
View File
@@ -5,7 +5,6 @@ from __future__ import print_function
import click
import json
import logging
import os
import subprocess
import ray.services as services
@@ -20,7 +19,7 @@ logger = logging.getLogger(__name__)
def check_no_existing_redis_clients(node_ip_address, redis_client):
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
# "src/ray/gcs/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.
@@ -167,11 +166,6 @@ def cli(logging_level, logging_format):
required=False,
type=str,
help="the file that contains the autoscaling config")
@click.option(
"--use-raylet",
default=None,
type=bool,
help="use the raylet code path, this defaults to false")
@click.option(
"--no-redirect-worker-output",
is_flag=True,
@@ -198,31 +192,15 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients, redis_password, redis_shard_ports,
object_manager_port, object_store_memory, num_workers, num_cpus,
num_gpus, resources, head, no_ui, block, plasma_directory,
huge_pages, autoscaling_config, use_raylet,
no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir):
huge_pages, autoscaling_config, no_redirect_worker_output,
no_redirect_output, plasma_store_socket_name, raylet_socket_name,
temp_dir):
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
if redis_address is not None:
redis_address = services.address_to_ip(redis_address)
if use_raylet is None:
if os.environ.get("RAY_USE_XRAY") == "0":
# This environment variable is used in our testing setup.
logger.info("Detected environment variable 'RAY_USE_XRAY' with "
"value {}. This turns OFF xray.".format(
os.environ.get("RAY_USE_XRAY")))
use_raylet = False
else:
use_raylet = True
if not use_raylet and redis_password is not None:
raise Exception("Setting the 'redis-password' argument is not "
"supported in legacy Ray. To run Ray with "
"password-protected Redis ports, pass "
"the '--use-raylet' flag.")
try:
resources = json.loads(resources)
except Exception:
@@ -290,7 +268,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
autoscaling_config=autoscaling_config,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
@@ -369,7 +346,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
resources=resources,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
@@ -387,11 +363,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
@cli.command()
def stop():
subprocess.call(
[
"killall global_scheduler plasma_store_server plasma_manager "
"local_scheduler raylet raylet_monitor"
],
shell=True)
["killall plasma_store_server raylet raylet_monitor"], shell=True)
# Find the PID of the monitor process and kill it.
subprocess.call(
+33 -316
View File
@@ -14,32 +14,25 @@ import subprocess
import sys
import threading
import time
from collections import OrderedDict, namedtuple
from collections import OrderedDict
import redis
import pyarrow
# Ray modules
import ray.ray_constants
import ray.global_scheduler as global_scheduler
import ray.local_scheduler
import ray.plasma
from ray.tempfile_services import (
get_ipython_notebook_path, get_logs_dir_path, get_raylet_socket_name,
get_temp_root, new_global_scheduler_log_file, new_local_scheduler_log_file,
new_log_monitor_log_file, new_monitor_log_file,
new_plasma_manager_log_file, new_plasma_store_log_file,
new_raylet_log_file, new_redis_log_file, new_webui_log_file,
new_worker_log_file, set_temp_root)
get_temp_root, new_log_monitor_log_file, new_monitor_log_file,
new_plasma_store_log_file, new_raylet_log_file, new_redis_log_file,
new_webui_log_file, set_temp_root)
PROCESS_TYPE_MONITOR = "monitor"
PROCESS_TYPE_LOG_MONITOR = "log_monitor"
PROCESS_TYPE_WORKER = "worker"
PROCESS_TYPE_RAYLET = "raylet"
PROCESS_TYPE_LOCAL_SCHEDULER = "local_scheduler"
PROCESS_TYPE_PLASMA_MANAGER = "plasma_manager"
PROCESS_TYPE_PLASMA_STORE = "plasma_store"
PROCESS_TYPE_GLOBAL_SCHEDULER = "global_scheduler"
PROCESS_TYPE_REDIS_SERVER = "redis_server"
PROCESS_TYPE_WEB_UI = "web_ui"
@@ -51,23 +44,20 @@ PROCESS_TYPE_WEB_UI = "web_ui"
all_processes = OrderedDict(
[(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []),
(PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []),
(PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []),
(PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []),
(PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], )
(PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_REDIS_SERVER, []),
(PROCESS_TYPE_WEB_UI, [])], )
# True if processes are run in the valgrind profiler.
RUN_RAYLET_PROFILER = False
RUN_LOCAL_SCHEDULER_PROFILER = False
RUN_PLASMA_MANAGER_PROFILER = False
RUN_PLASMA_STORE_PROFILER = False
# Location of the redis server and module.
REDIS_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"core/src/common/thirdparty/redis/src/redis-server")
"core/src/ray/thirdparty/redis/src/redis-server")
REDIS_MODULE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"core/src/common/redis_module/libray_redis_module.so")
"core/src/ray/gcs/redis_module/libray_redis_module.so")
# Location of the credis server and modules.
# credis will be enabled if the environment variable RAY_USE_NEW_GCS is set.
@@ -88,14 +78,6 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join(
RAYLET_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet")
# ObjectStoreAddress tuples contain all information necessary to connect to an
# object store. The fields are:
# - name: The socket name for the object store
# - manager_name: The socket name for the object store manager
# - manager_port: The Internet port that the object store manager listens on
ObjectStoreAddress = namedtuple("ObjectStoreAddress",
["name", "manager_name", "manager_port"])
# Logger for this module. It should be configured at the entry point
# into the program using Ray. Ray configures it by default automatically
# using logging.basicConfig in its entry/init points.
@@ -136,10 +118,7 @@ def kill_process(p):
if p.poll() is not None:
# The process has already terminated.
return True
if any([
RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER
]):
if any([RUN_RAYLET_PROFILER, RUN_PLASMA_STORE_PROFILER]):
# Give process signal to write profiler data.
os.kill(p.pid, signal.SIGINT)
# Wait for profiling data to be written.
@@ -430,7 +409,6 @@ def start_redis(node_ip_address,
redis_shard_ports=None,
num_redis_shards=1,
redis_max_clients=None,
use_raylet=True,
redirect_output=False,
redirect_worker_output=False,
cleanup=True,
@@ -450,7 +428,6 @@ def start_redis(node_ip_address,
shard.
redis_max_clients: If this is provided, Ray will attempt to configure
Redis with this maxclients number.
use_raylet: True if the new raylet code path should be used.
redirect_output (bool): True if output should be redirected to a file
and false otherwise.
redirect_worker_output (bool): True if worker output should be
@@ -515,12 +492,6 @@ def start_redis(node_ip_address,
port = assigned_port
redis_address = address(node_ip_address, port)
redis_client = redis.StrictRedis(
host=node_ip_address, port=port, password=password)
# Store whether we're using the raylet code path or not.
redis_client.set("UseRaylet", 1 if use_raylet else 0)
# Register the number of Redis shards in the primary shard, so that clients
# know how many redis shards to expect under RedisShards.
primary_redis_client = redis.StrictRedis(
@@ -762,40 +733,6 @@ def start_log_monitor(redis_address,
password=redis_password)
def start_global_scheduler(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True,
redis_password=None):
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will
run on.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
"""
p = global_scheduler.start_global_scheduler(
redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p)
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
"""Start a UI process.
@@ -856,13 +793,11 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
return webui_url
def check_and_update_resources(resources, use_raylet):
def check_and_update_resources(resources):
"""Sanity check a resource dictionary and add sensible defaults.
Args:
resources: A dictionary mapping resource names to resource quantities.
use_raylet: True if we are using the raylet code path and false
otherwise.
Returns:
A new resource dictionary.
@@ -901,79 +836,13 @@ def check_and_update_resources(resources, use_raylet):
and not resource_quantity.is_integer()):
raise ValueError("Resource quantities must all be whole numbers.")
if (use_raylet and
resource_quantity > ray.ray_constants.MAX_RESOURCE_QUANTITY):
if resource_quantity > ray.ray_constants.MAX_RESOURCE_QUANTITY:
raise ValueError("Resource quantities must be at most {}.".format(
ray.ray_constants.MAX_RESOURCE_QUANTITY))
return resources
def start_local_scheduler(redis_address,
node_ip_address,
plasma_store_name,
plasma_manager_name,
worker_path,
plasma_address=None,
stdout_file=None,
stderr_file=None,
cleanup=True,
resources=None,
num_workers=0,
redis_password=None):
"""Start a local scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address (str): The IP address of the node that this local
scheduler is running on.
plasma_store_name (str): The name of the plasma store socket to connect
to.
plasma_manager_name (str): The name of the plasma manager socket to
connect to.
worker_path (str): The path of the script to use when the local
scheduler starts up new workers.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by serices.cleanup() when the
Python process that imported services exits.
resources: A dictionary mapping the name of a resource to the available
quantity of that resource.
num_workers (int): The number of workers that the local scheduler
should start.
redis_password (str): The password of the redis server.
Return:
The name of the local scheduler socket.
"""
resources = check_and_update_resources(resources, False)
logger.info("Starting local scheduler with the following resources: {}."
.format(resources))
local_scheduler_name, p = ray.local_scheduler.start_local_scheduler(
plasma_store_name,
plasma_manager_name,
worker_path=worker_path,
node_ip_address=node_ip_address,
redis_address=redis_address,
plasma_address=plasma_address,
use_profiler=RUN_LOCAL_SCHEDULER_PROFILER,
stdout_file=stdout_file,
stderr_file=stderr_file,
static_resources=resources,
num_workers=num_workers)
if cleanup:
all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p)
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
return local_scheduler_name
def start_raylet(redis_address,
node_ip_address,
raylet_name,
@@ -1017,7 +886,7 @@ def start_raylet(redis_address,
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
static_resources = check_and_update_resources(resources, True)
static_resources = check_and_update_resources(resources)
# Limit the number of workers that can be started in parallel by the
# raylet. However, make sure it is at least 1.
@@ -1093,13 +962,10 @@ def start_plasma_store(node_ip_address,
object_manager_port=None,
store_stdout_file=None,
store_stderr_file=None,
manager_stdout_file=None,
manager_stderr_file=None,
objstore_memory=None,
cleanup=True,
plasma_directory=None,
huge_pages=False,
use_raylet=True,
plasma_store_socket_name=None,
redis_password=None):
"""This method starts an object store process.
@@ -1114,12 +980,6 @@ def start_plasma_store(node_ip_address,
to. If no redirection should happen, then this should be None.
store_stderr_file: A file handle opened for writing to redirect stderr
to. If no redirection should happen, then this should be None.
manager_stdout_file: A file handle opened for writing to redirect
stdout to. If no redirection should happen, then this should be
None.
manager_stderr_file: A file handle opened for writing to redirect
stderr to. If no redirection should happen, then this should be
None.
objstore_memory: The amount of memory (in bytes) to start the object
store with.
cleanup (bool): True if using Ray in local mode. If cleanup is true,
@@ -1129,12 +989,10 @@ def start_plasma_store(node_ip_address,
be created.
huge_pages: Boolean flag indicating whether to start the Object
Store with hugetlbfs support. Requires plasma_directory.
use_raylet: True if the new raylet code path should be used.
redis_password (str): The password of the redis server.
Return:
A tuple of the Plasma store socket name, the Plasma manager socket
name, and the plasma manager port.
The Plasma store socket name.
"""
if objstore_memory is None:
# Compute a fraction of the system memory for the Plasma store to use.
@@ -1177,32 +1035,6 @@ def start_plasma_store(node_ip_address,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
socket_name=plasma_store_socket_name)
# Start the plasma manager.
if not use_raylet:
if object_manager_port is not None:
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
node_ip_address=node_ip_address,
num_retries=1,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
assert plasma_manager_port == object_manager_port
else:
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
else:
plasma_manager_port = None
plasma_manager_name = None
if cleanup:
all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1)
@@ -1210,19 +1042,12 @@ def start_plasma_store(node_ip_address,
redis_address,
node_ip_address, [store_stdout_file, store_stderr_file],
password=redis_password)
if not use_raylet:
if cleanup:
all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2)
record_log_files_in_redis(redis_address, node_ip_address,
[manager_stdout_file, manager_stderr_file])
return ObjectStoreAddress(plasma_store_name, plasma_manager_name,
plasma_manager_port)
return plasma_store_name
def start_worker(node_ip_address,
object_store_name,
object_store_manager_name,
local_scheduler_name,
redis_address,
worker_path,
@@ -1235,7 +1060,6 @@ def start_worker(node_ip_address,
node_ip_address (str): The IP address of the node that this worker is
running on.
object_store_name (str): The name of the object store.
object_store_manager_name (str): The name of the object store manager.
local_scheduler_name (str): The name of the local scheduler.
redis_address (str): The address that the Redis server is listening on.
worker_path (str): The path of the source code which the worker process
@@ -1253,7 +1077,6 @@ def start_worker(node_ip_address,
sys.executable, "-u", worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--object-store-manager-name=" + object_store_manager_name,
"--local-scheduler-name=" + local_scheduler_name,
"--redis-address=" + str(redis_address),
"--temp-dir=" + get_temp_root()
@@ -1349,7 +1172,6 @@ def start_ray_processes(address_info=None,
cleanup=True,
redirect_worker_output=False,
redirect_output=False,
include_global_scheduler=False,
include_log_monitor=False,
include_webui=False,
start_workers_from_local_scheduler=True,
@@ -1357,7 +1179,6 @@ def start_ray_processes(address_info=None,
plasma_directory=None,
huge_pages=False,
autoscaling_config=None,
use_raylet=True,
plasma_store_socket_name=None,
raylet_socket_name=None,
temp_dir=None):
@@ -1398,8 +1219,6 @@ def start_ray_processes(address_info=None,
processes should be redirected to files.
redirect_output (bool): True if stdout and stderr for non-worker
processes should be redirected to files and false otherwise.
include_global_scheduler (bool): If include_global_scheduler is True,
then start a global scheduler process.
include_log_monitor (bool): If True, then start a log monitor to
monitor the log files for all processes on this node and push their
contents to Redis.
@@ -1415,7 +1234,6 @@ def start_ray_processes(address_info=None,
huge_pages: Boolean flag indicating whether to start the Object
Store with hugetlbfs support. Requires plasma_directory.
autoscaling_config: path to autoscaling config file.
use_raylet: True if the new raylet code path should be used.
plasma_store_socket_name (str): If provided, it will specify the socket
name used by the plasma store.
raylet_socket_name (str): If provided, it will specify the socket path
@@ -1469,7 +1287,6 @@ def start_ray_processes(address_info=None,
redis_shard_ports=redis_shard_ports,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
use_raylet=use_raylet,
redirect_output=True,
redirect_worker_output=redirect_worker_output,
cleanup=cleanup,
@@ -1488,13 +1305,12 @@ def start_ray_processes(address_info=None,
cleanup=cleanup,
autoscaling_config=autoscaling_config,
redis_password=redis_password)
if use_raylet:
start_raylet_monitor(
redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
redis_password=redis_password)
start_raylet_monitor(
redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
redis_password=redis_password)
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
@@ -1516,25 +1332,10 @@ def start_ray_processes(address_info=None,
cleanup=cleanup,
redis_password=redis_password)
# Start the global scheduler, if necessary.
if include_global_scheduler and not use_raylet:
global_scheduler_stdout_file, global_scheduler_stderr_file = (
new_global_scheduler_log_file(redirect_output))
start_global_scheduler(
redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup,
redis_password=redis_password)
# Initialize with existing services.
if "object_store_addresses" not in address_info:
address_info["object_store_addresses"] = []
object_store_addresses = address_info["object_store_addresses"]
if "local_scheduler_socket_names" not in address_info:
address_info["local_scheduler_socket_names"] = []
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
if "raylet_socket_names" not in address_info:
address_info["raylet_socket_names"] = []
raylet_socket_names = address_info["raylet_socket_names"]
@@ -1552,114 +1353,37 @@ def start_ray_processes(address_info=None,
plasma_store_stdout_file, plasma_store_stderr_file = (
new_plasma_store_log_file(i, redirect_output))
# If we use raylet, plasma manager won't be started and we don't need
# to create temp files for them.
plasma_manager_stdout_file, plasma_manager_stderr_file = (
new_plasma_manager_log_file(i, redirect_output and not use_raylet))
object_store_address = start_plasma_store(
node_ip_address,
redis_address,
object_manager_port=object_manager_ports[i],
store_stdout_file=plasma_store_stdout_file,
store_stderr_file=plasma_store_stderr_file,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
objstore_memory=object_store_memory,
cleanup=cleanup,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
redis_password=redis_password)
object_store_addresses.append(object_store_address)
time.sleep(0.1)
if not use_raylet:
# Start any local schedulers that do not yet exist.
for i in range(
len(local_scheduler_socket_names), num_local_schedulers):
# Connect the local scheduler to the object store at the same
# index.
object_store_address = object_store_addresses[i]
plasma_address = "{}:{}".format(node_ip_address,
object_store_address.manager_port)
# Determine how many workers this local scheduler should start.
if start_workers_from_local_scheduler:
num_local_scheduler_workers = workers_per_local_scheduler[i]
workers_per_local_scheduler[i] = 0
else:
# If we're starting the workers from Python, the local
# scheduler should not start any workers.
num_local_scheduler_workers = 0
# Start the local scheduler. Note that if we do not wish to
# redirect the worker output, then we cannot redirect the local
# scheduler output.
local_scheduler_stdout_file, local_scheduler_stderr_file = (
new_local_scheduler_log_file(
i, redirect_output=redirect_worker_output))
local_scheduler_name = start_local_scheduler(
# Start any raylets that do not exist yet.
for i in range(len(raylet_socket_names), num_local_schedulers):
raylet_stdout_file, raylet_stderr_file = new_raylet_log_file(
i, redirect_output=redirect_worker_output)
address_info["raylet_socket_names"].append(
start_raylet(
redis_address,
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
raylet_socket_name or get_raylet_socket_name(),
object_store_addresses[i],
worker_path,
plasma_address=plasma_address,
stdout_file=local_scheduler_stdout_file,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
resources=resources[i],
num_workers=num_local_scheduler_workers,
redis_password=redis_password)
local_scheduler_socket_names.append(local_scheduler_name)
# Make sure that we have exactly num_local_schedulers instances of
# object stores and local schedulers.
assert len(object_store_addresses) == num_local_schedulers
assert len(local_scheduler_socket_names) == num_local_schedulers
else:
# Start any raylets that do not exist yet.
for i in range(len(raylet_socket_names), num_local_schedulers):
raylet_stdout_file, raylet_stderr_file = new_raylet_log_file(
i, redirect_output=redirect_worker_output)
address_info["raylet_socket_names"].append(
start_raylet(
redis_address,
node_ip_address,
raylet_socket_name or get_raylet_socket_name(),
object_store_addresses[i].name,
worker_path,
resources=resources[i],
num_workers=workers_per_local_scheduler[i],
stdout_file=raylet_stdout_file,
stderr_file=raylet_stderr_file,
cleanup=cleanup,
redis_password=redis_password))
if not use_raylet:
# Start any workers that the local scheduler has not already started.
for i, num_local_scheduler_workers in enumerate(
workers_per_local_scheduler):
object_store_address = object_store_addresses[i]
local_scheduler_name = local_scheduler_socket_names[i]
for j in range(num_local_scheduler_workers):
worker_stdout_file, worker_stderr_file = new_worker_log_file(
i, j, redirect_output)
start_worker(
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
workers_per_local_scheduler[i] -= 1
# Make sure that we've started all the workers.
assert (sum(workers_per_local_scheduler) == 0)
num_workers=workers_per_local_scheduler[i],
stdout_file=raylet_stdout_file,
stderr_file=raylet_stderr_file,
cleanup=cleanup,
redis_password=redis_password))
# Try to start the web UI.
if include_webui:
@@ -1689,7 +1413,6 @@ def start_ray_node(node_ip_address,
resources=None,
plasma_directory=None,
huge_pages=False,
use_raylet=True,
plasma_store_socket_name=None,
raylet_socket_name=None,
temp_dir=None):
@@ -1727,7 +1450,6 @@ def start_ray_node(node_ip_address,
be created.
huge_pages: Boolean flag indicating whether to start the Object
Store with hugetlbfs support. Requires plasma_directory.
use_raylet: True if the new raylet code path should be used.
plasma_store_socket_name (str): If provided, it will specify the socket
name used by the plasma store.
raylet_socket_name (str): If provided, it will specify the socket path
@@ -1758,7 +1480,6 @@ def start_ray_node(node_ip_address,
resources=resources,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
@@ -1784,7 +1505,6 @@ def start_ray_head(address_info=None,
plasma_directory=None,
huge_pages=False,
autoscaling_config=None,
use_raylet=True,
plasma_store_socket_name=None,
raylet_socket_name=None,
temp_dir=None):
@@ -1836,7 +1556,6 @@ def start_ray_head(address_info=None,
huge_pages: Boolean flag indicating whether to start the Object
Store with hugetlbfs support. Requires plasma_directory.
autoscaling_config: path to autoscaling config file.
use_raylet: True if the new raylet code path should be used.
plasma_store_socket_name (str): If provided, it will specify the socket
name used by the plasma store.
raylet_socket_name (str): If provided, it will specify the socket path
@@ -1861,7 +1580,6 @@ def start_ray_head(address_info=None,
cleanup=cleanup,
redirect_worker_output=redirect_worker_output,
redirect_output=redirect_output,
include_global_scheduler=True,
include_log_monitor=True,
include_webui=include_webui,
start_workers_from_local_scheduler=start_workers_from_local_scheduler,
@@ -1872,7 +1590,6 @@ def start_ray_head(address_info=None,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
autoscaling_config=autoscaling_config,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
-60
View File
@@ -117,27 +117,6 @@ def get_object_store_socket_name():
return make_inc_temp(prefix="plasma_store", directory_name=sockets_dir)
def get_plasma_manager_socket_name():
"""Get a socket name for plasma manager."""
sockets_dir = get_sockets_dir_path()
return make_inc_temp(prefix="plasma_manager", directory_name=sockets_dir)
def get_local_scheduler_socket_name(suffix=""):
"""Get a socket name for local scheduler.
This function could be unsafe. The socket name may
refer to a file that did not exist at some point, but by the time
you get around to creating it, someone else may have beaten you to
the punch.
"""
sockets_dir = get_sockets_dir_path()
raylet_socket_name = make_inc_temp(
prefix="scheduler", directory_name=sockets_dir, suffix=suffix)
return raylet_socket_name
def get_ipython_notebook_path(port):
"""Get a new ipython notebook path"""
@@ -211,17 +190,6 @@ def new_raylet_log_file(local_scheduler_index, redirect_output):
return raylet_stdout_file, raylet_stderr_file
def new_local_scheduler_log_file(local_scheduler_index, redirect_output):
"""Create new logging files for local scheduler.
It is only used in non-raylet versions.
"""
local_scheduler_stdout_file, local_scheduler_stderr_file = (new_log_files(
"local_scheduler_{}".format(local_scheduler_index),
redirect_output=redirect_output))
return local_scheduler_stdout_file, local_scheduler_stderr_file
def new_webui_log_file():
"""Create new logging files for web ui."""
ui_stdout_file, ui_stderr_file = new_log_files(
@@ -229,17 +197,6 @@ def new_webui_log_file():
return ui_stdout_file, ui_stderr_file
def new_worker_log_file(local_scheduler_index, worker_index, redirect_output):
"""Create new logging files for workers with local scheduler index.
It is only used in non-raylet versions.
"""
worker_stdout_file, worker_stderr_file = new_log_files(
"worker_{}_{}".format(local_scheduler_index, worker_index),
redirect_output)
return worker_stdout_file, worker_stderr_file
def new_worker_redirected_log_file(worker_id):
"""Create new logging files for workers to redirect its output."""
worker_stdout_file, worker_stderr_file = (new_log_files(
@@ -254,16 +211,6 @@ def new_log_monitor_log_file():
return log_monitor_stdout_file, log_monitor_stderr_file
def new_global_scheduler_log_file(redirect_output):
"""Create new logging files for the new global scheduler.
It is only used in non-raylet versions.
"""
global_scheduler_stdout_file, global_scheduler_stderr_file = (
new_log_files("global_scheduler", redirect_output))
return global_scheduler_stdout_file, global_scheduler_stderr_file
def new_plasma_store_log_file(local_scheduler_index, redirect_output):
"""Create new logging files for the plasma store."""
plasma_store_stdout_file, plasma_store_stderr_file = new_log_files(
@@ -271,13 +218,6 @@ def new_plasma_store_log_file(local_scheduler_index, redirect_output):
return plasma_store_stdout_file, plasma_store_stderr_file
def new_plasma_manager_log_file(local_scheduler_index, redirect_output):
"""Create new logging files for the plasma manager."""
plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files(
"plasma_manager_{}".format(local_scheduler_index), redirect_output)
return plasma_manager_stdout_file, plasma_manager_stderr_file
def new_monitor_log_file(redirect_output):
"""Create new logging files for the monitor."""
monitor_stdout_file, monitor_stderr_file = new_log_files(
+10 -8
View File
@@ -44,7 +44,6 @@ class Cluster(object):
All nodes are by default started with the following settings:
cleanup=True,
use_raylet=True,
resources={"CPU": 1},
object_store_memory=100 * (2**20) # 100 MB
@@ -55,12 +54,13 @@ class Cluster(object):
Returns:
Node object of the added Ray node.
"""
node_kwargs = dict(
cleanup=True,
use_raylet=True,
resources={"CPU": 1},
object_store_memory=100 * (2**20) # 100 MB
)
node_kwargs = {
"cleanup": True,
"resources": {
"CPU": 1
},
"object_store_memory": 100 * (2**20) # 100 MB
}
node_kwargs.update(override_kwargs)
if self.head_node is None:
@@ -179,7 +179,9 @@ class Node(object):
for process_name, process_list in self.process_dict.items():
logger.info("Killing all {}(s)".format(process_name))
for process in process_list:
process.kill()
# Kill the process if it is still alive.
if process.poll() is None:
process.kill()
for process_name, process_list in self.process_dict.items():
logger.info("Waiting all {}(s)".format(process_name))
-3
View File
@@ -28,9 +28,6 @@ class TestRedisPassword(object):
@pytest.mark.skipif(
os.environ.get("RAY_USE_NEW_GCS") == "on",
reason="New GCS API doesn't support Redis authentication yet.")
@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") == "0",
reason="Redis authentication is not supported in legacy Ray.")
def test_redis_password(self, password, shutdown_only):
# Workaround for https://github.com/ray-project/ray/issues/3045
@ray.remote
+3 -14
View File
@@ -35,22 +35,11 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20):
client_table = ray.global_state.client_table()
num_ready_nodes = len(client_table)
if num_ready_nodes == num_nodes:
ready = True
# Check that for each node, a local scheduler and a plasma manager
# are present.
if ray.global_state.use_raylet:
# In raylet mode, this is a list of map.
# The GCS info will appear as a whole instead of part by part.
return
else:
for ip_address, clients in client_table.items():
client_types = [client["ClientType"] for client in clients]
if "local_scheduler" not in client_types:
ready = False
if "plasma_manager" not in client_types:
ready = False
if ready:
return
# In raylet mode, this is a list of map.
# The GCS info will appear as a whole instead of part by part.
return
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
+3 -14
View File
@@ -213,20 +213,9 @@ class RayTrialExecutor(TrialExecutor):
assert self._committed_resources.gpu >= 0
def _update_avail_resources(self):
if ray.worker.global_worker.use_raylet:
# TODO(rliaw): Remove once raylet flag is swapped
resources = ray.global_state.cluster_resources()
num_cpus = resources["CPU"]
num_gpus = resources["GPU"]
else:
clients = ray.global_state.client_table()
local_schedulers = [
entry for client in clients.values() for entry in client
if (entry['ClientType'] == 'local_scheduler'
and not entry['Deleted'])
]
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)
resources = ray.global_state.cluster_resources()
num_cpus = resources["CPU"]
num_gpus = resources["GPU"]
self._avail_resources = Resources(int(num_cpus), int(num_gpus))
self._resources_initialized = True
+6 -6
View File
@@ -107,7 +107,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
return Resources(cpu=config["cpu"], gpu=config["gpu"])
def _train(self):
return dict(timesteps_this_iter=1, done=True)
return {"timesteps_this_iter": 1, "done": True}
register_trainable("B", B)
@@ -440,7 +440,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.state = {"hi": 1}
def _train(self):
return dict(timesteps_this_iter=1, done=True)
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return self.state
@@ -471,7 +471,7 @@ class TrainableFunctionApiTest(unittest.TestCase):
def _train(self):
self.state["iter"] += 1
return dict(timesteps_this_iter=1, done=True)
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return self.state
@@ -604,7 +604,7 @@ class RunExperimentTest(unittest.TestCase):
class B(Trainable):
def _train(self):
return dict(timesteps_this_iter=1, done=True)
return {"timesteps_this_iter": 1, "done": True}
register_trainable("f1", train)
trials = run_experiments({
@@ -624,7 +624,7 @@ class RunExperimentTest(unittest.TestCase):
def testCheckpointAtEnd(self):
class train(Trainable):
def _train(self):
return dict(timesteps_this_iter=1, done=True)
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return path
@@ -887,7 +887,7 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[1].status, Trial.PENDING)
def testFractionalGpus(self):
ray.init(num_cpus=4, num_gpus=1, use_raylet=True)
ray.init(num_cpus=4, num_gpus=1)
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"resources": Resources(cpu=1, gpu=0.5),
+1 -1
View File
@@ -28,7 +28,7 @@ def pin_in_object_store(obj):
def get_pinned_object(pinned_id):
"""Retrieve a pinned object from the object store."""
from ray.local_scheduler import ObjectID
from ray.raylet import ObjectID
return _from_pinnable(
ray.get(
+10 -33
View File
@@ -15,11 +15,9 @@ import time
import uuid
import ray.gcs_utils
import ray.local_scheduler
import ray.raylet
import ray.ray_constants as ray_constants
ERROR_KEY_PREFIX = b"Error:"
def _random_string():
id_hash = hashlib.sha1()
@@ -70,22 +68,12 @@ def push_error_to_driver(worker,
"""
if driver_id is None:
driver_id = ray_constants.NIL_JOB_ID.id()
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
if not worker.use_raylet:
worker.redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
worker.redis_client.rpush("ErrorKeys", error_key)
else:
worker.local_scheduler_client.push_error(
ray.ObjectID(driver_id), error_type, message, time.time())
worker.local_scheduler_client.push_error(
ray.ObjectID(driver_id), error_type, message, time.time())
def push_error_to_driver_through_redis(redis_client,
use_raylet,
error_type,
message,
driver_id=None,
@@ -99,8 +87,6 @@ def push_error_to_driver_through_redis(redis_client,
Args:
redis_client: The redis client to use.
use_raylet: True if we are using the Raylet code path and false
otherwise.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
@@ -111,23 +97,14 @@ def push_error_to_driver_through_redis(redis_client,
"""
if driver_id is None:
driver_id = ray_constants.NIL_JOB_ID.id()
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
if not use_raylet:
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
else:
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(
driver_id, error_type, message, time.time())
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data)
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
message, time.time())
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data)
def is_cython(obj):
+89 -306
View File
@@ -27,14 +27,13 @@ import ray.serialization as serialization
import ray.services as services
import ray.signature
import ray.tempfile_services as tempfile_services
import ray.local_scheduler
import ray.raylet
import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray import profiling
from ray.function_manager import FunctionActorManager
from ray.utils import (
binary_to_hex,
check_oversized_pickle,
is_cython,
random_string,
@@ -56,14 +55,6 @@ NIL_ACTOR_ID = NIL_ID
NIL_ACTOR_HANDLE_ID = NIL_ID
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
# This must be kept in sync with the `error_types` array in
# common/state/error_table.h.
OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch"
PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction"
# This must be kept in sync with the `scheduling_state` enum in common/task.h.
TASK_STATUS_RUNNING = 8
# Default resource requirements for actors when no resource requirements are
# specified.
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
@@ -461,13 +452,9 @@ class Worker(object):
]
for i in range(0, len(object_ids),
ray._config.worker_fetch_request_size()):
if not self.use_raylet:
self.plasma_client.fetch(plain_object_ids[i:(
i + ray._config.worker_fetch_request_size())])
else:
self.local_scheduler_client.reconstruct_objects(
object_ids[i:(
i + ray._config.worker_fetch_request_size())], True)
self.local_scheduler_client.reconstruct_objects(
object_ids[i:(i + ray._config.worker_fetch_request_size())],
True)
# Get the objects. We initially try to get the objects immediately.
final_results = self.retrieve_and_deserialize(plain_object_ids, 0)
@@ -497,25 +484,9 @@ class Worker(object):
ray._config.worker_fetch_request_size())
for i in range(0, len(object_ids_to_fetch),
fetch_request_size):
if not self.use_raylet:
for unready_id in ray_object_ids_to_fetch[i:(
i + fetch_request_size)]:
(self.local_scheduler_client.
reconstruct_objects([unready_id], False))
# Do another fetch for objects that aren't
# available locally yet, in case they were evicted
# since the last fetch. We divide the fetch into
# smaller fetches so as to not block the manager
# for a prolonged period of time in a single call.
# This is only necessary for legacy ray since
# reconstruction and fetch are implemented by
# different processes.
self.plasma_client.fetch(object_ids_to_fetch[i:(
i + fetch_request_size)])
else:
self.local_scheduler_client.reconstruct_objects(
ray_object_ids_to_fetch[i:(
i + fetch_request_size)], False)
self.local_scheduler_client.reconstruct_objects(
ray_object_ids_to_fetch[i:(
i + fetch_request_size)], False)
results = self.retrieve_and_deserialize(
object_ids_to_fetch,
max([
@@ -608,7 +579,7 @@ class Worker(object):
for arg in args:
if isinstance(arg, ray.ObjectID):
args_for_local_scheduler.append(arg)
elif ray.local_scheduler.check_simple_value(arg):
elif ray.raylet.check_simple_value(arg):
args_for_local_scheduler.append(arg)
else:
args_for_local_scheduler.append(put(arg))
@@ -641,14 +612,13 @@ class Worker(object):
task_index = self.task_index
self.task_index += 1
# Submit the task to local scheduler.
task = ray.local_scheduler.Task(
task = ray.raylet.Task(
driver_id, ray.ObjectID(
function_id.id()), args_for_local_scheduler,
num_return_vals, self.current_task_id, task_index,
actor_creation_id, actor_creation_dummy_object_id, actor_id,
actor_handle_id, actor_counter, is_actor_checkpoint_method,
execution_dependencies, resources, placement_resources,
self.use_raylet)
actor_handle_id, actor_counter, execution_dependencies,
resources, placement_resources)
self.local_scheduler_client.submit(task)
return task.returns()
@@ -925,26 +895,13 @@ class Worker(object):
# good to know where the system is hanging.
with self.lock:
function_name = execution_info.function_name
if not self.use_raylet:
extra_data = {
"function_name": function_name,
"task_id": task.task_id().hex(),
"worker_id": binary_to_hex(self.worker_id)
}
else:
extra_data = {
"name": function_name,
"task_id": task.task_id().hex()
}
extra_data = {
"name": function_name,
"task_id": task.task_id().hex()
}
with profiling.profile("task", extra_data=extra_data, worker=self):
self._process_task(task, execution_info)
# In the non-raylet code path, push all of the log events to the global
# state store. In the raylet code path, this is done periodically in a
# background thread.
if not self.use_raylet:
self.profiler.flush_profile_data()
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
driver_id, function_id.id())
@@ -998,13 +955,10 @@ def get_gpu_ids():
raise Exception("ray.get_gpu_ids() currently does not work in PYTHON "
"MODE.")
if not global_worker.use_raylet:
assigned_ids = global_worker.local_scheduler_client.gpu_ids()
else:
all_resource_ids = global_worker.local_scheduler_client.resource_ids()
assigned_ids = [
resource_id for resource_id, _ in all_resource_ids.get("GPU", [])
]
all_resource_ids = global_worker.local_scheduler_client.resource_ids()
assigned_ids = [
resource_id for resource_id, _ in all_resource_ids.get("GPU", [])
]
# If the user had already set CUDA_VISIBLE_DEVICES, then respect that (in
# the sense that only GPU IDs that appear in CUDA_VISIBLE_DEVICES should be
# returned).
@@ -1019,17 +973,11 @@ def get_gpu_ids():
def get_resource_ids():
"""Get the IDs of the resources that are available to the worker.
This function is only supported in the raylet code path.
Returns:
A dictionary mapping the name of a resource to a list of pairs, where
each pair consists of the ID of a resource and the fraction of that
resource reserved for this worker.
"""
if not global_worker.use_raylet:
raise Exception("ray.get_resource_ids() is only supported in the "
"raylet code path.")
if _mode() == LOCAL_MODE:
raise Exception(
"ray.get_resource_ids() currently does not work in PYTHON "
@@ -1112,22 +1060,8 @@ def error_applies_to_driver(error_key, worker=global_worker):
def error_info(worker=global_worker):
"""Return information about failed tasks."""
worker.check_connected()
if worker.use_raylet:
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
errors = []
for error_key in error_keys:
if error_applies_to_driver(error_key, worker=worker):
error_contents = worker.redis_client.hgetall(error_key)
error_contents = {
"type": ray.utils.decode(error_contents[b"type"]),
"message": ray.utils.decode(error_contents[b"message"]),
"data": ray.utils.decode(error_contents[b"data"])
}
errors.append(error_contents)
return errors
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
def _initialize_serialization(driver_id, worker=global_worker):
@@ -1223,7 +1157,6 @@ def _initialize_serialization(driver_id, worker=global_worker):
def get_address_info_from_redis_helper(redis_address,
node_ip_address,
use_raylet=True,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
@@ -1231,118 +1164,50 @@ def get_address_info_from_redis_helper(redis_address,
redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=redis_password)
if not use_raylet:
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
client_keys = redis_client.keys("{}*".format(
ray.gcs_utils.DB_CLIENT_PREFIX))
# Filter to live clients on the same node and do some basic checking.
plasma_managers = []
local_schedulers = []
for key in client_keys:
info = redis_client.hgetall(key)
# Ignore clients that were deleted.
deleted = info[b"deleted"]
deleted = bool(int(deleted))
if deleted:
continue
assert b"ray_client_id" in info
assert b"node_ip_address" in info
assert b"client_type" in info
client_node_ip_address = ray.utils.decode(info[b"node_ip_address"])
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
if ray.utils.decode(info[b"client_type"]) == "plasma_manager":
plasma_managers.append(info)
elif (ray.utils.decode(
info[b"client_type"]) == "local_scheduler"):
local_schedulers.append(info)
# Make sure that we got at least one plasma manager and local
# scheduler.
assert len(plasma_managers) >= 1
assert len(local_schedulers) >= 1
# Build the address information.
object_store_addresses = []
for manager in plasma_managers:
address = ray.utils.decode(manager[b"manager_address"])
port = services.get_port(address)
object_store_addresses.append(
services.ObjectStoreAddress(
name=ray.utils.decode(manager[b"store_socket_name"]),
manager_name=ray.utils.decode(
manager[b"manager_socket_name"]),
manager_port=port))
scheduler_names = [
ray.utils.decode(scheduler[b"local_scheduler_socket_name"])
for scheduler in local_schedulers
]
client_info = {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_addresses": object_store_addresses,
"local_scheduler_socket_names": scheduler_names,
# Web UI should be running.
"webui_url": _webui_url_helper(redis_client)
}
return client_info
# Handle the raylet case.
else:
# In the raylet code path, all client data is stored in a zset at the
# key for the nil client.
client_key = b"CLIENT" + NIL_CLIENT_ID
clients = redis_client.zrange(client_key, 0, -1)
raylets = []
for client_message in clients:
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = ray.utils.decode(
client.NodeManagerAddress())
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
raylets.append(client)
# Make sure that at least one raylet has started locally.
# This handles a race condition where Redis has started but
# the raylet has not connected.
if len(raylets) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
object_store_addresses = [
services.ObjectStoreAddress(
name=ray.utils.decode(raylet.ObjectStoreSocketName()),
manager_name=None,
manager_port=None) for raylet in raylets
]
raylet_socket_names = [
ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets
]
return {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_addresses": object_store_addresses,
"raylet_socket_names": raylet_socket_names,
# Web UI should be running.
"webui_url": _webui_url_helper(redis_client)
}
# In the raylet code path, all client data is stored in a zset at the
# key for the nil client.
client_key = b"CLIENT" + NIL_CLIENT_ID
clients = redis_client.zrange(client_key, 0, -1)
raylets = []
for client_message in clients:
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = ray.utils.decode(client.NodeManagerAddress())
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
raylets.append(client)
# Make sure that at least one raylet has started locally.
# This handles a race condition where Redis has started but
# the raylet has not connected.
if len(raylets) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
object_store_addresses = [
ray.utils.decode(raylet.ObjectStoreSocketName()) for raylet in raylets
]
raylet_socket_names = [
ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets
]
return {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_addresses": object_store_addresses,
"raylet_socket_names": raylet_socket_names,
# Web UI should be running.
"webui_url": _webui_url_helper(redis_client)
}
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
use_raylet=True,
redis_password=None):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(
redis_address,
node_ip_address,
use_raylet=use_raylet,
redis_password=redis_password)
redis_address, node_ip_address, redis_password=redis_password)
except Exception:
if counter == num_retries:
raise
@@ -1414,7 +1279,6 @@ def _init(address_info=None,
plasma_directory=None,
huge_pages=False,
include_webui=True,
use_raylet=None,
plasma_store_socket_name=None,
raylet_socket_name=None,
temp_dir=None):
@@ -1474,7 +1338,6 @@ def _init(address_info=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which is a Jupyter notebook.
use_raylet: True if the new raylet code path should be used.
plasma_store_socket_name (str): If provided, it will specify the socket
name used by the plasma store.
raylet_socket_name (str): If provided, it will specify the socket path
@@ -1497,16 +1360,6 @@ def _init(address_info=None,
else:
driver_mode = SCRIPT_MODE
if use_raylet is None:
if os.environ.get("RAY_USE_XRAY") == "0":
# This environment variable is used in our testing setup.
logger.info("Detected environment variable 'RAY_USE_XRAY' with "
"value {}. This turns OFF xray.".format(
os.environ.get("RAY_USE_XRAY")))
use_raylet = False
else:
use_raylet = True
# Get addresses of existing services.
if address_info is None:
address_info = {}
@@ -1561,7 +1414,6 @@ def _init(address_info=None,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
include_webui=include_webui,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
@@ -1610,10 +1462,7 @@ def _init(address_info=None,
node_ip_address = services.get_node_ip_address(redis_address)
# Get the address info of the processes to connect to from Redis.
address_info = get_address_info_from_redis(
redis_address,
node_ip_address,
use_raylet=use_raylet,
redis_password=redis_password)
redis_address, node_ip_address, redis_password=redis_password)
# Connect this driver to Redis, the object store, and the local scheduler.
# Choose the first object store and local scheduler if there are multiple.
@@ -1625,18 +1474,11 @@ def _init(address_info=None,
driver_address_info = {
"node_ip_address": node_ip_address,
"redis_address": address_info["redis_address"],
"store_socket_name": (
address_info["object_store_addresses"][0].name),
"store_socket_name": address_info["object_store_addresses"][0],
"webui_url": address_info["webui_url"]
}
if not use_raylet:
driver_address_info["manager_socket_name"] = (
address_info["object_store_addresses"][0].manager_name)
driver_address_info["local_scheduler_socket_name"] = (
address_info["local_scheduler_socket_names"][0])
else:
driver_address_info["raylet_socket_name"] = (
address_info["raylet_socket_names"][0])
driver_address_info["raylet_socket_name"] = (
address_info["raylet_socket_names"][0])
# We only pass `temp_dir` to a worker (WORKER_MODE).
# It can't be a worker here.
@@ -1645,7 +1487,6 @@ def _init(address_info=None,
object_id_seed=object_id_seed,
mode=driver_mode,
worker=global_worker,
use_raylet=use_raylet,
redis_password=redis_password)
return address_info
@@ -1669,7 +1510,6 @@ def init(redis_address=None,
plasma_directory=None,
huge_pages=False,
include_webui=True,
use_raylet=None,
configure_logging=True,
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
@@ -1736,7 +1576,6 @@ def init(redis_address=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which is a Jupyter notebook.
use_raylet: True if the new raylet code path should be used.
configure_logging: True if allow the logging cofiguration here.
Otherwise, the users may want to configure it by their own.
logging_level: Logging level, default will be loging.INFO.
@@ -1767,22 +1606,6 @@ def init(redis_address=None,
else:
raise Exception("Perhaps you called ray.init twice by accident?")
if use_raylet is None:
if os.environ.get("RAY_USE_XRAY") == "0":
# This environment variable is used in our testing setup.
logger.info("Detected environment variable 'RAY_USE_XRAY' with "
"value {}. This turns OFF xray.".format(
os.environ.get("RAY_USE_XRAY")))
use_raylet = False
else:
use_raylet = True
if not use_raylet and redis_password is not None:
raise Exception("Setting the 'redis_password' argument is not "
"supported in legacy Ray. To run Ray with "
"password-protected Redis ports, set "
"'use_raylet=True'.")
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
@@ -1809,7 +1632,6 @@ def init(redis_address=None,
huge_pages=huge_pages,
include_webui=include_webui,
object_store_memory=object_store_memory,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name,
raylet_socket_name=raylet_socket_name,
temp_dir=temp_dir)
@@ -1887,9 +1709,6 @@ def print_error_messages_raylet(worker):
This runs in a separate thread on the driver and prints error messages in
the background.
"""
if not worker.use_raylet:
raise Exception("This function is specific to the raylet code path.")
worker.error_message_pubsub_client = worker.redis_client.pubsub(
ignore_subscribe_messages=True)
# Exports that are published after the call to
@@ -2004,7 +1823,6 @@ def connect(info,
object_id_seed=None,
mode=WORKER_MODE,
worker=global_worker,
use_raylet=True,
redis_password=None):
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
@@ -2015,7 +1833,6 @@ def connect(info,
deterministic.
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
LOCAL_MODE.
use_raylet: True if the new raylet code path should be used.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
"""
@@ -2038,7 +1855,6 @@ def connect(info,
worker.actor_id = NIL_ACTOR_ID
worker.connected = True
worker.set_mode(mode)
worker.use_raylet = use_raylet
# If running Ray in LOCAL_MODE, there is no need to create call
# create_worker or to start the worker service.
@@ -2067,7 +1883,6 @@ def connect(info,
traceback_str = traceback.format_exc()
ray.utils.push_error_to_driver_through_redis(
worker.redis_client,
worker.use_raylet,
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
driver_id=None)
@@ -2108,7 +1923,6 @@ def connect(info,
"driver_id": worker.worker_id,
"start_time": time.time(),
"plasma_store_socket": info["store_socket_name"],
"plasma_manager_socket": info.get("manager_socket_name"),
"local_scheduler_socket": info.get("local_scheduler_socket_name"),
"raylet_socket": info.get("raylet_socket_name")
}
@@ -2123,7 +1937,6 @@ def connect(info,
worker_dict = {
"node_ip_address": worker.node_ip_address,
"plasma_store_socket": info["store_socket_name"],
"plasma_manager_socket": info["manager_socket_name"],
"local_scheduler_socket": info["local_scheduler_socket_name"]
}
if redirect_worker_output:
@@ -2135,18 +1948,10 @@ def connect(info,
raise Exception("This code should be unreachable.")
# Create an object store client.
if not worker.use_raylet:
worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"],
info["manager_socket_name"], 64))
else:
worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"], "", 64))
worker.plasma_client = thread_safe_client(
plasma.connect(info["store_socket_name"], "", 64))
if not worker.use_raylet:
local_scheduler_socket = info["local_scheduler_socket_name"]
else:
local_scheduler_socket = info["raylet_socket_name"]
local_scheduler_socket = info["raylet_socket_name"]
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
@@ -2177,28 +1982,22 @@ def connect(info,
# rerun the driver.
nil_actor_counter = 0
driver_task = ray.local_scheduler.Task(
worker.task_driver_id, ray.ObjectID(NIL_FUNCTION_ID), [], 0,
worker.current_task_id, worker.task_index,
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, False, [], {"CPU": 0}, {}, worker.use_raylet)
driver_task = ray.raylet.Task(worker.task_driver_id,
ray.ObjectID(NIL_FUNCTION_ID), [], 0,
worker.current_task_id,
worker.task_index,
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID),
ray.ObjectID(NIL_ACTOR_ID),
nil_actor_counter, [], {"CPU": 0}, {})
# Add the driver task to the task table.
if not worker.use_raylet:
global_state._execute_command(
driver_task.task_id(), "RAY.TASK_TABLE_ADD",
driver_task.task_id().id(), TASK_STATUS_RUNNING,
NIL_LOCAL_SCHEDULER_ID,
driver_task.execution_dependencies_string(), 0,
ray.local_scheduler.task_to_string(driver_task))
else:
global_state._execute_command(
driver_task.task_id(), "RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().id(),
driver_task._serialized_raylet_task())
global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().id(),
driver_task._serialized_raylet_task())
# Set the driver's current task ID to the task ID assigned to the
# driver task.
@@ -2207,9 +2006,9 @@ def connect(info,
# A non-driver worker begins without an assigned task.
worker.current_task_id = ray.ObjectID(NIL_ID)
worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(
worker.local_scheduler_client = ray.raylet.LocalSchedulerClient(
local_scheduler_socket, worker.worker_id, is_worker,
worker.current_task_id, worker.use_raylet)
worker.current_task_id)
# Start the import thread
import_thread.ImportThread(worker, mode).start()
@@ -2221,16 +2020,10 @@ def connect(info,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
if not worker.use_raylet:
t = threading.Thread(
target=print_error_messages,
name="ray_print_error_messages",
args=(worker, ))
else:
t = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(worker, ))
t = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(worker, ))
# Making the thread a daemon causes it to exit when the main thread
# exits.
t.daemon = True
@@ -2238,7 +2031,7 @@ def connect(info,
# If we are using the raylet code path and we are not in local mode, start
# a background thread to periodically flush profiling data to the GCS.
if mode != LOCAL_MODE and worker.use_raylet:
if mode != LOCAL_MODE:
worker.profiler.start_flush_thread()
if mode == SCRIPT_MODE:
@@ -2395,6 +2188,9 @@ def register_custom_serializer(cls,
# worker and not across workers.
class_id = random_string()
# Make sure class_id is a string.
class_id = ray.utils.binary_to_hex(class_id)
if driver_id is None:
driver_id_bytes = worker.task_driver_id.id()
else:
@@ -2481,7 +2277,7 @@ def put(value, worker=global_worker):
# In LOCAL_MODE, ray.put is the identity operation.
return value
object_id = worker.local_scheduler_client.compute_put_id(
worker.current_task_id, worker.put_index, worker.use_raylet)
worker.current_task_id, worker.put_index)
worker.put_object(object_id, value)
worker.put_index += 1
return object_id
@@ -2554,21 +2350,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
raise Exception("num_returns cannot be greater than the number "
"of objects provided to ray.wait.")
timeout = timeout if timeout is not None else 2**30
if worker.use_raylet:
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
object_ids, num_returns, timeout, False)
else:
object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
object_ids, num_returns, timeout, False)
return ready_ids, remaining_ids
+1 -4
View File
@@ -88,10 +88,7 @@ if __name__ == "__main__":
tempfile_services.set_temp_root(args.temp_dir)
ray.worker.connect(
info,
mode=ray.WORKER_MODE,
use_raylet=(args.raylet_name is not None),
redis_password=args.redis_password)
info, mode=ray.WORKER_MODE, redis_password=args.redis_password)
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker
+3 -6
View File
@@ -19,13 +19,10 @@ import setuptools.command.build_ext as _build_ext
# NOTE: The lists below must be kept in sync with ray/CMakeLists.txt.
ray_files = [
"ray/core/src/common/thirdparty/redis/src/redis-server",
"ray/core/src/common/redis_module/libray_redis_module.so",
"ray/core/src/ray/thirdparty/redis/src/redis-server",
"ray/core/src/ray/gcs/redis_module/libray_redis_module.so",
"ray/core/src/plasma/plasma_store_server",
"ray/core/src/plasma/plasma_manager",
"ray/core/src/local_scheduler/local_scheduler",
"ray/core/src/local_scheduler/liblocal_scheduler_library_python.so",
"ray/core/src/global_scheduler/global_scheduler",
"ray/core/src/ray/raylet/liblocal_scheduler_library_python.so",
"ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet",
"ray/WebUI.ipynb"
]
-131
View File
@@ -1,131 +0,0 @@
cmake_minimum_required(VERSION 3.4)
project(common)
if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES")
include_directories("${CMAKE_CURRENT_LIST_DIR}/lib/python")
endif ()
add_subdirectory(redis_module)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g")
include_directories(thirdparty/ae)
# Compile flatbuffers
set(COMMON_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/common.fbs")
set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/)
set(COMMON_FBS_OUTPUT_FILES
"${OUTPUT_DIR}/common_generated.h")
add_custom_target(gen_common_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES})
add_custom_command(
OUTPUT ${COMMON_FBS_OUTPUT_FILES}
# The --gen-object-api flag generates a C++ class MessageT for each
# flatbuffers message Message, which can be used to store deserialized
# messages in data structures. This is currently used for ObjectInfo for
# example.
COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${COMMON_FBS_SRC} --gen-object-api --scoped-enums
DEPENDS ${FBS_DEPENDS}
COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}"
VERBATIM)
if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES")
add_custom_target(gen_common_python_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES})
# Generate Python bindings for the flatbuffers objects.
set(PYTHON_OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/../../python/ray/core/generated/)
add_custom_command(
TARGET gen_common_python_fbs
COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${COMMON_FBS_SRC}
DEPENDS ${FBS_DEPENDS}
COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}"
VERBATIM)
# Encode the fact that the ray redis module requires the autogenerated
# flatbuffer files to compile.
add_dependencies(ray_redis_module gen_common_python_fbs)
add_dependencies(gen_common_python_fbs flatbuffers_ep)
endif()
if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES")
add_custom_target(gen_common_java_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES})
# Generate Java bindings for the flatbuffers objects.
set(JAVA_OUTPUT_DIR ${CMAKE_BINARY_DIR}/generated/java)
add_custom_command(
TARGET gen_common_java_fbs
COMMAND ${FLATBUFFERS_COMPILER} -j -o ${JAVA_OUTPUT_DIR} ${COMMON_FBS_SRC}
DEPENDS ${FBS_DEPENDS}
COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}"
VERBATIM)
# Encode the fact that the ray redis module requires the autogenerated
# flatbuffer files to compile.
add_dependencies(ray_redis_module gen_common_java_fbs)
add_dependencies(gen_common_java_fbs flatbuffers_ep)
endif()
add_custom_target(
hiredis
COMMAND make
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/thirdparty/hiredis)
add_library(common STATIC
event_loop.cc
common.cc
common_protocol.cc
task.cc
io.cc
net.cc
logging.cc
state/redis.cc
state/table.cc
state/object_table.cc
state/task_table.cc
state/db_client_table.cc
state/driver_table.cc
state/actor_notification_table.cc
state/local_scheduler_table.cc
state/error_table.cc
thirdparty/ae/ae.c
thirdparty/sha256.c)
add_dependencies(common arrow)
if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES")
add_dependencies(common gen_common_python_fbs)
endif()
if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES")
add_dependencies(common gen_common_java_fbs)
endif()
target_link_libraries(common "${CMAKE_CURRENT_LIST_DIR}/thirdparty/hiredis/libhiredis.a")
function(define_test test_name library)
add_executable(${test_name} test/${test_name}.cc ${ARGN})
add_dependencies(${test_name} hiredis flatbuffers_ep)
target_link_libraries(${test_name} common ${FLATBUFFERS_STATIC_LIB} ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${library} -lpthread)
target_compile_options(${test_name} PUBLIC "-DPLASMA_TEST -DLOCAL_SCHEDULER_TEST -DCOMMON_TEST -DRAY_COMMON_LOG_LEVEL=4")
endfunction()
define_test(db_tests "")
define_test(io_tests "")
define_test(task_tests "")
define_test(redis_tests "")
define_test(task_table_tests "")
define_test(object_table_tests "")
add_custom_target(copy_redis ALL)
foreach(file "redis-cli" "redis-server")
add_custom_command(TARGET copy_redis POST_BUILD
COMMAND ${CMAKE_COMMAND} -E
copy ${CMAKE_CURRENT_LIST_DIR}/../../thirdparty/pkg/redis/src/${file}
${CMAKE_BINARY_DIR}/src/common/thirdparty/redis/src/${file})
endforeach()
-20
View File
@@ -1,20 +0,0 @@
#include "common.h"
#include <chrono>
#include <stdio.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include "io.h"
#include <functional>
const unsigned char NIL_DIGEST[DIGEST_SIZE] = {0};
int64_t current_time_ms() {
std::chrono::milliseconds ms_since_epoch =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch());
return ms_since_epoch.count();
}
-75
View File
@@ -1,75 +0,0 @@
#ifndef COMMON_H
#define COMMON_H
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS
#endif
#include <errno.h>
#include <inttypes.h>
#ifndef _WIN32
#include <execinfo.h>
#endif
#ifdef __cplusplus
#include <functional>
extern "C" {
#endif
#include "sha256.h"
#ifdef __cplusplus
}
#endif
#include "arrow/util/macros.h"
#include "plasma/common.h"
#include "ray/id.h"
#include "ray/util/logging.h"
#include "state/ray_config.h"
/** Definitions for Ray logging levels. */
#define RAY_COMMON_DEBUG 0
#define RAY_COMMON_INFO 1
#define RAY_COMMON_WARNING 2
#define RAY_COMMON_ERROR 3
#define RAY_COMMON_FATAL 4
/**
* RAY_COMMON_LOG_LEVEL should be defined to one of the above logging level
* integer values. Any logging statement in the code with a logging level
* greater than or equal to RAY_COMMON_LOG_LEVEL will be outputted to stderr.
* The default logging level is INFO. */
#ifndef RAY_COMMON_LOG_LEVEL
#define RAY_COMMON_LOG_LEVEL RAY_COMMON_INFO
#endif
/* These are exit codes for common errors that can occur in Ray components. */
#define EXIT_COULD_NOT_BIND_PORT -2
/** This macro indicates that this pointer owns the data it is pointing to
* and is responsible for freeing it. */
#define OWNER
/** The worker ID is the ID of a worker or driver. */
typedef ray::UniqueID WorkerID;
typedef ray::UniqueID DBClientID;
#define MAX(x, y) ((x) >= (y) ? (x) : (y))
#define MIN(x, y) ((x) <= (y) ? (x) : (y))
/** Definitions for computing hash digests. */
#define DIGEST_SIZE SHA256_BLOCK_SIZE
extern const unsigned char NIL_DIGEST[DIGEST_SIZE];
/**
* Return the current time in milliseconds since the Unix epoch.
*
* @return The number of milliseconds since the Unix epoch.
*/
int64_t current_time_ms();
#endif
-32
View File
@@ -1,32 +0,0 @@
# Task specifications, task instances and task logs
A *task specification* contains all information that is needed for computing
the results of a task:
- The ID of the task
- The function ID of the function that executes the task
- The arguments (either object IDs for pass by reference
or values for pass by value)
- The IDs of the result objects
From these, a task ID can be computed which is also stored in the task
specification.
A *task* represents the execution of a task specification.
It consists of:
- A scheduling state (WAITING, SCHEDULED, RUNNING, DONE)
- The target node where the task is scheduled or executed
- The task specification
The task data structures are defined in `common/task.h`.
The *task table* is a mapping from the task ID to the *task* information. It is
updated by various parts of the system:
1. The local scheduler writes it with status WAITING when submits a task to the global scheduler
2. The global scheduler appends an update WAITING -> SCHEDULED together with the node ID when assigning the task to a local scheduler
3. The local scheduler appends an update SCHEDULED -> RUNNING when it assigns a task to a worker
4. The local scheduler appends an update RUNNING -> DONE when the task finishes execution
The task table is defined in `common/state/task_table.h`.
-63
View File
@@ -1,63 +0,0 @@
#include "event_loop.h"
#include "common.h"
#include <errno.h>
#define INITIAL_EVENT_LOOP_SIZE 1024
event_loop *event_loop_create(void) {
return aeCreateEventLoop(INITIAL_EVENT_LOOP_SIZE);
}
void event_loop_destroy(event_loop *loop) {
/* Clean up timer events. This is to make valgrind happy. */
aeTimeEvent *te = loop->timeEventHead;
while (te) {
aeTimeEvent *next = te->next;
free(te);
te = next;
}
aeDeleteEventLoop(loop);
}
bool event_loop_add_file(event_loop *loop,
int fd,
int events,
event_loop_file_handler handler,
void *context) {
/* Try to add the file descriptor. */
int err = aeCreateFileEvent(loop, fd, events, handler, context);
/* If it cannot be added, increase the size of the event loop. */
if (err == AE_ERR && errno == ERANGE) {
err = aeResizeSetSize(loop, 3 * aeGetSetSize(loop) / 2);
if (err != AE_OK) {
return false;
}
err = aeCreateFileEvent(loop, fd, events, handler, context);
}
/* In any case, test if there were errors. */
return (err == AE_OK);
}
void event_loop_remove_file(event_loop *loop, int fd) {
aeDeleteFileEvent(loop, fd, EVENT_LOOP_READ | EVENT_LOOP_WRITE);
}
int64_t event_loop_add_timer(event_loop *loop,
int64_t timeout,
event_loop_timer_handler handler,
void *context) {
return aeCreateTimeEvent(loop, timeout, handler, context, NULL);
}
int event_loop_remove_timer(event_loop *loop, int64_t id) {
return aeDeleteTimeEvent(loop, id);
}
void event_loop_run(event_loop *loop) {
aeMain(loop);
}
void event_loop_stop(event_loop *loop) {
aeStop(loop);
}
-103
View File
@@ -1,103 +0,0 @@
#ifndef EVENT_LOOP_H
#define EVENT_LOOP_H
#include <stdint.h>
extern "C" {
#ifdef _WIN32
/* Quirks mean that Windows version needs to be included differently */
#include <hiredis/hiredis.h>
#include <ae.h>
#else
#include "ae/ae.h"
#endif
}
/* Unique timer ID that will be generated when the timer is added to the
* event loop. Will not be reused later on in another call
* to event_loop_add_timer. */
typedef long long timer_id;
typedef aeEventLoop event_loop;
/* File descriptor is readable. */
#define EVENT_LOOP_READ AE_READABLE
/* File descriptor is writable. */
#define EVENT_LOOP_WRITE AE_WRITABLE
/* Constant specifying that the timer is done and it will be removed. */
#define EVENT_LOOP_TIMER_DONE AE_NOMORE
/* Signature of the handler that will be called when there is a new event
* on the file descriptor that this handler has been registered for. The
* context is the one that was passed into add_file by the user. The
* events parameter indicates which event is available on the file,
* it can be EVENT_LOOP_READ or EVENT_LOOP_WRITE. */
typedef void (*event_loop_file_handler)(event_loop *loop,
int fd,
void *context,
int events);
/* This handler will be called when a timer times out. The id of the timer
* as well as the context that was specified when registering this handler
* are passed as arguments. The return is the number of milliseconds the
* timer shall be reset to or EVENT_LOOP_TIMER_DONE if the timer shall
* not be triggered again. */
typedef int (*event_loop_timer_handler)(event_loop *loop,
timer_id timer_id,
void *context);
/* Create and return a new event loop. */
event_loop *event_loop_create(void);
/* Deallocate space associated with the event loop that was created
* with the "create" function. */
void event_loop_destroy(event_loop *loop);
/* Register a handler that will be called any time a new event happens on
* a file descriptor. Can specify a context that will be passed as an
* argument to the handler. Currently there can only be one handler per file.
* The events parameter specifies which events we listen to: EVENT_LOOP_READ
* or EVENT_LOOP_WRITE. */
bool event_loop_add_file(event_loop *loop,
int fd,
int events,
event_loop_file_handler handler,
void *context);
/* Remove a registered file event handler from the event loop. */
void event_loop_remove_file(event_loop *loop, int fd);
/** Register a handler that will be called after a time slice of
* "timeout" milliseconds.
*
* @param loop The event loop.
* @param timeout The timeout in milliseconds.
* @param handler The handler for the timeout.
* @param context User context that can be passed in and will be passed in
* as an argument for the timer handler.
* @return The ID of the timer.
*/
int64_t event_loop_add_timer(event_loop *loop,
int64_t timeout,
event_loop_timer_handler handler,
void *context);
/**
* Remove a registered time event handler from the event loop. Can be called
* multiple times on the same timer.
*
* @param loop The event loop.
* @param timer_id The ID of the timer to be removed.
* @return Returns 0 if the removal was successful.
*/
int event_loop_remove_timer(event_loop *loop, int64_t timer_id);
/* Run the event loop. */
void event_loop_run(event_loop *loop);
/* Stop the event loop. */
void event_loop_stop(event_loop *loop);
#endif
-203
View File
@@ -1,203 +0,0 @@
// Indices into resource vectors.
// A resource vector maps a resource index to the number
// of units of that resource required.
table Arg {
// Object ID for pass-by-reference arguments. Normally there is only one
// object ID in this list which represents the object that is being passed.
// However to support reducers in a MapReduce workload, we also support
// passing multiple object IDs for each argument.
object_ids: [string];
// Data for pass-by-value arguments.
data: string;
}
table ResourcePair {
// The name of the resource.
key: string;
// The quantity of the resource.
value: double;
}
// NOTE: This enum is duplicate with the `Language` enum in `gcs.fbs`,
// because we cannot include this file in `gcs.fbs` due to cyclic dependency.
// TODO(raulchen): remove it once we get rid of legacy ray.
enum TaskLanguage:int {
PYTHON = 0,
JAVA = 1
}
table TaskInfo {
// ID of the driver that created this task.
driver_id: string;
// Task ID of the task.
task_id: string;
// Task ID of the parent task.
parent_task_id: string;
// A count of the number of tasks submitted by the parent task before this one.
parent_counter: int;
// The ID of the actor to create if this is an actor creation task.
actor_creation_id: string;
// The dummy object ID of the actor creation task if this is an actor method.
actor_creation_dummy_object_id: string;
// Actor ID of the task. This is the actor that this task is executed on
// or NIL_ACTOR_ID if the task is just a normal task.
actor_id: string;
// The ID of the handle that was used to submit the task. This should be
// unique across handles with the same actor_id.
actor_handle_id: string;
// Number of tasks that have been submitted to this actor so far.
actor_counter: int;
// True if this task is an actor checkpoint task and false otherwise.
is_actor_checkpoint_method: bool;
// Function ID of the task.
function_id: string;
// Task arguments.
args: [Arg];
// Object IDs of return values.
returns: [string];
// The required_resources vector indicates the quantities of the different
// resources required by this task.
required_resources: [ResourcePair];
// The resources required for placing this task on a node. If this is empty,
// then the placement resources are equal to the required_resources.
required_placement_resources: [ResourcePair];
// The language that this task belongs to
language: TaskLanguage;
// Function descriptor, which is a list of strings that can
// uniquely describe a function.
// For a Python function, it should be: [module_name, class_name, function_name]
// For a Java function, it should be: [class_name, method_name, type_descriptor]
// TODO(hchen): after changing Python worker to use function_descriptor,
// function_id can be removed.
function_descriptor: [string];
}
// Object information data structure.
// NOTE(pcm): This structure is replicated in
// https://github.com/apache/arrow/blob/master/cpp/src/plasma/format/common.fbs,
// so if you modify it, you should also modify that one.
table ObjectInfo {
// Object ID of this object.
object_id: string;
// Number of bytes the content of this object occupies in memory.
data_size: long;
// Number of bytes the metadata of this object occupies in memory.
metadata_size: long;
// Number of clients using the objects.
ref_count: int;
// Unix epoch of when this object was created.
create_time: long;
// How long creation of this object took.
construct_duration: long;
// Hash of the object content. If the object is not sealed yet this is
// an empty string.
digest: string;
// Specifies if this object was deleted or added.
is_deletion: bool;
}
root_type TaskInfo;
table TaskExecutionDependencies {
// A list of object IDs representing this task's dependencies at execution
// time.
execution_dependencies: [string];
}
root_type TaskExecutionDependencies;
table SubscribeToNotificationsReply {
// The object ID of the object that the notification is about.
object_id: string;
// The size of the object.
object_size: long;
// The IDs of the managers that contain this object.
manager_ids: [string];
}
root_type SubscribeToNotificationsReply;
table TaskReply {
// The task ID of the task that the message is about.
task_id: string;
// The state of the task. This is encoded as a bit mask of scheduling_state
// enum values in task.h.
state: long;
// A local scheduler ID.
local_scheduler_id: string;
// A string of bytes representing the task's TaskExecutionDependencies.
execution_dependencies: string;
// A string of bytes representing the task specification.
task_spec: string;
// The number of times the task was spilled back by local schedulers.
spillback_count: long;
// A boolean representing whether the update was successful. This field
// should only be used for test-and-set operations.
updated: bool;
}
root_type TaskReply;
table SubscribeToDBClientTableReply {
// The db client ID of the client that the message is about.
db_client_id: string;
// The type of the client.
client_type: string;
// If the client is a local scheduler, this is the address of the plasma
// manager that the local scheduler is connected to. Otherwise, it is empty.
manager_address: string;
// True if the message is about the addition of a client and false if it is
// about the deletion of a client.
is_insertion: bool;
}
root_type SubscribeToDBClientTableReply;
table LocalSchedulerInfoMessage {
// The db client ID of the client that the message is about.
db_client_id: string;
// The total number of workers that are connected to this local scheduler.
total_num_workers: long;
// The number of tasks queued in this local scheduler.
task_queue_length: long;
// The number of workers that are available and waiting for tasks.
available_workers: long;
// The resources generally available to this local scheduler.
static_resources: [ResourcePair];
// The resources currently available to this local scheduler.
dynamic_resources: [ResourcePair];
// Whether the local scheduler is dead. If true, then all other fields
// besides `db_client_id` will not be set.
is_dead: bool;
}
root_type LocalSchedulerInfoMessage;
table ResultTableReply {
// The task ID of the task that created the object.
task_id: string;
// Whether the task created the object through a ray.put.
is_put: bool;
// The size of the object created.
data_size: long;
// The hash of the object created.
hash: string;
}
root_type ResultTableReply;
table DriverTableMessage {
// The driver ID of the driver that died.
driver_id: string;
}
table ActorCreationNotification {
// The ID of the actor that was created.
actor_id: string;
// The ID of the driver that created the actor.
driver_id: string;
// The ID of the local scheduler that created the actor.
local_scheduler_id: string;
}
-416
View File
@@ -1,416 +0,0 @@
#include "io.h"
#include <stdlib.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <string.h>
#include <stdio.h>
#include <inttypes.h>
#include <stdarg.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <netdb.h>
#include "common.h"
#include "event_loop.h"
#ifndef _WIN32
/* This function is actually not declared in standard POSIX, so declare it. */
extern int usleep(useconds_t usec);
#endif
int bind_inet_sock(const int port, bool shall_listen) {
struct sockaddr_in name;
int socket_fd = socket(PF_INET, SOCK_STREAM, 0);
if (socket_fd < 0) {
RAY_LOG(ERROR) << "socket() failed for port " << port;
return -1;
}
name.sin_family = AF_INET;
name.sin_port = htons(port);
name.sin_addr.s_addr = htonl(INADDR_ANY);
int on = 1;
/* TODO(pcm): http://stackoverflow.com/q/1150635 */
if (ioctl(socket_fd, FIONBIO, (char *) &on) < 0) {
RAY_LOG(ERROR) << "ioctl failed";
close(socket_fd);
return -1;
}
int *const pon = (int *const) & on;
if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, pon, sizeof(on)) < 0) {
RAY_LOG(ERROR) << "setsockopt failed for port " << port;
close(socket_fd);
return -1;
}
if (bind(socket_fd, (struct sockaddr *) &name, sizeof(name)) < 0) {
RAY_LOG(ERROR) << "Bind failed for port " << port;
close(socket_fd);
return -1;
}
if (shall_listen && listen(socket_fd, 128) == -1) {
RAY_LOG(ERROR) << "Could not listen to socket " << port;
close(socket_fd);
return -1;
}
return socket_fd;
}
int bind_ipc_sock(const char *socket_pathname, bool shall_listen) {
struct sockaddr_un socket_address;
int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (socket_fd < 0) {
RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname;
return -1;
}
/* Tell the system to allow the port to be reused. */
int on = 1;
if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, (char *) &on,
sizeof(on)) < 0) {
RAY_LOG(ERROR) << "setsockopt failed for pathname " << socket_pathname;
close(socket_fd);
return -1;
}
unlink(socket_pathname);
memset(&socket_address, 0, sizeof(socket_address));
socket_address.sun_family = AF_UNIX;
if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) {
RAY_LOG(ERROR) << "Socket pathname is too long.";
close(socket_fd);
return -1;
}
strncpy(socket_address.sun_path, socket_pathname,
strlen(socket_pathname) + 1);
if (bind(socket_fd, (struct sockaddr *) &socket_address,
sizeof(socket_address)) != 0) {
RAY_LOG(ERROR) << "Bind failed for pathname " << socket_pathname;
close(socket_fd);
return -1;
}
if (shall_listen && listen(socket_fd, 128) == -1) {
RAY_LOG(ERROR) << "Could not listen to socket " << socket_pathname;
close(socket_fd);
return -1;
}
return socket_fd;
}
int connect_ipc_sock_retry(const char *socket_pathname,
int num_retries,
int64_t timeout) {
/* Pick the default values if the user did not specify. */
if (num_retries < 0) {
num_retries = RayConfig::instance().num_connect_attempts();
}
if (timeout < 0) {
timeout = RayConfig::instance().connect_timeout_milliseconds();
}
RAY_CHECK(socket_pathname);
int fd = -1;
for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) {
fd = connect_ipc_sock(socket_pathname);
if (fd >= 0) {
break;
}
if (num_attempts == 0) {
RAY_LOG(ERROR) << "Connection to socket failed for pathname "
<< socket_pathname;
}
/* Sleep for timeout milliseconds. */
usleep(timeout * 1000);
}
/* If we could not connect to the socket, exit. */
if (fd == -1) {
RAY_LOG(FATAL) << "Could not connect to socket " << socket_pathname;
}
return fd;
}
int connect_ipc_sock(const char *socket_pathname) {
struct sockaddr_un socket_address;
int socket_fd;
socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (socket_fd < 0) {
RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname;
return -1;
}
memset(&socket_address, 0, sizeof(socket_address));
socket_address.sun_family = AF_UNIX;
if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) {
RAY_LOG(ERROR) << "Socket pathname is too long.";
return -1;
}
strncpy(socket_address.sun_path, socket_pathname,
strlen(socket_pathname) + 1);
if (connect(socket_fd, (struct sockaddr *) &socket_address,
sizeof(socket_address)) != 0) {
close(socket_fd);
return -1;
}
return socket_fd;
}
int connect_inet_sock_retry(const char *ip_addr,
int port,
int num_retries,
int64_t timeout) {
/* Pick the default values if the user did not specify. */
if (num_retries < 0) {
num_retries = RayConfig::instance().num_connect_attempts();
}
if (timeout < 0) {
timeout = RayConfig::instance().connect_timeout_milliseconds();
}
RAY_CHECK(ip_addr);
int fd = -1;
for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) {
fd = connect_inet_sock(ip_addr, port);
if (fd >= 0) {
break;
}
if (num_attempts == 0) {
RAY_LOG(ERROR) << "Connection to socket failed for address " << ip_addr
<< ":" << port;
}
/* Sleep for timeout milliseconds. */
usleep(timeout * 1000);
}
/* If we could not connect to the socket, exit. */
if (fd == -1) {
RAY_LOG(FATAL) << "Could not connect to address " << ip_addr << ":" << port;
}
return fd;
}
int connect_inet_sock(const char *ip_addr, int port) {
int fd = socket(PF_INET, SOCK_STREAM, 0);
if (fd < 0) {
RAY_LOG(ERROR) << "socket() failed for address " << ip_addr << ":" << port;
return -1;
}
struct hostent *manager = gethostbyname(ip_addr); /* TODO(pcm): cache this */
if (!manager) {
RAY_LOG(ERROR) << "Failed to get hostname from address " << ip_addr << ":"
<< port;
close(fd);
return -1;
}
struct sockaddr_in addr;
addr.sin_family = AF_INET;
memcpy(&addr.sin_addr.s_addr, manager->h_addr_list[0], manager->h_length);
addr.sin_port = htons(port);
if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) != 0) {
close(fd);
return -1;
}
return fd;
}
int accept_client(int socket_fd) {
int client_fd = accept(socket_fd, NULL, NULL);
if (client_fd < 0) {
RAY_LOG(ERROR) << "Error reading from socket.";
return -1;
}
return client_fd;
}
int write_bytes(int fd, uint8_t *cursor, size_t length) {
ssize_t nbytes = 0;
size_t bytesleft = length;
size_t offset = 0;
while (bytesleft > 0) {
/* While we haven't written the whole message, write to the file
* descriptor, advance the cursor, and decrease the amount left to write. */
nbytes = write(fd, cursor + offset, bytesleft);
if (nbytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
continue;
}
return -1; /* Errno will be set. */
} else if (0 == nbytes) {
/* Encountered early EOF. */
return -1;
}
RAY_CHECK(nbytes > 0);
bytesleft -= nbytes;
offset += nbytes;
}
return 0;
}
int do_write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) {
int64_t version = RayConfig::instance().ray_protocol_version();
int closed;
closed = write_bytes(fd, (uint8_t *) &version, sizeof(version));
if (closed) {
return closed;
}
closed = write_bytes(fd, (uint8_t *) &type, sizeof(type));
if (closed) {
return closed;
}
closed = write_bytes(fd, (uint8_t *) &length, sizeof(length));
if (closed) {
return closed;
}
closed = write_bytes(fd, bytes, length * sizeof(char));
if (closed) {
return closed;
}
return 0;
}
int write_message(int fd,
int64_t type,
int64_t length,
uint8_t *bytes,
std::mutex *mutex) {
if (mutex != NULL) {
std::unique_lock<std::mutex> guard(*mutex);
return do_write_message(fd, type, length, bytes);
} else {
return do_write_message(fd, type, length, bytes);
}
}
int read_bytes(int fd, uint8_t *cursor, size_t length) {
ssize_t nbytes = 0;
/* Termination condition: EOF or read 'length' bytes total. */
size_t bytesleft = length;
size_t offset = 0;
while (bytesleft > 0) {
nbytes = read(fd, cursor + offset, bytesleft);
if (nbytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
continue;
}
return -1; /* Errno will be set. */
} else if (0 == nbytes) {
/* Encountered early EOF. */
return -1;
}
RAY_CHECK(nbytes > 0);
bytesleft -= nbytes;
offset += nbytes;
}
return 0;
}
void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes) {
int64_t version;
int closed = read_bytes(fd, (uint8_t *) &version, sizeof(version));
if (closed) {
goto disconnected;
}
RAY_CHECK(version == RayConfig::instance().ray_protocol_version());
closed = read_bytes(fd, (uint8_t *) type, sizeof(*type));
if (closed) {
goto disconnected;
}
closed = read_bytes(fd, (uint8_t *) length, sizeof(*length));
if (closed) {
goto disconnected;
}
*bytes = (uint8_t *) malloc(*length * sizeof(uint8_t));
closed = read_bytes(fd, *bytes, *length);
if (closed) {
free(*bytes);
goto disconnected;
}
return;
disconnected:
/* Handle the case in which the socket is closed. */
*type = static_cast<int64_t>(CommonMessageType::DISCONNECT_CLIENT);
*length = 0;
*bytes = NULL;
return;
}
uint8_t *read_message_async(event_loop *loop, int sock) {
int64_t size;
int error = read_bytes(sock, (uint8_t *) &size, sizeof(int64_t));
if (error < 0) {
/* The other side has closed the socket. */
RAY_LOG(DEBUG) << "Socket has been closed, or some other error has "
<< "occurred.";
if (loop != NULL) {
event_loop_remove_file(loop, sock);
}
close(sock);
return NULL;
}
uint8_t *message = (uint8_t *) malloc(size);
error = read_bytes(sock, message, size);
if (error < 0) {
/* The other side has closed the socket. */
RAY_LOG(DEBUG) << "Socket has been closed, or some other error has "
<< "occurred.";
if (loop != NULL) {
event_loop_remove_file(loop, sock);
}
close(sock);
return NULL;
}
return message;
}
int64_t read_vector(int fd, int64_t *type, std::vector<uint8_t> &buffer) {
int64_t version;
int closed = read_bytes(fd, (uint8_t *) &version, sizeof(version));
if (closed) {
goto disconnected;
}
RAY_CHECK(version == RayConfig::instance().ray_protocol_version());
int64_t length;
closed = read_bytes(fd, (uint8_t *) type, sizeof(*type));
if (closed) {
goto disconnected;
}
closed = read_bytes(fd, (uint8_t *) &length, sizeof(length));
if (closed) {
goto disconnected;
}
if (static_cast<size_t>(length) > buffer.size()) {
buffer.resize(length);
}
closed = read_bytes(fd, buffer.data(), length);
if (closed) {
goto disconnected;
}
return length;
disconnected:
/* Handle the case in which the socket is closed. */
*type = static_cast<int64_t>(CommonMessageType::DISCONNECT_CLIENT);
return 0;
}
void write_log_message(int fd, const char *message) {
/* Account for the \0 at the end of the string. */
do_write_message(fd, static_cast<int64_t>(CommonMessageType::LOG_MESSAGE),
strlen(message) + 1, (uint8_t *) message);
}
char *read_log_message(int fd) {
uint8_t *bytes;
int64_t type;
int64_t length;
read_message(fd, &type, &length, &bytes);
RAY_CHECK(static_cast<CommonMessageType>(type) ==
CommonMessageType::LOG_MESSAGE);
return (char *) bytes;
}
-228
View File
@@ -1,228 +0,0 @@
#ifndef IO_H
#define IO_H
#include <stdint.h>
#include <stdlib.h>
#include <mutex>
#include <vector>
struct aeEventLoop;
typedef aeEventLoop event_loop;
enum class CommonMessageType : int32_t {
/** Disconnect a client. */
DISCONNECT_CLIENT,
/** Log a message from a client. */
LOG_MESSAGE,
/** Submit a task to the local scheduler. */
SUBMIT_TASK,
};
/* Helper functions for socket communication. */
/**
* Binds to an Internet socket at the given port. Removes any existing file at
* the pathname. Returns a non-blocking file descriptor for the socket, or -1
* if an error occurred.
*
* @note Since the returned file descriptor is non-blocking, it is not
* recommended to use the Linux read and write calls directly, since these
* might read or write a partial message. Instead, use the provided
* write_message and read_message methods.
*
* @param port The port to bind to.
* @param shall_listen Are we also starting to listen on the socket?
* @return A non-blocking file descriptor for the socket, or -1 if an error
* occurs.
*/
int bind_inet_sock(const int port, bool shall_listen);
/**
* Binds to a Unix domain streaming socket at the given
* pathname. Removes any existing file at the pathname.
*
* @param socket_pathname The pathname for the socket.
* @param shall_listen Are we also starting to listen on the socket?
* @return A blocking file descriptor for the socket, or -1 if an error
* occurs.
*/
int bind_ipc_sock(const char *socket_pathname, bool shall_listen);
/**
* Connect to a Unix domain streaming socket at the given
* pathname.
*
* @param socket_pathname The pathname for the socket.
* @return A file descriptor for the socket, or -1 if an error occurred.
*/
int connect_ipc_sock(const char *socket_pathname);
/**
* Connect to a Unix domain streaming socket at the given
* pathname, or fail after some number of retries.
*
* @param socket_pathname The pathname for the socket.
* @param num_retries The number of times to retry the connection
* before exiting. If -1 is provided, then this defaults to
* num_connect_attempts.
* @param timeout The number of milliseconds to wait in between
* retries. If -1 is provided, then this defaults to
* connect_timeout_milliseconds.
* @return A file descriptor for the socket, or -1 if an error occurred.
*/
int connect_ipc_sock_retry(const char *socket_pathname,
int num_retries,
int64_t timeout);
/**
* Connect to an Internet socket at the given address and port.
*
* @param ip_addr The IP address to connect to.
* @param port The port number to connect to.
*
* @param socket_pathname The pathname for the socket.
* @return A file descriptor for the socket, or -1 if an error occurred.
*/
int connect_inet_sock(const char *ip_addr, int port);
/**
* Connect to an Internet socket at the given address and port, or fail after
* some number of retries.
*
* @param ip_addr The IP address to connect to.
* @param port The port number to connect to.
* @param num_retries The number of times to retry the connection
* before exiting. If -1 is provided, then this defaults to
* num_connect_attempts.
* @param timeout The number of milliseconds to wait in between
* retries. If -1 is provided, then this defaults to
* connect_timeout_milliseconds.
* @return A file descriptor for the socket, or -1 if an error occurred.
*/
int connect_inet_sock_retry(const char *ip_addr,
int port,
int num_retries,
int64_t timeout);
/**
* Accept a new client connection on the given socket
* descriptor. Returns a descriptor for the new socket.
*/
int accept_client(int socket_fd);
/* Reading and writing data. */
/**
* Write a sequence of bytes on a file descriptor. The bytes should then be read
* by read_message.
*
* @param fd The file descriptor to write to. It can be non-blocking.
* @param version The protocol version.
* @param type The type of the message to send.
* @param length The size in bytes of the bytes parameter.
* @param bytes The address of the message to send.
* @param mutex If not NULL, the whole write operation will be locked
* with this mutex, otherwise do nothing.
* @return int Whether there was an error while writing. 0 corresponds to
* success and -1 corresponds to an error (errno will be set).
*/
int write_message(int fd,
int64_t type,
int64_t length,
uint8_t *bytes,
std::mutex *mutex = NULL);
/**
* Read a sequence of bytes written by write_message from a file descriptor.
* This allocates space for the message.
*
* @note The caller must free the memory.
*
* @param fd The file descriptor to read from. It can be non-blocking.
* @param type The type of the message that is read will be written at this
* address. If there was an error while reading, this will be
* DISCONNECT_CLIENT.
* @param length The size in bytes of the message that is read will be written
* at this address. This size does not include the bytes used to encode
* the type and length. If there was an error while reading, this will
* be 0.
* @param bytes The address at which to write the pointer to the bytes that are
* read and allocated by this function. If there was an error while
* reading, this will be NULL.
* @return Void.
*/
void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes);
/**
* Read a message from a file descriptor and remove the file descriptor from the
* event loop if there is an error. This will actually do two reads. The first
* read reads sizeof(int64_t) bytes to determine the number of bytes to read in
* the next read.
*
* @param loop: The event loop.
* @param sock: The file descriptor to read from.
* @return A byte buffer contining the message or NULL if there was an
* error. The buffer needs to be freed by the user.
*/
uint8_t *read_message_async(event_loop *loop, int sock);
/**
* Read a sequence of bytes written by write_message from a file descriptor.
* This does not allocate space for the message if the provided buffer is
* large enough and can therefore often avoid allocations.
*
* @param fd The file descriptor to read from. It can be non-blocking.
* @param type The type of the message that is read will be written at this
* address. If there was an error while reading, this will be
* DISCONNECT_CLIENT.
* @param buffer The array the message will be written to. If it is not
* large enough to hold the message, it will be enlarged by read_vector.
* @return Number of bytes of the message that were read. This size does not
* include the bytes used to encode the type and length. If there was
* an error while reading, this will be 0.
*/
int64_t read_vector(int fd, int64_t *type, std::vector<uint8_t> &buffer);
/**
* Write a null-terminated string to a file descriptor.
*/
void write_log_message(int fd, const char *message);
/**
* Reads a null-terminated string from the file descriptor that has been
* written by write_log_message. Allocates and returns a pointer to the string.
* NOTE: Caller must free the memory!
*/
char *read_log_message(int fd);
/**
* Read a sequence of bytes from a file descriptor into a buffer. This will
* block until one of the following happens: (1) there is an error (2) end of
* file, or (3) all length bytes have been written.
*
* @note The buffer pointed to by cursor must already have length number of
* bytes allocated before calling this method.
*
* @param fd The file descriptor to read from. It can be non-blocking.
* @param cursor The cursor pointing to the beginning of the buffer.
* @param length The size of the byte sequence to read.
* @return int Whether there was an error while reading. 0 corresponds to
* success and -1 corresponds to an error (errno will be set).
*/
int read_bytes(int fd, uint8_t *cursor, size_t length);
/**
* Write a sequence of bytes into a file descriptor. This will block until one
* of the following happens: (1) there is an error (2) end of file, or (3) all
* length bytes have been written.
*
* @param fd The file descriptor to write to. It can be non-blocking.
* @param cursor The cursor pointing to the beginning of the bytes to send.
* @param length The size of the bytes sequence to write.
* @return int Whether there was an error while writing. 0 corresponds to
* success and -1 corresponds to an error (errno will be set).
*/
int write_bytes(int fd, uint8_t *cursor, size_t length);
#endif /* IO_H */
-107
View File
@@ -1,107 +0,0 @@
#include "logging.h"
#include <inttypes.h>
#include <stdint.h>
#include <sys/time.h>
#include <hiredis/hiredis.h>
#include "state/redis.h"
#include "io.h"
#include <iostream>
#include <string>
static const char *log_levels[5] = {"DEBUG", "INFO", "WARN", "ERROR", "FATAL"};
static const char *log_fmt =
"HMSET log:%s:%s log_level %s event_type %s message %s timestamp %s";
struct RayLoggerImpl {
/* String that identifies this client type. */
const char *client_type;
/* Suppress all log messages below this level. */
int log_level;
/* Whether or not we have a direct connection to Redis. */
int is_direct;
/* Either a db_handle or a socket to a process with a db_handle,
* depending on the is_direct flag. */
void *conn;
};
RayLogger *RayLogger_init(const char *client_type,
int log_level,
int is_direct,
void *conn) {
RayLogger *logger = (RayLogger *) malloc(sizeof(RayLogger));
logger->client_type = client_type;
logger->log_level = log_level;
logger->is_direct = is_direct;
logger->conn = conn;
return logger;
}
void RayLogger_free(RayLogger *logger) {
free(logger);
}
void RayLogger_log(RayLogger *logger,
int log_level,
const char *event_type,
const char *message) {
if (log_level < logger->log_level) {
return;
}
if (log_level < RAY_LOG_DEBUG || log_level > RAY_LOG_FATAL) {
return;
}
struct timeval tv;
gettimeofday(&tv, NULL);
std::string timestamp =
std::to_string(tv.tv_sec) + "." + std::to_string(tv.tv_usec);
/* Find number of bytes that would have been written for formatted_message
* size */
size_t formatted_message_size =
std::snprintf(nullptr, 0, log_fmt, timestamp.c_str(), "%b",
log_levels[log_level], event_type, message,
timestamp.c_str()) +
1;
/* Fill out everything except the client ID, which is binary data. */
char formatted_message[formatted_message_size];
std::snprintf(formatted_message, formatted_message_size, log_fmt,
timestamp.c_str(), "%b", log_levels[log_level], event_type,
message, timestamp.c_str());
if (logger->is_direct) {
DBHandle *db = (DBHandle *) logger->conn;
/* Fill in the client ID and send the message to Redis. */
redisAsyncContext *context = get_redis_context(db, db->client);
int status =
redisAsyncCommand(context, NULL, NULL, formatted_message,
(char *) db->client.data(), sizeof(db->client));
if ((status == REDIS_ERR) || context->err) {
LOG_REDIS_DEBUG(context, "error while logging message to log table");
}
} else {
/* If we don't own a Redis connection, we leave our client
* ID to be filled in by someone else. */
int *socket_fd = (int *) logger->conn;
write_log_message(*socket_fd, formatted_message);
}
}
void RayLogger_log_event(DBHandle *db,
uint8_t *key,
int64_t key_length,
uint8_t *value,
int64_t value_length,
double timestamp) {
std::string timestamp_string = std::to_string(timestamp);
int status = redisAsyncCommand(db->context, NULL, NULL, "ZADD %b %s %b", key,
key_length, timestamp_string.c_str(), value,
value_length);
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context, "error while logging message to event log");
}
}
-58
View File
@@ -1,58 +0,0 @@
#ifndef LOGGING_H
#define LOGGING_H
#define RAY_LOG_VERBOSE -1
#define RAY_LOG_DEBUG 0
#define RAY_LOG_INFO 1
#define RAY_LOG_WARNING 2
#define RAY_LOG_ERROR 3
#define RAY_LOG_FATAL 4
/* Entity types. */
#define RAY_FUNCTION "FUNCTION"
#define RAY_OBJECT "OBJECT"
#define RAY_TASK "TASK"
#include "state/db.h"
typedef struct RayLoggerImpl RayLogger;
/* Initialize a Ray logger for the given client type and logging level. If the
* is_direct flag is set, the logger will treat the given connection as a
* direct connection to the log. Otherwise, it will treat it as a socket to
* another process with a connection to the log.
* NOTE: User is responsible for freeing the returned logger. */
RayLogger *RayLogger_init(const char *client_type,
int log_level,
int is_direct,
void *conn);
/* Free the logger. This does not free the connection to the log. */
void RayLogger_free(RayLogger *logger);
/* Log an event at the given log level with the given event_type.
* NOTE: message cannot contain spaces! JSON format is recommended.
* TODO: Support spaces in messages. */
void RayLogger_log(RayLogger *logger,
int log_level,
const char *event_type,
const char *message);
/**
* Log an event to the event log.
*
* @param db The database handle.
* @param key The key in Redis to store the event in.
* @param key_length The length of the key.
* @param value The value to log.
* @param value_length The length of the value.
* @return Void.
*/
void RayLogger_log_event(DBHandle *db,
uint8_t *key,
int64_t key_length,
uint8_t *value,
int64_t value_length,
double time);
#endif /* LOGGING_H */
-24
View File
@@ -1,24 +0,0 @@
#include "net.h"
#include <arpa/inet.h>
#include <sstream>
#include "common.h"
int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) {
char port_str[6];
int parsed = sscanf(ip_addr_port, "%15[0-9.]:%5[0-9]", ip_addr, port_str);
if (parsed != 2) {
return -1;
}
*port = atoi(port_str);
return 0;
}
/* Return true if the ip address is valid. */
bool valid_ip_address(const std::string &ip_address) {
struct sockaddr_in sa;
int result = inet_pton(AF_INET, ip_address.c_str(), &sa.sin_addr);
return result == 1;
}
-9
View File
@@ -1,9 +0,0 @@
#ifndef NET_H
#define NET_H
/* Helper function to parse a string of the form <IP address>:<port> into the
* given ip_addr and port pointers. The ip_addr buffer must already be
* allocated. Return 0 upon success and -1 upon failure. */
int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port);
#endif /* NET_H */
File diff suppressed because it is too large Load Diff
-69
View File
@@ -1,69 +0,0 @@
/* http://stackoverflow.com/a/17195644/541686 */
#include <string.h>
#include <stdio.h>
int opterr = 1, /* if error message should be printed */
optind = 1, /* index into parent argv vector */
optopt, /* character checked for validity */
optreset; /* reset getopt */
char *optarg; /* argument associated with option */
#define BADCH (int) '?'
#define BADARG (int) ':'
#define EMSG ""
/*
* getopt --
* Parse argc/argv argument vector.
*/
int getopt(int nargc, char *const nargv[], const char *ostr) {
static char *place = EMSG; /* option letter processing */
const char *oli; /* option letter list index */
if (optreset || !*place) { /* update scanning pointer */
optreset = 0;
if (optind >= nargc || *(place = nargv[optind]) != '-') {
place = EMSG;
return (-1);
}
if (place[1] && *++place == '-') { /* found "--" */
++optind;
place = EMSG;
return (-1);
}
} /* option letter okay? */
if ((optopt = (int) *place++) == (int) ':' || !(oli = strchr(ostr, optopt))) {
/*
* if the user didn't specify '-' as an option,
* assume it means -1.
*/
if (optopt == (int) '-')
return (-1);
if (!*place)
++optind;
if (opterr && *ostr != ':')
(void) printf("illegal option -- %c\n", optopt);
return (BADCH);
}
if (*++oli != ':') { /* don't need argument */
optarg = NULL;
if (!*place)
++optind;
} else { /* need an argument */
if (*place) /* no white space */
optarg = place;
else if (nargc <= ++optind) { /* no arg */
place = EMSG;
if (*ostr == ':')
return (BADARG);
if (opterr)
(void) printf("option requires an argument -- %c\n", optopt);
return (BADCH);
} else /* white space */
optarg = nargv[optind];
place = EMSG;
++optind;
}
return (optopt); /* dump back option letter */
}
-4
View File
@@ -1,4 +0,0 @@
#ifndef GETOPT_H
#define GETOPT_H
#endif /* GETOPT_H */
-208
View File
@@ -1,208 +0,0 @@
#include <sys/socket.h>
int socketpair(int domain, int type, int protocol, int sv[2]) {
if ((domain != AF_UNIX && domain != AF_INET) || type != SOCK_STREAM) {
return INVALID_SOCKET;
}
SOCKET sockets[2];
int r = dumb_socketpair(sockets);
sv[0] = (int) sockets[0];
sv[1] = (int) sockets[1];
return r;
}
#pragma comment(lib, "IPHlpAPI.lib")
struct _MIB_TCPROW2 {
DWORD dwState, dwLocalAddr, dwLocalPort, dwRemoteAddr, dwRemotePort,
dwOwningPid;
enum _TCP_CONNECTION_OFFLOAD_STATE dwOffloadState;
};
struct _MIB_TCPTABLE2 {
DWORD dwNumEntries;
struct _MIB_TCPROW2 table[1];
};
DECLSPEC_IMPORT ULONG WINAPI GetTcpTable2(struct _MIB_TCPTABLE2 *TcpTable,
PULONG SizePointer,
BOOL Order);
static DWORD getsockpid(SOCKET client) {
/* http://stackoverflow.com/a/25431340 */
DWORD pid = 0;
struct sockaddr_in Server = {0};
int ServerSize = sizeof(Server);
struct sockaddr_in Client = {0};
int ClientSize = sizeof(Client);
if ((getsockname(client, (struct sockaddr *) &Server, &ServerSize) == 0) &&
(getpeername(client, (struct sockaddr *) &Client, &ClientSize) == 0)) {
struct _MIB_TCPTABLE2 *TcpTable = NULL;
ULONG TcpTableSize = 0;
ULONG result;
do {
result = GetTcpTable2(TcpTable, &TcpTableSize, TRUE);
if (result != ERROR_INSUFFICIENT_BUFFER) {
break;
}
free(TcpTable);
TcpTable = (struct _MIB_TCPTABLE2 *) malloc(TcpTableSize);
} while (TcpTable != NULL);
if (result == NO_ERROR) {
for (DWORD dw = 0; dw < TcpTable->dwNumEntries; ++dw) {
struct _MIB_TCPROW2 *row = &(TcpTable->table[dw]);
if ((row->dwState == 5 /* MIB_TCP_STATE_ESTAB */) &&
(row->dwLocalAddr == Client.sin_addr.s_addr) &&
((row->dwLocalPort & 0xFFFF) == Client.sin_port) &&
(row->dwRemoteAddr == Server.sin_addr.s_addr) &&
((row->dwRemotePort & 0xFFFF) == Server.sin_port)) {
pid = row->dwOwningPid;
break;
}
}
}
free(TcpTable);
}
return pid;
}
ssize_t sendmsg(int sockfd, struct msghdr *msg, int flags) {
ssize_t result = -1;
struct cmsghdr *header = CMSG_FIRSTHDR(msg);
if (header->cmsg_level == SOL_SOCKET && header->cmsg_type == SCM_RIGHTS) {
/* We're trying to send over a handle of some kind.
* We have to look up which process we're communicating with,
* open a handle to it, and then duplicate our handle into it.
* However, the first two steps cannot be done atomically.
* Therefore, this code HAS A RACE CONDITIONS and is therefore NOT SECURE.
* In the absense of a malicious actor, though, it is exceedingly unlikely
* that the child process closes AND that its process ID is reassigned
* to another existing process.
*/
struct msghdr const old_msg = *msg;
int *const pfd = (int *) CMSG_DATA(header);
msg->msg_control = NULL;
msg->msg_controllen = 0;
WSAPROTOCOL_INFO protocol_info = {0};
BOOL const is_socket = !!FDAPI_GetSocketStatePtr(*pfd);
DWORD const target_pid = getsockpid(sockfd);
HANDLE target_process = NULL;
if (target_pid) {
if (!is_socket) {
/* This is a regular handle... fit it into the same struct */
target_process = OpenProcess(PROCESS_DUP_HANDLE, FALSE, target_pid);
if (target_process) {
if (DuplicateHandle(GetCurrentProcess(), (HANDLE)(intptr_t) *pfd,
target_process, (HANDLE *) &protocol_info, 0,
TRUE, DUPLICATE_SAME_ACCESS)) {
result = 0;
}
}
} else {
/* This is a socket... */
result = FDAPI_WSADuplicateSocket(*pfd, target_pid, &protocol_info);
}
}
if (result == 0) {
int const nbufs = msg->dwBufferCount + 1;
WSABUF *const bufs =
(struct _WSABUF *) _alloca(sizeof(*msg->lpBuffers) * nbufs);
bufs[0].buf = (char *) &protocol_info;
bufs[0].len = sizeof(protocol_info);
memcpy(&bufs[1], msg->lpBuffers,
msg->dwBufferCount * sizeof(*msg->lpBuffers));
DWORD nb;
msg->lpBuffers = bufs;
msg->dwBufferCount = nbufs;
GUID const wsaid_WSASendMsg = {
0xa441e712,
0x754f,
0x43ca,
{0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d}};
typedef INT PASCAL WSASendMsg_t(
SOCKET s, LPWSAMSG lpMsg, DWORD dwFlags, LPDWORD lpNumberOfBytesSent,
LPWSAOVERLAPPED lpOverlapped,
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine);
WSASendMsg_t *WSASendMsg = NULL;
result = FDAPI_WSAIoctl(sockfd, SIO_GET_EXTENSION_FUNCTION_POINTER,
&wsaid_WSASendMsg, sizeof(wsaid_WSASendMsg),
&WSASendMsg, sizeof(WSASendMsg), &nb, NULL, 0);
if (result == 0) {
result = (*WSASendMsg)(sockfd, msg, flags, &nb, NULL, NULL) == 0
? (ssize_t)(nb - sizeof(protocol_info))
: 0;
}
}
if (result != 0 && target_process && !is_socket) {
/* we failed to send the handle, and it needs cleaning up! */
HANDLE duplicated_back = NULL;
if (DuplicateHandle(target_process, *(HANDLE *) &protocol_info,
GetCurrentProcess(), &duplicated_back, 0, FALSE,
DUPLICATE_CLOSE_SOURCE)) {
CloseHandle(duplicated_back);
}
}
if (target_process) {
CloseHandle(target_process);
}
*msg = old_msg;
}
return result;
}
ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags) {
int result = -1;
struct cmsghdr *header = CMSG_FIRSTHDR(msg);
if (msg->msg_controllen &&
flags == 0 /* We can't send flags on Windows... */) {
struct msghdr const old_msg = *msg;
msg->msg_control = NULL;
msg->msg_controllen = 0;
WSAPROTOCOL_INFO protocol_info = {0};
int const nbufs = msg->dwBufferCount + 1;
WSABUF *const bufs =
(struct _WSABUF *) _alloca(sizeof(*msg->lpBuffers) * nbufs);
bufs[0].buf = (char *) &protocol_info;
bufs[0].len = sizeof(protocol_info);
memcpy(&bufs[1], msg->lpBuffers,
msg->dwBufferCount * sizeof(*msg->lpBuffers));
typedef INT PASCAL WSARecvMsg_t(
SOCKET s, LPWSAMSG lpMsg, LPDWORD lpNumberOfBytesRecvd,
LPWSAOVERLAPPED lpOverlapped,
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine);
WSARecvMsg_t *WSARecvMsg = NULL;
DWORD nb;
GUID const wsaid_WSARecvMsg = {
0xf689d7c8,
0x6f1f,
0x436b,
{0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22}};
result = FDAPI_WSAIoctl(sockfd, SIO_GET_EXTENSION_FUNCTION_POINTER,
&wsaid_WSARecvMsg, sizeof(wsaid_WSARecvMsg),
&WSARecvMsg, sizeof(WSARecvMsg), &nb, NULL, 0);
if (result == 0) {
result = (*WSARecvMsg)(sockfd, msg, &nb, NULL, NULL) == 0
? (ssize_t)(nb - sizeof(protocol_info))
: 0;
}
if (result == 0) {
int *const pfd = (int *) CMSG_DATA(header);
if (protocol_info.iSocketType == 0 && protocol_info.iProtocol == 0) {
*pfd = *(int *) &protocol_info;
} else {
*pfd = FDAPI_WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO,
FROM_PROTOCOL_INFO, &protocol_info, 0, 0);
}
header->cmsg_level = SOL_SOCKET;
header->cmsg_type = SCM_RIGHTS;
}
*msg = old_msg;
}
return result;
}
-4
View File
@@ -1,4 +0,0 @@
#ifndef NETDB_H
#define NETDB_H
#endif /* NETDB_H */
-4
View File
@@ -1,4 +0,0 @@
#ifndef IN_H
#define IN_H
#endif /* IN_H */
-4
View File
@@ -1,4 +0,0 @@
#ifndef POLL_H
#define POLL_H
#endif /* POLL_H */
-150
View File
@@ -1,150 +0,0 @@
/* socketpair.c
Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
The name of the author must not be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/* Changes:
* 2014-02-12: merge David Woodhouse, Ger Hobbelt improvements
* git.infradead.org/users/dwmw2/openconnect.git/commitdiff/bdeefa54
* github.com/GerHobbelt/selectable-socketpair
* always init the socks[] to -1/INVALID_SOCKET on error, both on Win32/64
* and UNIX/other platforms
* 2013-07-18: Change to BSD 3-clause license
* 2010-03-31:
* set addr to 127.0.0.1 because win32 getsockname does not always set it.
* 2010-02-25:
* set SO_REUSEADDR option to avoid leaking some windows resource.
* Windows System Error 10049, "Event ID 4226 TCP/IP has reached
* the security limit imposed on the number of concurrent TCP connect
* attempts." Bleah.
* 2007-04-25:
* preserve value of WSAGetLastError() on all error returns.
* 2007-04-22: (Thanks to Matthew Gregan <kinetik@flim.org>)
* s/EINVAL/WSAEINVAL/ fix trivial compile failure
* s/socket/WSASocket/ enable creation of sockets suitable as stdin/stdout
* of a child process.
* add argument make_overlapped
*/
#include <string.h>
#ifdef WIN32
#include <ws2tcpip.h> /* socklen_t, et al (MSVC20xx) */
#include <windows.h>
#include <io.h>
#else
#ifdef _WIN32
#include <Win32_Interop/win32_types.h>
#include <Win32_Interop/Win32_FDAPI.h>
#endif
#include <sys/types.h>
#include <sys/socket.h>
#include <errno.h>
#endif
#ifdef WIN32
/* dumb_socketpair:
* If make_overlapped is nonzero, both sockets created will be usable for
* "overlapped" operations via WSASend etc. If make_overlapped is zero,
* socks[0] (only) will be usable with regular ReadFile etc., and thus
* suitable for use as stdin or stdout of a child process. Note that the
* sockets must be closed with closesocket() regardless.
*/
int dumb_socketpair(SOCKET socks[2]) {
union {
struct sockaddr_in inaddr;
struct sockaddr addr;
} a;
SOCKET listener;
int e;
socklen_t addrlen = sizeof(a.inaddr);
int reuse = 1;
if (socks == 0) {
return SOCKET_ERROR;
}
socks[0] = socks[1] = -1;
listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (listener == -1)
return SOCKET_ERROR;
memset(&a, 0, sizeof(a));
a.inaddr.sin_family = AF_INET;
a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
a.inaddr.sin_port = 0;
for (;;) {
if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, (char *) &reuse,
(socklen_t) sizeof(reuse)) == -1)
break;
if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
break;
memset(&a, 0, sizeof(a));
if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
break;
// win32 getsockname may only set the port number, p=0.0005.
// ( http://msdn.microsoft.com/library/ms738543.aspx ):
a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
a.inaddr.sin_family = AF_INET;
if (listen(listener, 1) == SOCKET_ERROR)
break;
socks[0] = FDAPI_WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, 0);
if (socks[0] == -1)
break;
if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
break;
socks[1] = accept(listener, NULL, NULL);
if (socks[1] == -1)
break;
FDAPI_close(listener);
return 0;
}
FDAPI_close(listener);
FDAPI_close(socks[0]);
FDAPI_close(socks[1]);
socks[0] = socks[1] = -1;
return SOCKET_ERROR;
}
#else
int dumb_socketpair(int socks[2], int dummy) {
if (socks == 0) {
errno = EINVAL;
return -1;
}
dummy = socketpair(AF_LOCAL, SOCK_STREAM, 0, socks);
if (dummy)
socks[0] = socks[1] = -1;
return dummy;
}
#endif
-4
View File
@@ -1,4 +0,0 @@
#ifndef STRINGS_H
#define STRINGS_H
#endif /* STRINGS_H */
-4
View File
@@ -1,4 +0,0 @@
#ifndef IOCTL_H
#define IOCTL_H
#endif /* IOCTL_H */
-36
View File
@@ -1,36 +0,0 @@
#ifndef MMAN_H
#define MMAN_H
#include <unistd.h>
#define MAP_SHARED 0x0010 /* share changes */
#define MAP_FAILED ((void *) -1)
#define PROT_READ 0x04 /* pages can be read */
#define PROT_WRITE 0x02 /* pages can be written */
#define PROT_EXEC 0x01 /* pages can be executed */
static void *mmap(void *addr,
size_t len,
int prot,
int flags,
int fd,
off_t off) {
void *result = (void *) (-1);
if (!addr && prot == MAP_SHARED) {
/* HACK: we're assuming handle sizes can't exceed 32 bits, which is wrong...
* but works for now. */
void *ptr = MapViewOfFile((HANDLE)(intptr_t) fd, FILE_MAP_ALL_ACCESS,
(DWORD)(off >> (CHAR_BIT * sizeof(DWORD))),
(DWORD) off, (SIZE_T) len);
if (ptr) {
result = ptr;
}
}
return result;
}
static int munmap(void *addr, size_t length) {
(void) length;
return UnmapViewOfFile(addr) ? 0 : -1;
}
#endif /* MMAN_H */
-4
View File
@@ -1,4 +0,0 @@
#ifndef SELECT_H
#define SELECT_H
#endif /* SELECT_H */
-36
View File
@@ -1,36 +0,0 @@
#ifndef SOCKET_H
#define SOCKET_H
typedef unsigned short sa_family_t;
#include "../../src/Win32_Interop/Win32_FDAPI.h"
#include "../../src/Win32_Interop/Win32_APIs.h"
#define cmsghdr _WSACMSGHDR
#undef CMSG_DATA
#define CMSG_DATA WSA_CMSG_DATA
#define CMSG_SPACE WSA_CMSG_SPACE
#define CMSG_FIRSTHDR WSA_CMSG_FIRSTHDR
#define CMSG_LEN WSA_CMSG_LEN
#define CMSG_NXTHDR WSA_CMSG_NXTHDR
#define SCM_RIGHTS 1
#define iovec _WSABUF
#define iov_base buf
#define iov_len len
#define msghdr _WSAMSG
#define msg_name name
#define msg_namelen namelen
#define msg_iov lpBuffers
#define msg_iovlen dwBufferCount
#define msg_control Control.buf
#define msg_controllen Control.len
#define msg_flags dwFlags
int dumb_socketpair(SOCKET socks[2]);
ssize_t sendmsg(int sockfd, struct msghdr *msg, int flags);
ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags);
int socketpair(int domain, int type, int protocol, int sv[2]);
#endif /* SOCKET_H */
-12
View File
@@ -1,12 +0,0 @@
#ifndef TIME_H
#define TIME_H
#include <WinSock2.h> /* timeval */
int gettimeofday_highres(struct timeval *tv, struct timezone *tz);
static int gettimeofday(struct timeval *tv, struct timezone *tz) {
return gettimeofday_highres(tv, tz);
}
#endif /* TIME_H */
-13
View File
@@ -1,13 +0,0 @@
#ifndef UN_H
#define UN_H
#include <sys/socket.h>
struct sockaddr_un {
/** AF_UNIX. */
sa_family_t sun_family;
/** The pathname. */
char sun_path[108];
};
#endif /* UN_H */
-4
View File
@@ -1,4 +0,0 @@
#ifndef WAIT_H
#define WAIT_H
#endif /* WAIT_H */
-11
View File
@@ -1,11 +0,0 @@
#ifndef UNISTD_H
#define UNISTD_H
extern char *optarg;
extern int optind, opterr, optopt;
int getopt(int nargc, char *const nargv[], const char *ostr);
#include "../../src/Win32_Interop/Win32_FDAPI.h"
#define close(...) FDAPI_close(__VA_ARGS__)
#endif /* UNISTD_H */
@@ -1,47 +0,0 @@
#include "actor_notification_table.h"
#include "common_protocol.h"
#include "redis.h"
void publish_actor_creation_notification(DBHandle *db_handle,
const ActorID &actor_id,
const WorkerID &driver_id,
const DBClientID &local_scheduler_id) {
// Create a flatbuffer object to serialize and publish.
flatbuffers::FlatBufferBuilder fbb;
// Create the flatbuffers message.
auto message = CreateActorCreationNotification(
fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, driver_id),
to_flatbuf(fbb, local_scheduler_id));
fbb.Finish(message);
ActorCreationNotificationData *data =
(ActorCreationNotificationData *) malloc(
sizeof(ActorCreationNotificationData) + fbb.GetSize());
data->size = fbb.GetSize();
memcpy(&data->flatbuffer_data[0], fbb.GetBufferPointer(), fbb.GetSize());
init_table_callback(db_handle, UniqueID::nil(), __func__,
new CommonCallbackData(data), NULL, NULL,
redis_publish_actor_creation_notification, NULL);
}
void actor_notification_table_subscribe(
DBHandle *db_handle,
actor_notification_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry) {
ActorNotificationTableSubscribeData *sub_data =
(ActorNotificationTableSubscribeData *) malloc(
sizeof(ActorNotificationTableSubscribeData));
sub_data->subscribe_callback = subscribe_callback;
sub_data->subscribe_context = subscribe_context;
init_table_callback(db_handle, UniqueID::nil(), __func__,
new CommonCallbackData(sub_data), retry, NULL,
redis_actor_notification_table_subscribe, NULL);
}
void actor_table_mark_removed(DBHandle *db_handle, ActorID actor_id) {
redis_actor_table_mark_removed(db_handle, actor_id);
}
@@ -1,74 +0,0 @@
#ifndef ACTOR_NOTIFICATION_TABLE_H
#define ACTOR_NOTIFICATION_TABLE_H
#include "task.h"
#include "db.h"
#include "table.h"
/*
* ==== Subscribing to the actor notification table ====
*/
/* Callback for subscribing to the local scheduler table. */
typedef void (*actor_notification_table_subscribe_callback)(
const ActorID &actor_id,
const WorkerID &driver_id,
const DBClientID &local_scheduler_id,
void *user_context);
/// Publish an actor creation notification. This is published by a local
/// scheduler once it creates an actor.
///
/// \param db_handle Database handle.
/// \param actor_id The ID of the actor that was created.
/// \param driver_id The ID of the driver that created the actor.
/// \param local_scheduler_id The ID of the local scheduler that created the
/// actor.
/// \return Void.
void publish_actor_creation_notification(DBHandle *db_handle,
const ActorID &actor_id,
const WorkerID &driver_id,
const DBClientID &local_scheduler_id);
/// Data that is needed to publish an actor creation notification.
typedef struct {
/// The size of the flatbuffer object.
int64_t size;
/// The information to be sent.
uint8_t flatbuffer_data[0];
} ActorCreationNotificationData;
/**
* Register a callback to process actor notification events.
*
* @param db_handle Database handle.
* @param subscribe_callback Callback that will be called when the local
* scheduler event happens.
* @param subscribe_context Context that will be passed into the
* subscribe_callback.
* @param retry Information about retrying the request to the database.
* @return Void.
*/
void actor_notification_table_subscribe(
DBHandle *db_handle,
actor_notification_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry);
/* Data that is needed to register local scheduler table subscribe callbacks
* with the state database. */
typedef struct {
actor_notification_table_subscribe_callback subscribe_callback;
void *subscribe_context;
} ActorNotificationTableSubscribeData;
/**
* Marks an actor as removed. This prevents the actor from being resurrected.
*
* @param db The database handle.
* @param actor_id The actor id to mark as removed.
* @return Void.
*/
void actor_table_mark_removed(DBHandle *db_handle, ActorID actor_id);
#endif /* ACTOR_NOTIFICATION_TABLE_H */
-70
View File
@@ -1,70 +0,0 @@
#ifndef DB_H
#define DB_H
#include <vector>
#include "common.h"
#include "event_loop.h"
typedef struct DBHandle DBHandle;
/**
* Connect to the global system store.
*
* @param db_address The hostname to use to connect to the database.
* @param db_port The port to use to connect to the database.
* @param db_shards_addresses The list of database shard IP addresses.
* @param db_shards_ports The list of database shard ports, in the same order
* as db_shards_addresses.
* @param client_type The type of this client.
* @param node_ip_address The hostname of the client that is connecting.
* @param args A vector of extra arguments strings. They should alternate
* between the name of the argument and the value of the argument. For
* examples: "port", "1234", "socket_name", "/tmp/s1". This vector should
* have an even length.
* @return This returns a handle to the database, which must be freed with
* db_disconnect after use.
*/
DBHandle *db_connect(const std::string &db_primary_address,
int db_primary_port,
const char *client_type,
const char *node_ip_address,
const std::vector<std::string> &args);
/**
* Attach global system store connection to an event loop. Callbacks from
* queries to the global system store will trigger events in the event loop.
*
* @param db The handle to the database that is connected.
* @param loop The event loop the database gets connected to.
* @param reattach Can only be true in unit tests. If true, the database is
* reattached to the loop.
* @return Void.
*/
void db_attach(DBHandle *db, event_loop *loop, bool reattach);
/**
* Disconnect from the global system store.
*
* @param db The database connection to close and clean up.
* @return Void.
*/
void db_disconnect(DBHandle *db);
/**
* Free the database handle.
*
* @param db The database connection to clean up.
* @return Void.
*/
void DBHandle_free(DBHandle *db);
/**
* Returns the db client ID.
*
* @param db The handle to the database.
* @returns int The db client ID for this connection to the database.
*/
DBClientID get_db_client_id(DBHandle *db);
#endif
-90
View File
@@ -1,90 +0,0 @@
#include "db_client_table.h"
#include "redis.h"
void db_client_table_remove(DBHandle *db_handle,
DBClientID db_client_id,
RetryInfo *retry,
db_client_table_done_callback done_callback,
void *user_context) {
init_table_callback(db_handle, db_client_id, __func__,
new CommonCallbackData(NULL), retry,
(table_done_callback) done_callback,
redis_db_client_table_remove, user_context);
}
void db_client_table_subscribe(
DBHandle *db_handle,
db_client_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry,
db_client_table_done_callback done_callback,
void *user_context) {
DBClientTableSubscribeData *sub_data =
(DBClientTableSubscribeData *) malloc(sizeof(DBClientTableSubscribeData));
sub_data->subscribe_callback = subscribe_callback;
sub_data->subscribe_context = subscribe_context;
init_table_callback(db_handle, UniqueID::nil(), __func__,
new CommonCallbackData(sub_data), retry,
(table_done_callback) done_callback,
redis_db_client_table_subscribe, user_context);
}
const std::vector<std::string> db_client_table_get_ip_addresses(
DBHandle *db_handle,
const std::vector<DBClientID> &manager_ids) {
/* We time this function because in the past this loop has taken multiple
* seconds under stressful situations on hundreds of machines causing the
* plasma manager to die (because it went too long without sending
* heartbeats). */
int64_t start_time = current_time_ms();
/* Construct the manager vector from the flatbuffers object. */
std::vector<std::string> manager_vector;
for (auto const &manager_id : manager_ids) {
DBClient client = redis_cache_get_db_client(db_handle, manager_id);
RAY_CHECK(!client.manager_address.empty());
if (client.is_alive) {
manager_vector.push_back(client.manager_address);
}
}
int64_t end_time = current_time_ms();
if (end_time - start_time > RayConfig::instance().max_time_for_loop()) {
RAY_LOG(WARNING) << "calling redis_get_cached_db_client in a loop in with "
<< manager_ids.size() << " manager IDs took "
<< end_time - start_time << " milliseconds.";
}
return manager_vector;
}
void db_client_table_update_cache_callback(DBClient *db_client,
void *user_context) {
DBHandle *db_handle = (DBHandle *) user_context;
redis_cache_set_db_client(db_handle, *db_client);
}
void db_client_table_cache_init(DBHandle *db_handle) {
db_client_table_subscribe(db_handle, db_client_table_update_cache_callback,
db_handle, NULL, NULL, NULL);
}
DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id) {
RAY_CHECK(!client_id.is_nil());
return redis_cache_get_db_client(db_handle, client_id);
}
void plasma_manager_send_heartbeat(DBHandle *db_handle) {
RetryInfo heartbeat_retry;
heartbeat_retry.num_retries = 0;
heartbeat_retry.timeout =
RayConfig::instance().heartbeat_timeout_milliseconds();
heartbeat_retry.fail_callback = NULL;
init_table_callback(db_handle, UniqueID::nil(), __func__,
new CommonCallbackData(NULL),
(RetryInfo *) &heartbeat_retry, NULL,
redis_plasma_manager_send_heartbeat, NULL);
}
-120
View File
@@ -1,120 +0,0 @@
#ifndef DB_CLIENT_TABLE_H
#define DB_CLIENT_TABLE_H
#include <vector>
#include "db.h"
#include "table.h"
typedef void (*db_client_table_done_callback)(DBClientID db_client_id,
void *user_context);
/**
* Remove a client from the db clients table.
*
* @param db_handle Database handle.
* @param db_client_id The database client ID to remove.
* @param retry Information about retrying the request to the database.
* @param done_callback Function to be called when database returns result.
* @param user_context Data that will be passed to done_callback and
* fail_callback.
* @return Void.
*
*/
void db_client_table_remove(DBHandle *db_handle,
DBClientID db_client_id,
RetryInfo *retry,
db_client_table_done_callback done_callback,
void *user_context);
/*
* ==== Subscribing to the db client table ====
*/
/* An entry in the db client table. */
typedef struct {
/** The database client ID. */
DBClientID id;
/** The database client type. */
std::string client_type;
/** An optional auxiliary address for the plasma manager associated with this
* database client. */
std::string manager_address;
/** Whether or not the database client exists. If this is false for an entry,
* then it will never again be true. */
bool is_alive;
} DBClient;
/* Callback for subscribing to the db client table. */
typedef void (*db_client_table_subscribe_callback)(DBClient *db_client,
void *user_context);
/**
* Register a callback for a db client table event.
*
* @param db_handle Database handle.
* @param subscribe_callback Callback that will be called when the db client
* table is updated.
* @param subscribe_context Context that will be passed into the
* subscribe_callback.
* @param retry Information about retrying the request to the database.
* @param done_callback Function to be called when database returns result.
* @param user_context Data that will be passed to done_callback and
* fail_callback.
* @return Void.
*/
void db_client_table_subscribe(
DBHandle *db_handle,
db_client_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry,
db_client_table_done_callback done_callback,
void *user_context);
/* Data that is needed to register db client table subscribe callbacks with the
* state database. */
typedef struct {
db_client_table_subscribe_callback subscribe_callback;
void *subscribe_context;
} DBClientTableSubscribeData;
const std::vector<std::string> db_client_table_get_ip_addresses(
DBHandle *db,
const std::vector<DBClientID> &manager_ids);
/**
* Initialize the db client cache. The cache is updated with each notification
* from the db client table.
*
* @param db_handle Database handle.
* @return Void.
*/
void db_client_table_cache_init(DBHandle *db_handle);
/**
* Get a db client from the cache. If the requested client is not there,
* request the latest entry from the db client table.
*
* @param db_handle Database handle.
* @param client_id The ID of the client to look up in the cache.
* @return The database client in the cache.
*/
DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id);
/*
* ==== Plasma manager heartbeats ====
*/
/**
* Start sending heartbeats to the plasma_managers channel. Each
* heartbeat contains this database client's ID. Heartbeats can be subscribed
* to through the plasma_managers channel. Once called, this "retries" the
* heartbeat operation forever, every heartbeat_timeout_milliseconds
* milliseconds.
*
* @param db_handle Database handle.
* @return Void.
*/
void plasma_manager_send_heartbeat(DBHandle *db_handle);
#endif /* DB_CLIENT_TABLE_H */
-23
View File
@@ -1,23 +0,0 @@
#include "driver_table.h"
#include "redis.h"
void driver_table_subscribe(DBHandle *db_handle,
driver_table_subscribe_callback subscribe_callback,
void *subscribe_context,
RetryInfo *retry) {
DriverTableSubscribeData *sub_data =
(DriverTableSubscribeData *) malloc(sizeof(DriverTableSubscribeData));
sub_data->subscribe_callback = subscribe_callback;
sub_data->subscribe_context = subscribe_context;
init_table_callback(db_handle, UniqueID::nil(), __func__,
new CommonCallbackData(sub_data), retry, NULL,
redis_driver_table_subscribe, NULL);
}
void driver_table_send_driver_death(DBHandle *db_handle,
WorkerID driver_id,
RetryInfo *retry) {
init_table_callback(db_handle, driver_id, __func__,
new CommonCallbackData(NULL), retry, NULL,
redis_driver_table_send_driver_death, NULL);
}

Some files were not shown because too many files have changed in this diff Show More